Commit 4a68dd60 authored by rusty1s's avatar rusty1s
Browse files

add mul diag traceable

parent f6bf81df
...@@ -40,3 +40,8 @@ from .tensor import SparseTensor ...@@ -40,3 +40,8 @@ from .tensor import SparseTensor
from .transpose import t from .transpose import t
from .narrow import narrow from .narrow import narrow
from .select import select from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import set_diag, remove_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from typing import Optional
import torch import torch
from torch_scatter import gather_csr from torch_scatter import gather_csr
from torch_sparse.utils import is_scalar from torch_sparse.tensor import SparseTensor
def sparse_add(matA, matB):
nnzA, nnzB = matA.nnz(), matB.nnz()
valA = torch.full((nnzA, ), 1, dtype=torch.uint8, device=matA.device)
valB = torch.full((nnzB, ), 2, dtype=torch.uint8, device=matB.device)
if matA.is_cuda: @torch.jit.script
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else: else:
matA_ = matA.set_value(valA, layout='csr').to_scipy(layout='csr')
matB_ = matB.set_value(valB, layout='csr').to_scipy(layout='csr')
matC_ = matA_ + matB_
rowptr = torch.from_numpy(matC_.indptr).to(torch.long)
matC_ = matC_.tocoo()
row = torch.from_numpy(matC_.row).to(torch.long)
col = torch.from_numpy(matC_.col).to(torch.long)
index = torch.stack([row, col], dim=0)
valC_ = torch.from_numpy(matC_.data)
value = None
if matA.has_value() or matB.has_value():
maskA, maskB = valC_ != 2, valC_ >= 2
size = matA.size() if matA.dim() >= matB.dim() else matA.size()
size = (valC_.size(0), ) + size[2:]
value = torch.zeros(size, dtype=matA.dtype, device=matA.device)
value[maskA] += matA.storage.value if matA.has_value() else 1
value[maskB] += matB.storage.value if matB.has_value() else 1
storage = matA.storage.__class__(index, value, matA.sparse_size(),
rowptr=rowptr, is_sorted=True)
return matA.__class__.from_storage(storage)
def add(src, other):
if is_scalar(other):
return add_nnz(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
value = other.add_(src.storage.value if src.has_value() else 1)
return src.set_value(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
value = other.add_(src.storage.value if src.has_value() else 1)
return src.set_value(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,' raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size ' f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.') f'{other.size()}.')
elif isinstance(other, src.__class__): if value is not None:
raise NotImplementedError value = other.add_(value)
else:
raise ValueError('Argument `other` needs to be of type `int`, `float`, ' value = other.add_(1)
'`torch.tensor` or `torch_sparse.SparseTensor`.') return src.set_value(value, layout='coo')
def add_(src, other):
if is_scalar(other):
return add_nnz_(src, other)
elif torch.is_tensor(other):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
@torch.jit.script
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,' raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size ' f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.') f'{other.size()}.')
elif isinstance(other, src.__class__): if value is not None:
raise NotImplementedError value = value.add_(other)
else:
raise ValueError('Argument `other` needs to be of type `int`, `float`, ' value = other.add_(1)
'`torch.tensor` or `torch_sparse.SparseTensor`.') return src.set_value_(value, layout='coo')
def add_nnz(src, other, layout=None):
if is_scalar(other):
if src.has_value():
value = src.storage.value + other
else:
value = torch.full((src.nnz(), ), 1 + other, device=src.device)
return src.set_value(value, layout='coo')
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value + other
else:
value = other + 1
return src.set_value(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or ' @torch.jit.script
'`torch.tensor`.') def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add(other)
else:
value = other.add(1)
return src.set_value(value, layout=layout)
def add_nnz_(src, other, layout=None): @torch.jit.script
if is_scalar(other): def add_nnz_(src: SparseTensor, other: torch.Tensor,
if src.has_value(): layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value.add_(other) value = src.storage.value()
else: if value is not None:
value = torch.full((src.nnz(), ), 1 + other, device=src.device) value = value.add_(other)
return src.set_value_(value, layout='coo') else:
value = other.add(1)
return src.set_value_(value, layout=layout)
if torch.is_tensor(other):
if src.has_value():
value = src.storage.value.add_(other)
else:
value = other + 1 # No inplace operation possible.
return src.set_value_(value, layout='coo')
raise ValueError('Argument `other` needs to be of type `int`, `float` or ' SparseTensor.add = lambda self, other: add(self, other)
'`torch.tensor`.') SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
self, other, layout)
from typing import Optional
import torch import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def remove_diag(src, k=0): @torch.jit.script
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
row, col, value = src.coo() row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k) inv_mask = row != col if k == 0 else row != (col - k)
new_row, new_col = row[inv_mask], col[inv_mask] new_row, new_col = row[inv_mask], col[inv_mask]
if src.has_value(): if value is not None:
value = value[inv_mask] value = value[inv_mask]
if src.storage.has_rowcount() or src.storage.has_colcount(): rowcount = src.storage._rowcount
colcount = src.storage._colcount
if rowcount is not None or colcount is not None:
mask = ~inv_mask mask = ~inv_mask
if rowcount is not None:
rowcount = None rowcount = rowcount.clone()
if src.storage.has_rowcount(): rowcount[row[mask]] -= 1
rowcount = src.storage.rowcount.clone() if colcount is not None:
rowcount[row[mask]] -= 1 colcount = colcount.clone()
colcount[col[mask]] -= 1
colcount = None
if src.storage.has_colcount(): storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
colcount = src.storage.colcount.clone() sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
colcount[col[mask]] -= 1 colptr=None, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
storage = src.storage.__class__(row=new_row, col=new_col, value=value, return src.from_storage(storage)
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True) @torch.jit.script
return src.__class__.from_storage(storage) def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=0)
def set_diag(src, values=None, k=0):
if values is not None and not src.has_value():
raise ValueError('Sparse matrix has no values')
src = src.remove_diag(k=0)
row, col, value = src.coo() row, col, value = src.coo()
if row.is_cuda: if row.is_cuda:
...@@ -47,7 +48,7 @@ def set_diag(src, values=None, k=0): ...@@ -47,7 +48,7 @@ def set_diag(src, values=None, k=0):
inv_mask = ~mask inv_mask = ~mask
start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
diag = torch.arange(start, start + num_diag, device=src.device) diag = torch.arange(start, start + num_diag, device=row.device)
new_row = row.new_empty(mask.size(0)) new_row = row.new_empty(mask.size(0))
new_row[mask] = row new_row[mask] = row
...@@ -57,25 +58,33 @@ def set_diag(src, values=None, k=0): ...@@ -57,25 +58,33 @@ def set_diag(src, values=None, k=0):
new_col[mask] = row new_col[mask] = row
new_col[inv_mask] = diag.add_(k) new_col[inv_mask] = diag.add_(k)
new_value = None new_value: Optional[torch.Tensor] = None
if src.has_value(): if value is not None:
new_value = value.new_empty((mask.size(0), ) + value.size()[1:]) new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1 if values is not None:
new_value[inv_mask] = values
rowcount = None else:
if src.storage.has_rowcount(): new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
rowcount = src.storage.rowcount.clone() device=value.device)
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
rowcount[start:start + num_diag] += 1 rowcount[start:start + num_diag] += 1
colcount = None colcount = src.storage._colcount
if src.storage.has_colcount(): if colcount is not None:
colcount = src.storage.colcount.clone() colcount = colcount.clone()
colcount[start + k:start + num_diag + k] += 1 colcount[start + k:start + num_diag + k] += 1
storage = src.storage.__class__(row=new_row, col=new_col, value=new_value, storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
sparse_size=src.sparse_size(), value=new_value, sparse_sizes=src.sparse_sizes(),
rowcount=rowcount, colcount=colcount, rowcount=rowcount, colptr=None, colcount=colcount,
is_sorted=True) csr2csc=None, csc2csr=None, is_sorted=True)
return src.from_storage(storage)
return src.__class__.from_storage(storage) SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
self, values, k)
from typing import Optional
import torch import torch
from torch_scatter import gather_csr from torch_scatter import gather_csr
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.storage import get_layout from torch_sparse.tensor import SparseTensor
def index_select(src, dim, idx): @torch.jit.script
def index_select(src: SparseTensor, dim: int,
idx: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1 assert idx.dim() == 1
if dim == 0: if dim == 0:
old_rowptr, col, value = src.csr() old_rowptr, col, value = src.csr()
rowcount = src.storage.rowcount rowcount = src.storage.rowcount()
rowcount = rowcount[idx] rowcount = rowcount[idx]
...@@ -22,69 +25,81 @@ def index_select(src, dim, idx): ...@@ -22,69 +25,81 @@ def index_select(src, dim, idx):
device=col.device).repeat_interleave(rowcount) device=col.device).repeat_interleave(rowcount)
perm = torch.arange(row.size(0), device=row.device) perm = torch.arange(row.size(0), device=row.device)
perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr) # TODO
# perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
col = col[perm] col = col[perm]
if src.has_value(): if value is not None:
value = value[perm] value = value[perm]
sparse_size = torch.Size([idx.size(0), src.sparse_size(1)]) sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
storage = src.storage.__class__(row=row, rowptr=rowptr, col=col, storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
value=value, sparse_size=sparse_size, sparse_sizes=sparse_sizes, rowcount=rowcount,
rowcount=rowcount, is_sorted=True) colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1: elif dim == 1:
old_colptr, row, value = src.csc() old_colptr, row, value = src.csc()
colcount = src.storage.colcount colcount = src.storage.colcount()
colcount = colcount[idx] colcount = colcount[idx]
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
colptr = row.new_zeros(idx.size(0) + 1) colptr = row.new_zeros(idx.size(0) + 1)
torch.cumsum(colcount, dim=0, out=colptr[1:]) torch.cumsum(colcount, dim=0, out=colptr[1:])
col = torch.arange(idx.size(0),
device=row.device).repeat_interleave(colcount)
perm = torch.arange(col.size(0), device=col.device) perm = torch.arange(col.size(0), device=col.device)
perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr) # TODO
# perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
row = row[perm] row = row[perm]
csc2csr = (idx.size(0) * row + col).argsort() csc2csr = (idx.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr] row, col = row[csc2csr], col[csc2csr]
if src.has_value(): if value is not None:
value = value[perm][csc2csr] value = value[perm][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), idx.size(0)]) sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
storage = src.storage.__class__(row=row, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=sparse_size, colptr=colptr, sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, csc2csr=csc2csr, colptr=colptr, colcount=colcount, csr2csc=None,
is_sorted=True) csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else: else:
storage = src.storage.apply_value( value = src.storage.value()
lambda x: x.index_select(dim - 1, idx)) if value is not None:
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
return src.from_storage(storage)
@torch.jit.script
def index_select_nnz(src, idx, layout=None): def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert idx.dim() == 1 assert idx.dim() == 1
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
idx = idx[src.storage.csc2csr] idx = src.storage.csc2csr()[idx]
row, col, value = src.coo() row, col, value = src.coo()
row, col = row[idx], col[idx] row, col = row[idx], col[idx]
if src.has_value(): if value is not None:
value = value[idx] value = value[idx]
# There is no other information we can maintain... return SparseTensor(row=row, rowptr=None, col=col, value=value,
storage = src.storage.__class__(row=row, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
sparse_size=src.sparse_size(),
is_sorted=True)
return src.from_storage(storage) SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz( # noqa
self, idx, layout)
SparseTensor.index_select_nnz = tmp
import torch from typing import Optional
from torch_sparse.storage import get_layout import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
def masked_select(src, dim, mask): @torch.jit.script
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
assert mask.dim() == 1 assert mask.dim() == 1
...@@ -11,29 +15,33 @@ def masked_select(src, dim, mask): ...@@ -11,29 +15,33 @@ def masked_select(src, dim, mask):
if dim == 0: if dim == 0:
row, col, value = src.coo() row, col, value = src.coo()
rowcount = src.storage.rowcount rowcount = src.storage.rowcount()
rowcount = rowcount[mask] rowcount = rowcount[mask]
mask = mask[row] mask = mask[row]
row = torch.arange(rowcount.size(0), row = torch.arange(rowcount.size(0),
device=row.device).repeat_interleave(rowcount) device=row.device).repeat_interleave(rowcount)
col = col[mask] col = col[mask]
if src.has_value(): if value is not None:
value = value[mask] value = value[mask]
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_sizes = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(row=row, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=sparse_size, sparse_sizes=sparse_sizes, rowcount=rowcount,
rowcount=rowcount, is_sorted=True) colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1: elif dim == 1:
row, col, value = src.coo() row, col, value = src.coo()
csr2csc = src.storage.csr2csc csr2csc = src.storage.csr2csc()
row, col = row[csr2csc], col[csr2csc] row = row[csr2csc]
colcount = src.storage.colcount col = col[csr2csc]
colcount = src.storage.colcount()
colcount = colcount[mask] colcount = colcount[mask]
...@@ -44,39 +52,47 @@ def masked_select(src, dim, mask): ...@@ -44,39 +52,47 @@ def masked_select(src, dim, mask):
csc2csr = (colcount.size(0) * row + col).argsort() csc2csr = (colcount.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr] row, col = row[csc2csr], col[csc2csr]
if src.has_value(): if value is not None:
value = value[csr2csc][mask][csc2csr] value = value[csr2csc][mask][csc2csr]
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_sizes = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(row=row, col=col, value=value, storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_size=sparse_size, sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, csc2csr=csc2csr, colcount=colcount, colptr=None, csr2csc=None,
is_sorted=True) csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else: else:
idx = mask.nonzero().view(-1) value = src.storage.value()
storage = src.storage.apply_value( if value is not None:
lambda x: x.index_select(dim - 1, idx)) idx = mask.nonzero().flatten()
return src.set_value(value.index_select(dim - 1, idx),
return src.from_storage(storage) layout='coo')
else:
raise ValueError
def masked_select_nnz(src, mask, layout=None):
@torch.jit.script
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1 assert mask.dim() == 1
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr] mask = mask[src.storage.csc2csr()]
row, col, value = src.coo() row, col, value = src.coo()
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
if src.has_value(): if value is not None:
value = value[mask] value = value[mask]
# There is no other information we can maintain... return SparseTensor(row=row, rowptr=None, col=col, value=value,
storage = src.storage.__class__(row=row, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
sparse_size=src.sparse_size(),
is_sorted=True)
return src.from_storage(storage) SparseTensor.masked_select = lambda self, dim, mask: masked_select(
self, dim, mask)
tmp = lambda src, mask, layout=None: masked_select_nnz( # noqa
src, mask, layout)
SparseTensor.masked_select_nnz = tmp
from typing import Optional
import torch import torch
from torch_scatter import gather_csr from torch_scatter import gather_csr
from torch_sparse.utils import is_scalar from torch_sparse.tensor import SparseTensor
def mul(src, other): @torch.jit.script
if is_scalar(other): def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return mul_nnz(src, other) rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
elif torch.is_tensor(other): # TODO
rowptr, col, value = src.csr() # other = gather_csr(other.squeeze(1), rowptr)
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... pass
other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
if src.has_value(): other = other.squeeze(0)[col]
value = other.mul_(src.storage.value) else:
else:
value = other
return src.set_value(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = other.mul_(src.storage.value)
else:
value = other
return src.set_value(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,' raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size ' f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.') f'{other.size()}.')
elif isinstance(other, src.__class__): if value is not None:
raise NotImplementedError value = other.mul_(value)
else:
raise ValueError('Argument `other` needs to be of type `int`, `float`, ' value = other
'`torch.tensor` or `torch_sparse.SparseTensor`.') return src.set_value(value, layout='coo')
def mul_(src, other): @torch.jit.script
if is_scalar(other): def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return mul_nnz_(src, other) rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
elif torch.is_tensor(other): # TODO
rowptr, col, value = src.csr() # other = gather_csr(other.squeeze(1), rowptr)
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... pass
other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
if src.has_value(): other = other.squeeze(0)[col]
value = src.storage.value.mul_(other) else:
else:
value = other
return src.set_value_(value, layout='csr')
if other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
if src.has_value():
value = src.storage.value.mul_(other)
else:
value = other
return src.set_value_(value, layout='coo')
raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,' raise ValueError(f'Size mismatch: Expected size ({src.size(0)}, 1,'
f' ...) or (1, {src.size(1)}, ...), but got size ' f' ...) or (1, {src.size(1)}, ...), but got size '
f'{other.size()}.') f'{other.size()}.')
elif isinstance(other, src.__class__): if value is not None:
raise NotImplementedError value = value.mul_(other)
else:
raise ValueError('Argument `other` needs to be of type `int`, `float`, ' value = other
'`torch.tensor` or `torch_sparse.SparseTensor`.') return src.set_value_(value, layout='coo')
def mul_nnz(src, other, layout=None): @torch.jit.script
if torch.is_tensor(other) or is_scalar(other): def mul_nnz(src: SparseTensor, other: torch.Tensor,
if src.has_value(): layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value * other value = src.storage.value()
else: if value is not None:
value = other value = value.mul(other)
return src.set_value(value, layout='coo') else:
value = other
raise ValueError('Argument `other` needs to be of type `int`, `float` or ' return src.set_value(value, layout=layout)
'`torch.tensor`.')
@torch.jit.script
def mul_nnz_(src, other, layout=None): def mul_nnz_(src: SparseTensor, other: torch.Tensor,
if torch.is_tensor(other) or is_scalar(other): layout: Optional[str] = None) -> SparseTensor:
if src.has_value(): value = src.storage.value()
value = src.storage.value.mul_(other) if value is not None:
else: value = value.mul_(other)
value = other else:
return src.set_value_(value, layout='coo') value = other
return src.set_value_(value, layout=layout)
raise ValueError('Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.')
SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout)
...@@ -4,7 +4,8 @@ from torch_sparse.tensor import SparseTensor ...@@ -4,7 +4,8 @@ from torch_sparse.tensor import SparseTensor
@torch.jit.script @torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, length: int): def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
start = src.size(dim) + start if start < 0 else start start = src.size(dim) + start if start < 0 else start
......
...@@ -4,7 +4,7 @@ from torch_sparse.narrow import narrow ...@@ -4,7 +4,7 @@ from torch_sparse.narrow import narrow
@torch.jit.script @torch.jit.script
def select(src: SparseTensor, dim: int, idx: int): def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1) return narrow(src, dim, start=idx, length=1)
......
...@@ -5,14 +5,6 @@ import torch ...@@ -5,14 +5,6 @@ import torch
import scipy.sparse import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout from torch_sparse.storage import SparseStorage, get_layout
# from torch_sparse.index_select import index_select, index_select_nnz
# from torch_sparse.masked_select import masked_select, masked_select_nnz
# from torch_sparse.diag import remove_diag, set_diag
# import torch_sparse.reduce
# from torch_sparse.matmul import matmul
# from torch_sparse.add import add, add_, add_nnz, add_nnz_
# from torch_sparse.mul import mul, mul_, mul_nnz, mul_nnz_
from torch_sparse.utils import is_scalar from torch_sparse.utils import is_scalar
...@@ -403,12 +395,6 @@ class SparseTensor(object): ...@@ -403,12 +395,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum') # return matmul(self, other, reduce='sum')
# SparseTensor.narrow = narrow
# SparseTensor.select = select
# SparseTensor.index_select = index_select
# SparseTensor.index_select_nnz = index_select_nnz
# SparseTensor.masked_select = masked_select
# SparseTensor.masked_select_nnz = masked_select_nnz
# SparseTensor.reduction = torch_sparse.reduce.reduction # SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum # SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean # SparseTensor.mean = torch_sparse.reduce.mean
...@@ -417,14 +403,6 @@ class SparseTensor(object): ...@@ -417,14 +403,6 @@ class SparseTensor(object):
# SparseTensor.remove_diag = remove_diag # SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag # SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul # SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add_ = add_
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz_ = add_nnz_
# SparseTensor.mul = mul
# SparseTensor.mul_ = mul_
# SparseTensor.mul_nnz = mul_nnz
# SparseTensor.mul_nnz_ = mul_nnz_
# Python Bindings ############################################################# # Python Bindings #############################################################
......
...@@ -6,7 +6,7 @@ from torch_sparse.tensor import SparseTensor ...@@ -6,7 +6,7 @@ from torch_sparse.tensor import SparseTensor
@torch.jit.script @torch.jit.script
def t(src: SparseTensor): def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc() csr2csc = src.storage.csr2csc()
row, col, value = src.coo() row, col, value = src.coo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment