Commit 436b2e50 authored by rusty1s's avatar rusty1s
Browse files

more memory efficient

parent 906c97e4
...@@ -19,3 +19,21 @@ def test_getitem(dtype, device): ...@@ -19,3 +19,21 @@ def test_getitem(dtype, device):
assert mat[..., :10].sizes() == [50, 10] assert mat[..., :10].sizes() == [50, 10]
assert mat[idx1, idx2].sizes() == [10, 10] assert mat[idx1, idx2].sizes() == [10, 10]
assert mat[idx1.tolist()].sizes() == [10, 40] assert mat[idx1.tolist()].sizes() == [10, 40]
@pytest.mark.parametrize('device', devices)
def test_to_symmetric(device):
row = torch.tensor([0, 0, 0, 1, 1], device=device)
col = torch.tensor([0, 1, 2, 0, 2], device=device)
value = torch.arange(1, 6, device=device)
mat = SparseTensor(row=row, col=col, value=value)
assert not mat.is_symmetric()
mat = mat.to_symmetric()
assert mat.is_symmetric()
assert mat.to_dense().tolist() == [
[2, 6, 3],
[6, 0, 5],
[3, 5, 0],
]
...@@ -382,7 +382,6 @@ class SparseStorage(object): ...@@ -382,7 +382,6 @@ class SparseStorage(object):
ptr = mask.nonzero().flatten() ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))]) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value, return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None, sparse_sizes=self._sparse_sizes, rowcount=None,
......
...@@ -3,6 +3,7 @@ from typing import Optional, List, Tuple, Dict, Union, Any ...@@ -3,6 +3,7 @@ from typing import Optional, List, Tuple, Dict, Union, Any
import torch import torch
import scipy.sparse import scipy.sparse
from torch_scatter import segment_csr
from torch_sparse.storage import SparseStorage, get_layout from torch_sparse.storage import SparseStorage, get_layout
...@@ -270,17 +271,33 @@ class SparseTensor(object): ...@@ -270,17 +271,33 @@ class SparseTensor(object):
return bool((value1 == value2).all()) return bool((value1 == value2).all())
def to_symmetric(self, reduce: str = "sum"): def to_symmetric(self, reduce: str = "sum"):
N = max(self.size(0), self.size(1))
row, col, value = self.coo() row, col, value = self.coo()
idx = col.new_full((2 * col.numel() + 1, ), -1)
idx[1:row.numel() + 1] = row
idx[row.numel() + 1:] = col
idx[1:] *= N
idx[1:row.numel() + 1] += col
idx[row.numel() + 1:] += row
idx, perm = idx.sort()
perm = perm[1:].sub_(1)
mask = idx[1:] > idx[:-1]
idx2 = perm[mask]
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
if value is not None: if value is not None:
value = torch.cat([value, value], dim=0) ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
value = torch.cat([value, value])[perm]
value = segment_csr(value, ptr, reduce=reduce)
N = max(self.size(0), self.size(1)) new_row = torch.cat([row, col], dim=0, out=perm)[idx2]
new_col = torch.cat([col, row], dim=0, out=perm)[idx2]
out = SparseTensor(row=row, rowptr=None, col=col, value=value, out = SparseTensor(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=(N, N), is_sorted=False) sparse_sizes=(N, N), is_sorted=True)
out = out.coalesce(reduce)
return out return out
def detach_(self): def detach_(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment