Commit 0692b905 authored by rusty1s's avatar rusty1s
Browse files

fixes

parent e696cfd6
......@@ -38,8 +38,7 @@ def cat(tensors, dim):
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None,
is_sorted=True)
rowptr=torch.cat(rowptrs) if has_rowptr else None, is_sorted=True)
if dim == 1:
raise NotImplementedError
......
......@@ -61,8 +61,8 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
deterministic = src.storage._csr2csc is not None or deterministic
if sparse_dims[0] == 0 and deterministic and src.has_value():
csr2csc, colptr = src.storage.csr2csc, src.storage.colptr
out = segment_csr(value[csr2csc], colptr)
csr2csc = src.storage.csr2csc
out = segment_csr(value[csr2csc], src.storage.colptr)
out = out[0] if len(dense_dims) > 0 and isinstance(out, tuple) else out
return out
......
......@@ -81,7 +81,7 @@ class SparseTensor(object):
return self.storage.rowptr, self.storage.col, self.storage.value
def csc(self):
perm = self.storage.csr2csc
perm = self.storage.csr2csc # Compute `csr2csc` first.
return (self.storage.colptr, self.storage.row[perm],
self.storage.value[perm] if self.has_value() else None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment