Commit e6f5c3f0 authored by rusty1s's avatar rusty1s
Browse files

removed jit in external functions due to wrong caching behaviour

parent 3c259af5
...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr ...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor, def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
...@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor, def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
if dim < 0: if dim < 0:
...@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor: ...@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.') '[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor: def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo() row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k) inv_mask = row != col if k == 0 else row != (col - k)
...@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: ...@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone() colcount = colcount.clone()
colcount[col[mask]] -= 1 colcount[col[mask]] -= 1
storage = SparseStorage( storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
row=new_row, sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
rowptr=None, colptr=None, colcount=colcount, csr2csc=None,
col=new_col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
@torch.jit.script def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor: k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k) src = remove_diag(src, k=k)
row, col, value = src.coo() row, col, value = src.coo()
...@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor, ...@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if values is not None: if values is not None:
new_value[inv_mask] = values new_value[inv_mask] = values
else: else:
new_value[inv_mask] = torch.ones((num_diag, ), new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
dtype=value.dtype,
device=value.device) device=value.device)
rowcount = src.storage._rowcount rowcount = src.storage._rowcount
...@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor, ...@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount = colcount.clone() colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1 colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage( storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
row=new_row, value=new_value, sparse_sizes=src.sparse_sizes(),
rowptr=None, rowcount=rowcount, colptr=None, colcount=colcount,
col=new_col, csr2csc=None, csc2csr=None, is_sorted=True)
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor: def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k) num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0: if k < 0:
......
...@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout ...@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def index_select(src: SparseTensor, dim: int, def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor: idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
...@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int, ...@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise ValueError raise ValueError
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor, def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1 assert idx.dim() == 1
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def masked_select(src: SparseTensor, dim: int, def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor: mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
...@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int, ...@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise ValueError raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor, def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1 assert mask.dim() == 1
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
...@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: ...@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc, other) csr2csc, other)
@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other) return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
...@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor: ...@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr, csr2csc, other) colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor, def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other) return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor, def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other) return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor, def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor: reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
...@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor, ...@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise ValueError raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0) assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr() rowptrA, colA, valueA = src.csr()
...@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: ...@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1) M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K) rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor( return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
row=None, sparse_sizes=torch.Size([M, K]), is_sorted=True)
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor: def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other) return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor, def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor: reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add': if reduce == 'sum' or reduce == 'add':
...@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor, ...@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError raise ValueError
def matmul(src: SparseTensor, def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"): reduce: str = "sum"):
if torch.is_tensor(other): if torch.is_tensor(other):
return spmm(src, other, reduce) return spmm(src, other, reduce)
......
...@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor ...@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute from torch_sparse.permute import permute
@torch.jit.script
def partition_kway( def partition_kway(
src: SparseTensor, src: SparseTensor,
num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
......
...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr ...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor, def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
...@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor, def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
......
...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage ...@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor: length: int) -> SparseTensor:
if dim < 0: if dim < 0:
...@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int, ...@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise ValueError raise ValueError
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int], def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor: length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence # This function builds the inverse operation of `cat_diag` and should hence
......
...@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage ...@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor: def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_symmetric() assert src.is_symmetric()
......
...@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr ...@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None, def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor: reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value() value = src.storage.value()
...@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None, ...@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise ValueError raise ValueError
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum') return reduction(src, dim, reduce='sum')
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean') return reduction(src, dim, reduce='mean')
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min') return reduction(src, dim, reduce='min')
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max') return reduction(src, dim, reduce='max')
......
import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor: def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1) return narrow(src, dim, start=idx, length=1)
......
...@@ -30,8 +30,7 @@ class SparseStorage(object): ...@@ -30,8 +30,7 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor] _csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor] _csc2csr: Optional[torch.Tensor]
def __init__(self, def __init__(self, row: Optional[torch.Tensor] = None,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
...@@ -192,8 +191,7 @@ class SparseStorage(object): ...@@ -192,8 +191,7 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]: def value(self) -> Optional[torch.Tensor]:
return self._value return self._value
def set_value_(self, def set_value_(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -205,8 +203,7 @@ class SparseStorage(object): ...@@ -205,8 +203,7 @@ class SparseStorage(object):
self._value = value self._value = value
return self return self
def set_value(self, def set_value(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -215,18 +212,11 @@ class SparseStorage(object): ...@@ -215,18 +212,11 @@ class SparseStorage(object):
assert value.device == self._col.device assert value.device == self._col.device
assert value.size(0) == self._col.numel() assert value.size(0) == self._col.numel()
return SparseStorage( return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
row=self._row, value=value, sparse_sizes=self._sparse_sizes,
rowptr=self._rowptr, rowcount=self._rowcount, colptr=self._colptr,
col=self._col, colcount=self._colcount, csr2csc=self._csr2csc,
value=value, csc2csr=self._csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def sparse_sizes(self) -> List[int]: def sparse_sizes(self) -> List[int]:
return self._sparse_sizes return self._sparse_sizes
...@@ -264,18 +254,11 @@ class SparseStorage(object): ...@@ -264,18 +254,11 @@ class SparseStorage(object):
if colcount is not None: if colcount is not None:
colcount = colcount[:-diff_1] colcount = colcount[:-diff_1]
return SparseStorage( return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
row=self._row, value=self._value, sparse_sizes=sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=self._col, colcount=colcount, csr2csc=self._csr2csc,
value=self._value, csc2csr=self._csc2csr, is_sorted=True)
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -320,10 +303,8 @@ class SparseStorage(object): ...@@ -320,10 +303,8 @@ class SparseStorage(object):
if colptr is not None: if colptr is not None:
colcount = colptr[1:] - colptr[:-1] colcount = colptr[1:] - colptr[:-1]
else: else:
colcount = scatter_add( colcount = scatter_add(torch.ones_like(self._col), self._col,
torch.ones_like(self._col), dim_size=self._sparse_sizes[1])
self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount self._colcount = colcount
return colcount return colcount
...@@ -375,18 +356,10 @@ class SparseStorage(object): ...@@ -375,18 +356,10 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return SparseStorage( return SparseStorage(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes, rowcount=None,
rowptr=None, colptr=None, colcount=None, csr2csc=None,
col=col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -421,18 +394,12 @@ class SparseStorage(object): ...@@ -421,18 +394,12 @@ class SparseStorage(object):
return count return count
def copy(self): def copy(self):
return SparseStorage( return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
row=self._row, value=self._value,
rowptr=self._rowptr, sparse_sizes=self._sparse_sizes,
col=self._col, rowcount=self._rowcount, colptr=self._colptr,
value=self._value, colcount=self._colcount, csr2csc=self._csr2csc,
sparse_sizes=self._sparse_sizes, csc2csr=self._csc2csr, is_sorted=True)
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -460,18 +427,11 @@ class SparseStorage(object): ...@@ -460,18 +427,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def type_as(self, tensor=torch.Tensor): def type_as(self, tensor=torch.Tensor):
value = self._value value = self._value
...@@ -512,18 +472,11 @@ class SparseStorage(object): ...@@ -512,18 +472,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking) csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -551,18 +504,11 @@ class SparseStorage(object): ...@@ -551,18 +504,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
...@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage ...@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor: def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc() csr2csc = src.storage.csr2csc()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment