Commit 77bb0595 authored by rusty1s's avatar rusty1s
Browse files

added todos

parent 6613b175
import torch # import torch
from torch_scatter import scatter_add from torch_scatter import scatter_add
# from torch_sparse.tensor import SparseTensor # from torch_sparse.tensor import SparseTensor
...@@ -17,7 +17,8 @@ from torch_scatter import scatter_add ...@@ -17,7 +17,8 @@ from torch_scatter import scatter_add
# if reduce in ['add', 'mean']: # if reduce in ['add', 'mean']:
# return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce) # return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
# else: # else:
# return torch_sparse.spmm_cuda.spmm_arg(rowptr, col, value, mat, reduce) # return torch_sparse.spmm_cuda.spmm_arg(
# rowptr, col, value, mat, reduce)
def spmm(index, value, m, n, matrix): def spmm(index, value, m, n, matrix):
......
import warnings import warnings
import torch import torch
import torch_scatter from torch_scatter import segment_csr
from torch_scatter import scatter_add, segment_add
__cache__ = {'enabled': True} __cache__ = {'enabled': True}
...@@ -66,9 +65,17 @@ class SparseStorage(object): ...@@ -66,9 +65,17 @@ class SparseStorage(object):
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr' 'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
] ]
def __init__(self, index, value=None, sparse_size=None, rowcount=None, def __init__(self,
rowptr=None, colcount=None, colptr=None, csr2csc=None, index,
csc2csr=None, is_sorted=False): value=None,
sparse_size=None,
rowcount=None,
rowptr=None,
colcount=None,
colptr=None,
csr2csc=None,
csc2csr=None,
is_sorted=False):
assert index.dtype == torch.long assert index.dtype == torch.long
assert index.dim() == 2 and index.size(0) == 2 assert index.dim() == 2 and index.size(0) == 2
...@@ -189,11 +196,13 @@ class SparseStorage(object): ...@@ -189,11 +196,13 @@ class SparseStorage(object):
@cached_property @cached_property
def rowcount(self): def rowcount(self):
# TODO
one = torch.ones_like(self.row) one = torch.ones_like(self.row)
return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0]) return segment_add(one, self.row, dim=0, dim_size=self._sparse_size[0])
@cached_property @cached_property
def rowptr(self): def rowptr(self):
# TODO
rowcount = self.rowcount rowcount = self.rowcount
rowptr = rowcount.new_zeros(rowcount.numel() + 1) rowptr = rowcount.new_zeros(rowcount.numel() + 1)
torch.cumsum(rowcount, dim=0, out=rowptr[1:]) torch.cumsum(rowcount, dim=0, out=rowptr[1:])
...@@ -201,11 +210,13 @@ class SparseStorage(object): ...@@ -201,11 +210,13 @@ class SparseStorage(object):
@cached_property @cached_property
def colcount(self): def colcount(self):
# TODO
one = torch.ones_like(self.col) one = torch.ones_like(self.col)
return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1]) return scatter_add(one, self.col, dim=0, dim_size=self._sparse_size[1])
@cached_property @cached_property
def colptr(self): def colptr(self):
# TODO
colcount = self.colcount colcount = self.colcount
colptr = colcount.new_zeros(colcount.numel() + 1) colptr = colcount.new_zeros(colcount.numel() + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:]) torch.cumsum(colcount, dim=0, out=colptr[1:])
...@@ -236,10 +247,9 @@ class SparseStorage(object): ...@@ -236,10 +247,9 @@ class SparseStorage(object):
value = self.value value = self.value
if self.has_value(): if self.has_value():
assert reduce in ['add', 'mean', 'min', 'max']
idx = mask.cumsum(0) - 1 idx = mask.cumsum(0) - 1
op = getattr(torch_scatter, f'scatter_{reduce}') dim_size = idx[-1].item() + 1
value = op(value, idx, dim=0, dim_size=idx[-1].item() + 1) value = segment_csr(idx, value, dim_size=dim_size, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return self.__class__(index, value, self.sparse_size(), is_sorted=True) return self.__class__(index, value, self.sparse_size(), is_sorted=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment