Commit 34b25b3c authored by rusty1s's avatar rusty1s
Browse files

fixes

parent 4a68dd60
...@@ -71,10 +71,13 @@ def test_jit(): ...@@ -71,10 +71,13 @@ def test_jit():
# scipy = adj.to_scipy(layout='csr') # scipy = adj.to_scipy(layout='csr')
# mat = SparseTensor.from_scipy(scipy) # mat = SparseTensor.from_scipy(scipy)
print() print()
print(adj)
# adj = t(adj) # adj = t(adj)
adj = adj.t() adj = adj.t()
adj = adj.remove_diag(k=0)
print(adj.to_dense())
adj = adj + torch.tensor([1, 2, 3]).view(1, 3)
print(adj) print(adj)
print(adj.to_dense())
# print(adj.t) # print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col} # adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
......
...@@ -20,7 +20,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -20,7 +20,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f'{other.size()}.') f'{other.size()}.')
if value is not None: if value is not None:
value = other.add_(value) value = other.to(value.dtype).add_(value)
else: else:
value = other.add_(1) value = other.add_(1)
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
...@@ -41,7 +41,7 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -41,7 +41,7 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f'{other.size()}.') f'{other.size()}.')
if value is not None: if value is not None:
value = value.add_(other) value = value.add_(other.to(value.dtype))
else: else:
value = other.add_(1) value = other.add_(1)
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
...@@ -52,7 +52,7 @@ def add_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -52,7 +52,7 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
if value is not None: if value is not None:
value = value.add(other) value = value.add(other.to(value.dtype))
else: else:
value = other.add(1) value = other.add(1)
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
...@@ -63,7 +63,7 @@ def add_nnz_(src: SparseTensor, other: torch.Tensor, ...@@ -63,7 +63,7 @@ def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
if value is not None: if value is not None:
value = value.add_(other) value = value.add_(other.to(value.dtype))
else: else:
value = other.add(1) value = other.add(1)
return src.set_value_(value, layout=layout) return src.set_value_(value, layout=layout)
...@@ -75,3 +75,6 @@ SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz( ...@@ -75,3 +75,6 @@ SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self, other, layout) self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_( SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
self, other, layout) self, other, layout)
SparseTensor.__add__ = SparseTensor.add
SparseTensor.__radd__ = SparseTensor.add
SparseTensor.__iadd__ = SparseTensor.add_
...@@ -20,7 +20,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -20,7 +20,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f'{other.size()}.') f'{other.size()}.')
if value is not None: if value is not None:
value = other.mul_(value) value = other.to(value.dtype).mul_(value)
else: else:
value = other value = other
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
...@@ -41,7 +41,7 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -41,7 +41,7 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f'{other.size()}.') f'{other.size()}.')
if value is not None: if value is not None:
value = value.mul_(other) value = value.mul_(other.to(value.dtype))
else: else:
value = other value = other
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
...@@ -52,7 +52,7 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -52,7 +52,7 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
if value is not None: if value is not None:
value = value.mul(other) value = value.mul(other.to(value.dtype))
else: else:
value = other value = other
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
...@@ -63,7 +63,7 @@ def mul_nnz_(src: SparseTensor, other: torch.Tensor, ...@@ -63,7 +63,7 @@ def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
if value is not None: if value is not None:
value = value.mul_(other) value = value.mul_(other.to(value.dtype))
else: else:
value = other value = other
return src.set_value_(value, layout=layout) return src.set_value_(value, layout=layout)
...@@ -75,3 +75,6 @@ SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz( ...@@ -75,3 +75,6 @@ SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout) self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_( SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout) self, other, layout)
SparseTensor.__mul__ = SparseTensor.mul
SparseTensor.__rmul__ = SparseTensor.mul
SparseTensor.__imul__ = SparseTensor.mul_
...@@ -345,7 +345,10 @@ class SparseTensor(object): ...@@ -345,7 +345,10 @@ class SparseTensor(object):
def to_dense(self, options: Optional[torch.Tensor] = None): def to_dense(self, options: Optional[torch.Tensor] = None):
row, col, value = self.coo() row, col, value = self.coo()
if options is not None: if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device())
elif options is not None:
mat = torch.zeros(self.sizes(), dtype=options.dtype, mat = torch.zeros(self.sizes(), dtype=options.dtype,
device=self.device()) device=self.device())
else: else:
...@@ -373,24 +376,6 @@ class SparseTensor(object): ...@@ -373,24 +376,6 @@ class SparseTensor(object):
# Standard Operators ###################################################### # Standard Operators ######################################################
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other): # def __matmul__(self, other):
# return matmul(self, other, reduce='sum') # return matmul(self, other, reduce='sum')
...@@ -400,8 +385,6 @@ class SparseTensor(object): ...@@ -400,8 +385,6 @@ class SparseTensor(object):
# SparseTensor.mean = torch_sparse.reduce.mean # SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min # SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max # SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul # SparseTensor.matmul = matmul
# Python Bindings ############################################################# # Python Bindings #############################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment