"torchvision/vscode:/vscode.git/clone" did not exist on "260992372dc4b4c51fdf8c454eb2c291ba0e9b80"
Commit 4a68dd60 authored by rusty1s's avatar rusty1s
Browse files

add mul diag traceable

parent f6bf81df
......@@ -40,3 +40,8 @@ from .tensor import SparseTensor
from .transpose import t
from .narrow import narrow
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import set_diag, remove_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.utils import is_scalar
from torch_sparse.tensor import SparseTensor
def sparse_add(matA, matB):
nnzA, nnzB = matA.nnz(), matB.nnz()
valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
if matA.is_cuda:
@torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
matC_ = matA_ + matB_
rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
matC_ = matC_.tocoo()
row = torch.from_numpy(matC_.row).to(torch.long)
col = torch.from_numpy(matC_.col).to(torch.long)
index = torch.stack([row, col], dim=0)
valC_ = torch.from_numpy(matC_.data)
value = None
if matA.has_value() or matB.has_value():
maskA, maskB = valC_ != 2, valC_ >= 2
size = matA.size() if matA.dim() >= matB.dim() else matA.size()
size = (valC_.size(0), ) + size[2:]
value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
value[maskA] += matA.storage.value if matA.has_value() else 1
value[maskB] += matB.storage.value if matB.has_value() else 1
storage = matA.storage.__class__(index, value, matA.sparse_size(),
rowptr=rowptr, is_sorted=True)
return matA.__class__.from_storage(storage)
def add(src, other):
if is_scalar(other):
return add_nnz(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
value = other.add_(src.storage.value if src.has_value() else 1)
return src.set_value(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
value = other.add_(src.storage.value if src.has_value() else 1)
return src.set_value(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_(src, other):
if is_scalar(other):
return add_nnz_(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='csr')
if value is not None:
value = other.add_(value)
else:
value = other.add_(1)
return src.set_value(value, layout='coo')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def add_nnz(src, other, layout=None):
if is_scalar(other):
if src.has_value():
value = src.storage.value + other
else:
value = torch.full((src.nnz(), ), 1 + other, device=src.device)
return src.set_value(value, layout='coo')
if value is not None:
value = value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value + other
else:
value = other + 1
return src.set_value(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
@torch.jit.script
def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add(other)
else:
value = other.add(1)
return src.set_value(value, layout=layout)
def add_nnz_(src, other, layout=None):
if is_scalar(other):
if src.has_value():
value = src.storage.value.add_(other)
else:
value = torch.full((src.nnz(), ), 1 + other, device=src.device)
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add_(other)
else:
value = other.add(1)
return src.set_value_(value, layout=layout)
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other + 1 # No inplace operation possible.
return src.set_value_(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
SparseTensor.add = lambda self, other: add(self, other)
SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
self, other, layout)
from typing import Optional
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def remove_diag(src, k=0):
@torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
new_row, new_col = row[inv_mask], col[inv_mask]
if src.has_value():
if value is not None:
value = value[inv_mask]
if src.storage.has_rowcount() or src.storage.has_colcount():
rowcount = src.storage._rowcount
colcount = src.storage._colcount
if rowcount is not None or colcount is not None:
mask = ~inv_mask
rowcount = None
if src.storage.has_rowcount():
rowcount = src.storage.rowcount.clone()
rowcount[row[mask]] -= 1
colcount = None
if src.storage.has_colcount():
colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1
storage = src.storage.__class__(row=new_row, col=new_col, value=value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
return src.__class__.from_storage(storage)
def set_diag(src, values=None, k=0):
if values is not None and not src.has_value():
raise ValueError('Sparse matrix has no values')
src = src.remove_diag(k=0)
if rowcount is not None:
rowcount = rowcount.clone()
rowcount[row[mask]] -= 1
if colcount is not None:
colcount = colcount.clone()
colcount[col[mask]] -= 1
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
@torch.jit.script
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=0)
row, col, value = src.coo()
if row.is_cuda:
......@@ -47,7 +48,7 @@ def set_diag(src, values=None, k=0):
inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
diag = torch.arange(start, start + num_diag, device=src.device)
diag = torch.arange(start, start + num_diag, device=row.device)
new_row = row.new_empty(mask.size(0))
new_row[mask] = row
......@@ -57,25 +58,33 @@ def set_diag(src, values=None, k=0):
new_col[mask] = row
new_col[inv_mask] = diag.add_(k)
new_value = None
if src.has_value():
new_value: Optional[torch.Tensor] = None
if value is not None:
new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1
rowcount = None
if src.storage.has_rowcount():
rowcount = src.storage.rowcount.clone()
if values is not None:
new_value[inv_mask] = values
else:
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
device=value.device)
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
rowcount[start:start + num_diag] += 1
colcount = None
if src.storage.has_colcount():
colcount = src.storage.colcount.clone()
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1
storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colptr=None, colcount=colcount,
csr2csc=None, csc2csr=None, is_sorted=True)
return src.from_storage(storage)
return src.__class__.from_storage(storage)
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
self, values, k)
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.storage import get_layout
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
def index_select(src, dim, idx):
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1
if dim == 0:
old_rowptr, col, value = src.csr()
rowcount = src.storage.rowcount
rowcount = src.storage.rowcount()
rowcount = rowcount[idx]
......@@ -22,69 +25,81 @@ def index_select(src, dim, idx):
device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device)
perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
# TODO
# perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm]
if src.has_value():
if value is not None:
value = value[perm]
sparse_size = torch.Size([idx.size(0), src.sparse_size(1)])
sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
value=value, sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
old_colptr, row, value = src.csc()
colcount = src.storage.colcount
colcount = src.storage.colcount()
colcount = colcount[idx]
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
colptr = row.new_zeros(idx.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:])
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
perm = torch.arange(col.size(0), device=col.device)
perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
# TODO
# perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm]
csc2csr = (idx.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if src.has_value():
if value is not None:
value = value[perm][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), idx.size(0)])
sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size, colptr=colptr,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else:
storage = src.storage.apply_value(
lambda x: x.index_select(dim - 1, idx))
value = src.storage.value()
if value is not None:
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
return src.from_storage(storage)
def index_select_nnz(src, idx, layout=None):
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1
if get_layout(layout) == 'csc':
idx = idx[src.storage.csc2csr]
idx = src.storage.csc2csr()[idx]
row, col, value = src.coo()
row, col = row[idx], col[idx]
if src.has_value():
if value is not None:
value = value[idx]
# There is no other information we can maintain...
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(),
is_sorted=True)
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), is_sorted=True)
return src.from_storage(storage)
SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz( # noqa
self, idx, layout)
SparseTensor.index_select_nnz = tmp
import torch
from typing import Optional
from torch_sparse.storage import get_layout
import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
def masked_select(src, dim, mask):
@torch.jit.script
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
assert mask.dim() == 1
......@@ -11,29 +15,33 @@ def masked_select(src, dim, mask):
if dim == 0:
row, col, value = src.coo()
rowcount = src.storage.rowcount
rowcount = src.storage.rowcount()
rowcount = rowcount[mask]
mask = mask[row]
row = torch.arange(rowcount.size(0),
device=row.device).repeat_interleave(rowcount)
col = col[mask]
if src.has_value():
if value is not None:
value = value[mask]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size,
rowcount=rowcount, is_sorted=True)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
row, col, value = src.coo()
csr2csc = src.storage.csr2csc
row, col = row[csr2csc], col[csr2csc]
colcount = src.storage.colcount
csr2csc = src.storage.csr2csc()
row = row[csr2csc]
col = col[csr2csc]
colcount = src.storage.colcount()
colcount = colcount[mask]
......@@ -44,39 +52,47 @@ def masked_select(src, dim, mask):
csc2csr = (colcount.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if src.has_value():
if value is not None:
value = value[csr2csc][mask][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=sparse_size,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, colptr=None, csr2csc=None,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else:
idx = mask.nonzero().view(-1)
storage = src.storage.apply_value(
lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage)
def masked_select_nnz(src, mask, layout=None):
value = src.storage.value()
if value is not None:
idx = mask.nonzero().flatten()
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1
if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr]
mask = mask[src.storage.csc2csr()]
row, col, value = src.coo()
row, col = row[mask], col[mask]
if src.has_value():
if value is not None:
value = value[mask]
# There is no other information we can maintain...
storage = src.storage.__class__(row=row, col=col, value=value,
sparse_size=src.sparse_size(),
is_sorted=True)
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), is_sorted=True)
return src.from_storage(storage)
SparseTensor.masked_select = lambda self, dim, mask: masked_select(
self, dim, mask)
tmp = lambda src, mask, layout=None: masked_select_nnz( # noqa
src, mask, layout)
SparseTensor.masked_select_nnz = tmp
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.utils import is_scalar
def mul(src, other):
if is_scalar(other):
return mul_nnz(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = other.mul_(src.storage.value)
else:
value = other
return src.set_value(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = other.mul_(src.storage.value)
else:
value = other
return src.set_value(value, layout='coo')
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def mul_(src, other):
if is_scalar(other):
return mul_nnz_(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
if value is not None:
value = other.mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.')
elif isinstance(other, src.__class__):
raise NotImplementedError
raise ValueError('Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.')
def mul_nnz(src, other, layout=None):
if torch.is_tensor(other) or is_scalar(other):
if src.has_value():
value = src.storage.value * other
else:
value = other
return src.set_value(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
def mul_nnz_(src, other, layout=None):
if torch.is_tensor(other) or is_scalar(other):
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
if value is not None:
value = value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul(other)
else:
value = other
return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul_(other)
else:
value = other
return src.set_value_(value, layout=layout)
SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout)
......@@ -4,7 +4,8 @@ from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, length: int):
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
start = src.size(dim) + start if start < 0 else start
......
......@@ -4,7 +4,7 @@ from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int):
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1)
......
......@@ -5,14 +5,6 @@ import torch
import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from torch_sparse.utils import is_scalar
......@@ -403,12 +395,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
......@@ -417,14 +403,6 @@ class SparseTensor(object):
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_
# Python Bindings #############################################################
......
......@@ -6,7 +6,7 @@ from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor):
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment