Commit cdd1cb4d authored by rusty1s's avatar rusty1s
Browse files

set_value takes scalars

parent d742afd9
......@@ -158,18 +158,22 @@ class SparseStorage(object):
return self._value
def set_value_(self, value, layout=None):
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
if value is not None and get_layout(layout) == 'csc':
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
assert value.device == self.index.device
assert value.size(0) == self.index.size(1)
self._value = value
return self
def set_value(self, value, layout=None):
if isinstance(value, int) or isinstance(value, float):
value = torch.full((self.nnz(), ), device=self.index.device)
elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value = value[self.csc2csr]
assert value.device == self._index.device
assert value.size(0) == self._index.size(1)
if value is not None and get_layout(layout) == 'csc':
value = value[self.csc2csr]
return self.__class__(
self._index,
value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment