Commit 801970bf authored by rusty1s's avatar rusty1s
Browse files

add fix

parent 8481d0d6
......@@ -42,9 +42,9 @@ def add(src, other):
return add_nnz(src, other)
elif torch.is_tensor(other):
(row, col), value = src.coo()
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), src.storage.rowptr)
other = gather_csr(other.squeeze(1), rowptr)
value = other.add_(src.storage.value if src.has_value() else 1)
return src.set_value(value, layout='csr')
......@@ -69,9 +69,9 @@ def add_(src, other):
return add_nnz_(src, other)
elif torch.is_tensor(other):
(row, col), value = src.coo()
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), src.storage.rowptr)
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = src.storage.value.add_(other)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment