"src/kernel/vscode:/vscode.git/clone" did not exist on "35bed2a9d04d01a0b11b98e59710f21a6e702ef7"
Commit e6f5c3f0 authored by rusty1s's avatar rusty1s
Browse files

removed jit in external functions due to wrong caching behaviour

parent 3c259af5
......@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -24,7 +23,6 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -44,7 +42,6 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......@@ -55,7 +52,6 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout)
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0
if dim < 0:
......@@ -142,7 +141,6 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.')
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
......@@ -25,24 +24,14 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
colcount = colcount.clone()
colcount[col[mask]] -= 1
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def set_diag(src: SparseTensor,
values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k)
row, col, value = src.coo()
......@@ -69,8 +58,7 @@ def set_diag(src: SparseTensor,
if values is not None:
new_value[inv_mask] = values
else:
new_value[inv_mask] = torch.ones((num_diag, ),
dtype=value.dtype,
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
device=value.device)
rowcount = src.storage._rowcount
......@@ -83,22 +71,13 @@ def set_diag(src: SparseTensor,
colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = SparseStorage(
row=new_row,
rowptr=None,
col=new_col,
value=new_value,
sparse_sizes=src.sparse_sizes(),
rowcount=rowcount,
colptr=None,
colcount=colcount,
csr2csc=None,
csc2csr=None,
is_sorted=True)
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colptr=None, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0:
......
......@@ -6,7 +6,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
......@@ -79,7 +78,6 @@ def index_select(src: SparseTensor, dim: int,
raise ValueError
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
......@@ -73,7 +72,6 @@ def masked_select(src: SparseTensor, dim: int,
raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1
......
......@@ -4,7 +4,6 @@ import torch
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
......@@ -24,12 +23,10 @@ def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
csr2csc, other)
@torch.jit.script
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
@torch.jit.script
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
......@@ -51,21 +48,18 @@ def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
colptr, csr2csc, other)
@torch.jit.script
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
@torch.jit.script
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
@torch.jit.script
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
......@@ -80,7 +74,6 @@ def spmm(src: SparseTensor, other: torch.Tensor,
raise ValueError
@torch.jit.script
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
......@@ -88,21 +81,14 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
return SparseTensor(
row=None,
rowptr=rowptrC,
col=colC,
value=valueC,
sparse_sizes=torch.Size([M, K]),
is_sorted=True)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=torch.Size([M, K]), is_sorted=True)
@torch.jit.script
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
@torch.jit.script
def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
......@@ -113,8 +99,7 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise ValueError
def matmul(src: SparseTensor,
other: Union[torch.Tensor, SparseTensor],
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
reduce: str = "sum"):
if torch.is_tensor(other):
return spmm(src, other, reduce)
......
......@@ -5,7 +5,6 @@ from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
@torch.jit.script
def partition_kway(
src: SparseTensor,
num_parts: int) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
......
......@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......
......@@ -5,7 +5,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
......@@ -80,7 +79,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise ValueError
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
......
......@@ -3,7 +3,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_symmetric()
......
......@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value()
......@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise ValueError
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum')
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean')
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min')
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max')
......
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1)
......
......@@ -30,8 +30,7 @@ class SparseStorage(object):
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self,
row: Optional[torch.Tensor] = None,
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
......@@ -192,8 +191,7 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self,
value: Optional[torch.Tensor],
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -205,8 +203,7 @@ class SparseStorage(object):
self._value = value
return self
def set_value(self,
value: Optional[torch.Tensor],
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -215,18 +212,11 @@ class SparseStorage(object):
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def sparse_sizes(self) -> List[int]:
return self._sparse_sizes
......@@ -264,18 +254,11 @@ class SparseStorage(object):
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -320,10 +303,8 @@ class SparseStorage(object):
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(
torch.ones_like(self._col),
self._col,
dim_size=self._sparse_sizes[1])
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
......@@ -375,18 +356,10 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def fill_cache_(self):
self.row()
......@@ -421,18 +394,12 @@ class SparseStorage(object):
return count
def copy(self):
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def clone(self):
row = self._row
......@@ -460,18 +427,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
......@@ -512,18 +472,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self):
row = self._row
......@@ -551,18 +504,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
......
......@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment