Commit e0c91397 authored by rusty1s's avatar rusty1s
Browse files

fix

parent 436b2e50
......@@ -282,10 +282,9 @@ class SparseTensor(object):
idx[row.numel() + 1:] += row
idx, perm = idx.sort()
perm = perm[1:].sub_(1)
mask = idx[1:] > idx[:-1]
idx2 = perm[mask]
perm = perm[1:].sub_(1)
idx = perm[mask]
if value is not None:
ptr = mask.nonzero().flatten()
......@@ -293,8 +292,8 @@ class SparseTensor(object):
value = torch.cat([value, value])[perm]
value = segment_csr(value, ptr, reduce=reduce)
new_row = torch.cat([row, col], dim=0, out=perm)[idx2]
new_col = torch.cat([col, row], dim=0, out=perm)[idx2]
new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=(N, N), is_sorted=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment