Commit 335dfed0 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent 6a7f10e5
...@@ -175,8 +175,9 @@ class SparseStorage(object): ...@@ -175,8 +175,9 @@ class SparseStorage(object):
value = torch.full((self.nnz(), ), device=self.index.device) value = torch.full((self.nnz(), ), device=self.index.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc': elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr] value = value[self.csc2csr]
assert value.device == self._index.device if torch.is_tensor(value):
assert value.size(0) == self._index.size(1) assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
return self.__class__( return self.__class__(
self._index, self._index,
value, value,
......
...@@ -274,27 +274,32 @@ class SparseTensor(object): ...@@ -274,27 +274,32 @@ class SparseTensor(object):
return self.from_storage(storage) return self.from_storage(storage)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
storage = None args = list(args)
non_blocking = getattr(kwargs, 'non_blocking', False)
storage = None
if 'device' in kwargs: if 'device' in kwargs:
device = kwargs['device'] device = kwargs['device']
del kwargs['device'] del kwargs['device']
storage = self.storage.apply(lambda x: x.to( storage = self.storage.apply(
device, non_blocking=getattr(kwargs, 'non_blocking', False))) lambda x: x.to(device, non_blocking=non_blocking))
else:
for arg in args[:]: for arg in args[:]:
if isinstance(arg, str) or isinstance(arg, torch.device): if isinstance(arg, str) or isinstance(arg, torch.device):
storage = self.storage.apply(lambda x: x.to( storage = self.storage.apply(
arg, non_blocking=getattr(kwargs, 'non_blocking', False))) lambda x: x.to(arg, non_blocking=non_blocking))
args.remove(arg) args.remove(arg)
if storage is not None: storage = self.storage if storage is None else storage
self = self.from_storage(storage)
if len(args) > 0 or len(kwargs) > 0: if len(args) > 0 or len(kwargs) > 0:
self = self.type(*args, **kwargs) storage = storage.apply_value(lambda x: x.type(*args, **kwargs))
return self if storage == self.storage: # Nothing changed...
return self
else:
return self.from_storage(storage)
def bfloat16(self): def bfloat16(self):
return self.type(torch.bfloat16) return self.type(torch.bfloat16)
...@@ -454,41 +459,6 @@ SparseTensor.matmul = matmul ...@@ -454,41 +459,6 @@ SparseTensor.matmul = matmul
# SparseTensor.add = add # SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz # SparseTensor.add_nnz = add_nnz
# def remove_diag(self):
# raise NotImplementedError
# def set_diag(self, value):
# raise NotImplementedError
# def __reduce(self, dim, reduce, only_nnz):
# raise NotImplementedError
# def sum(self, dim):
# return self.__reduce(dim, reduce='add', only_nnz=True)
# def prod(self, dim):
# return self.__reduce(dim, reduce='mul', only_nnz=True)
# def min(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
# def max(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
# def mean(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='mean', only_nnz=only_nnz)
# def matmul(self, mat, reduce='add'):
# assert self.numel() == self.nnz() # Disallow multi-dimensional value
# if torch.is_tensor(mat):
# raise NotImplementedError
# elif isinstance(mat, self.__class__):
# assert reduce == 'add'
# assert mat.numel() == mat.nnz() # Disallow multi-dimensional value
# raise NotImplementedError
# raise ValueError('Argument needs to be of type `torch.tensor` or '
# 'type `torch_sparse.SparseTensor`.')
# def __add__(self, other): # def __add__(self, other):
# return self.add(other) # return self.add(other)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment