Commit 77bb0595 authored by rusty1s's avatar rusty1s
Browse files

added todos

parent 6613b175
import torch
# import torch
from torch_scatter import scatter_add
# from torch_sparse.tensor import SparseTensor
......@@ -17,7 +17,8 @@ from torch_scatter import scatter_add
# if reduce in ['add', 'mean']:
# return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
# else:
# return torch_sparse.spmm_cuda.spmm_arg(rowptr, col, value, mat, reduce)
# return torch_sparse.spmm_cuda.spmm_arg(
# rowptr, col, value, mat, reduce)
def spmm(index, value, m, n, matrix):
......
import warnings
import torch
import torch_scatter
from torch_scatter import scatter_add, segment_add
from torch_scatter import segment_csr
__cache__ = {'enabled': True}
......@@ -66,9 +65,17 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
]
def __init__(self, index, value=None, sparse_size=None, rowcount=None,
rowptr=None, colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=False):
def __init__(self,
index,
value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2
......@@ -189,11 +196,13 @@ class SparseStorage(object):
@cached_property
def rowcount(self):
# TODO
one = torch.ones_like(self.row)
return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])
@cached_property
def rowptr(self):
# TODO
rowcount = self.rowcount
rowptr = rowcount.new_zeros(rowcount.numel() + 1)
torch.cumsum(rowcount, dim=0, out=rowptr[1:])
......@@ -201,11 +210,13 @@ class SparseStorage(object):
@cached_property
def colcount(self):
# TODO
one = torch.ones_like(self.col)
return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
@cached_property
def colptr(self):
# TODO
colcount = self.colcount
colptr = colcount.new_zeros(colcount.numel() + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
......@@ -236,10 +247,9 @@ class SparseStorage(object):
value = self.value
if self.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
idx = mask.cumsum(0) - 1
op = getattr(torch_scatter, f'scatter_{reduce}')
value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1)
dim_size = idx[-1].item() + 1
value = segment_csr(idx, value, dim_size=dim_size, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment