Commit 97f2f4e9 authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #189 failed with stages
in 0 seconds
from typing import Optional
import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
def masked_select(src: SparseTensor, dim: int,
mask: torch.Tensor) -> SparseTensor:
dim = src.dim() + dim if dim < 0 else dim
assert mask.dim() == 1
storage = src.storage
if dim == 0:
row, col, value = src.coo()
rowcount = src.storage.rowcount()
rowcount = rowcount[mask]
mask = mask[row]
row = torch.arange(rowcount.size(0),
device=row.device).repeat_interleave(rowcount)
col = col[mask]
if value is not None:
value = value[mask]
sparse_sizes = (rowcount.size(0), src.sparse_size(1))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colcount=None, colptr=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
row, col, value = src.coo()
csr2csc = src.storage.csr2csc()
row = row[csr2csc]
col = col[csr2csc]
colcount = src.storage.colcount()
colcount = colcount[mask]
mask = mask[col]
col = torch.arange(colcount.size(0),
device=col.device).repeat_interleave(colcount)
row = row[mask]
csc2csr = (colcount.size(0) * row + col).argsort()
row, col = row[csc2csr], col[csc2csr]
if value is not None:
value = value[csr2csc][mask][csc2csr]
sparse_sizes = (src.sparse_size(0), colcount.size(0))
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colcount=colcount, colptr=None, csr2csc=None,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
idx = mask.nonzero().flatten()
return src.set_value(value.index_select(dim - 1, idx),
layout='coo')
else:
raise ValueError
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
assert mask.dim() == 1
if get_layout(layout) == 'csc':
mask = mask[src.storage.csc2csr()]
row, col, value = src.coo()
row, col = row[mask], col[mask]
if value is not None:
value = value[mask]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=src.sparse_sizes(), is_sorted=True)
SparseTensor.masked_select = lambda self, dim, mask: masked_select(
self, dim, mask)
tmp = lambda self, mask, layout=None: masked_select_nnz( # noqa
self, mask, layout)
SparseTensor.masked_select_nnz = tmp
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
csr2csc, other)
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
rowcount = src.storage._rowcount
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
rowcount = src.storage.rowcount()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
colptr, csr2csc, other)
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
def spmm(src: SparseTensor, other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return spmm_sum(src, other)
elif reduce == 'mean':
return spmm_mean(src, other)
elif reduce == 'min':
return spmm_min(src, other)[0]
elif reduce == 'max':
return spmm_max(src, other)[0]
else:
raise ValueError
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
assert src.sparse_size(1) == other.sparse_size(0)
rowptrA, colA, valueA = src.csr()
rowptrB, colB, valueB = other.csr()
value = valueA if valueA is not None else valueB
if valueA is not None and valueA.dtype == torch.half:
valueA = valueA.to(torch.float)
if valueB is not None and valueB.dtype == torch.half:
valueB = valueB.to(torch.float)
M, K = src.sparse_size(0), other.sparse_size(1)
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
if valueC is not None and value is not None:
valueC = valueC.to(value.dtype)
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
sparse_sizes=(M, K), is_sorted=True)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
def spspmm(src: SparseTensor, other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
return spspmm_sum(src, other)
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
raise NotImplementedError
else:
raise ValueError
@torch.jit._overload # noqa: F811
def matmul(src, other, reduce): # noqa: F811
# type: (SparseTensor, torch.Tensor, str) -> torch.Tensor
pass
@torch.jit._overload # noqa: F811
def matmul(src, other, reduce): # noqa: F811
# type: (SparseTensor, SparseTensor, str) -> SparseTensor
pass
def matmul(src, other, reduce="sum"): # noqa: F811
if isinstance(other, torch.Tensor):
return spmm(src, other, reduce)
elif isinstance(other, SparseTensor):
return spspmm(src, other, reduce)
raise ValueError
SparseTensor.spmm = lambda self, other, reduce="sum": spmm(self, other, reduce)
SparseTensor.spspmm = lambda self, other, reduce="sum": spspmm(
self, other, reduce)
SparseTensor.matmul = lambda self, other, reduce="sum": matmul(
self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
from typing import Tuple, Optional
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
sorted_weight = weight.sort()[0]
diff = sorted_weight[1:] - sorted_weight[:-1]
if diff.sum() == 0:
return None
weight_min, weight_max = sorted_weight[0], sorted_weight[-1]
srange = weight_max - weight_min
min_diff = diff.min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio()
weight_ratio = (weight - weight_min).div_(srange).mul_(arange).add_(tick)
return weight_ratio.to(torch.long)
def partition(
src: SparseTensor, num_parts: int, recursive: bool = False,
weighted: bool = False, node_weight: Optional[torch.Tensor] = None
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
assert num_parts >= 1
if num_parts == 1:
partptr = torch.tensor([0, src.size(0)], device=src.device())
perm = torch.arange(src.size(0), device=src.device())
return src, partptr, perm
rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu()
if value is not None and weighted:
assert value.numel() == col.numel()
value = value.view(-1).detach().cpu()
if value.is_floating_point():
value = weight2metis(value)
else:
value = None
if node_weight is not None:
assert node_weight.numel() == rowptr.numel() - 1
node_weight = node_weight.view(-1).detach().cpu()
if node_weight.is_floating_point():
node_weight = weight2metis(node_weight)
cluster = torch.ops.torch_sparse.partition2(rowptr, col, value,
node_weight, num_parts,
recursive)
else:
cluster = torch.ops.torch_sparse.partition(rowptr, col, value,
num_parts, recursive)
cluster = cluster.to(src.device())
cluster, perm = cluster.sort()
out = permute(src, perm)
partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
return out, partptr, perm
SparseTensor.partition = partition
from typing import Optional
import torch
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout='coo')
def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul(other.to(value.dtype))
else:
value = other
return src.set_value(value, layout=layout)
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout=layout)
SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout)
SparseTensor.__mul__ = SparseTensor.mul
SparseTensor.__rmul__ = SparseTensor.mul
SparseTensor.__imul__ = SparseTensor.mul_
from typing import Tuple
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
dim = src.dim() + dim
if start < 0:
start = src.size(dim) + start
if dim == 0:
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row = row[mask]
col = col[mask] - start
if value is not None:
value = value[mask]
sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
row_start = int(rowptr[0])
rowptr = rowptr - row_start
row_length = int(rowptr[-1])
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start[0]
col = col.narrow(0, row_start, row_length) - start[1]
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = length
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start[0], length[0])
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start[1], length[1] + 1)
colptr = colptr - int(colptr[0]) # i.e. `row_start`
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start[1], length[1])
csr2csc = src.storage._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start
csc2csr = src.storage._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
self, start, length)
from typing import Tuple, List
import torch
from torch_sparse.tensor import SparseTensor
def padded_index(src: SparseTensor, binptr: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
Tensor, List[int], List[int]]:
return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
src.storage.col(),
src.storage.rowcount(), binptr)
def padded_index_select(src: torch.Tensor, index: torch.Tensor,
fill_value: float = 0.) -> torch.Tensor:
fill_value = torch.tensor(fill_value, dtype=src.dtype)
return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)
SparseTensor.padded_index = padded_index
import torch
from torch_sparse.tensor import SparseTensor
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_quadratic()
return src.index_select(0, perm).index_select(1, perm)
SparseTensor.permute = lambda self, perm: permute(self, perm)
from typing import Optional
import torch
from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor
def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value()
if dim is None:
if value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum()
elif reduce == 'mean':
return value.mean()
elif reduce == 'min':
return value.min()
elif reduce == 'max':
return value.max()
else:
raise ValueError
else:
if reduce == 'sum' or reduce == 'add':
return torch.tensor(src.nnz(), dtype=src.dtype(),
device=src.device())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.tensor(1, dtype=src.dtype(), device=src.device())
else:
raise ValueError
else:
if dim < 0:
dim = src.dim() + dim
if dim == 0 and value is not None:
col = src.storage.col()
return scatter(value, col, 0, None, src.size(1), reduce)
elif dim == 0 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.colcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(1), dtype=src.dtype())
else:
raise ValueError
elif dim == 1 and value is not None:
return segment_csr(value, src.storage.rowptr(), None, reduce)
elif dim == 1 and value is None:
if reduce == 'sum' or reduce == 'add':
return src.storage.rowcount().to(src.dtype())
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
return torch.ones(src.size(0), dtype=src.dtype())
else:
raise ValueError
elif dim > 1 and value is not None:
if reduce == 'sum' or reduce == 'add':
return value.sum(dim=dim - 1)
elif reduce == 'mean':
return value.mean(dim=dim - 1)
elif reduce == 'min':
return value.min(dim=dim - 1)[0]
elif reduce == 'max':
return value.max(dim=dim - 1)[0]
else:
raise ValueError
else:
raise ValueError
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum')
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean')
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min')
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max')
SparseTensor.sum = lambda self, dim=None: sum(self, dim)
SparseTensor.mean = lambda self, dim=None: mean(self, dim)
SparseTensor.min = lambda self, dim=None: min(self, dim)
SparseTensor.max = lambda self, dim=None: max(self, dim)
import torch
from torch_sparse.tensor import SparseTensor
def random_walk(src: SparseTensor, start: torch.Tensor,
walk_length: int) -> torch.Tensor:
rowptr, col, _ = src.csr()
return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
SparseTensor.random_walk = random_walk
from typing import Tuple
import torch
from torch_sparse.tensor import SparseTensor
def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
) -> Tuple[SparseTensor, torch.Tensor]:
row, col, value = src.coo()
rowptr = src.storage.rowptr()
data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col)
row, col, edge_index = data
if value is not None:
value = value[edge_index]
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(node_idx.size(0), node_idx.size(0)),
is_sorted=True)
return out, edge_index
SparseTensor.saint_subgraph = saint_subgraph
from typing import Optional, Tuple
import torch
from torch_sparse.tensor import SparseTensor
def sample(src: SparseTensor, num_neighbors: int,
subset: Optional[torch.Tensor] = None) -> torch.Tensor:
rowptr, col, _ = src.csr()
rowcount = src.storage.rowcount()
if subset is not None:
rowcount = rowcount[subset]
rowptr = rowptr[subset]
else:
rowptr = rowptr[:-1]
rand = torch.rand((rowcount.size(0), num_neighbors), device=col.device)
rand.mul_(rowcount.to(rand.dtype).view(-1, 1))
rand = rand.to(torch.long)
rand.add_(rowptr.view(-1, 1))
return col[rand]
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, value = src.csr()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, subset, num_neighbors, replace)
if value is not None:
value = value[e_id]
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value,
sparse_sizes=(subset.size(0), n_id.size(0)),
is_sorted=True)
return out, n_id
SparseTensor.sample = sample
SparseTensor.sample_adj = sample_adj
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1)
SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)
import torch
from torch_sparse import coalesce
def spadd(indexA, valueA, indexB, valueB, m, n):
"""Matrix addition of two sparse matrices.
Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of the sparse matrices.
n (int): The second dimension of the sparse matrices.
"""
index = torch.cat([indexA, indexB], dim=-1)
value = torch.cat([valueA, valueB], dim=0)
return coalesce(index=index, value=value, m=m, n=n, op='add')
from torch import Tensor
from torch_scatter import scatter_add
def spmm(index: Tensor, value: Tensor, m: int, n: int,
matrix: Tensor) -> Tensor:
"""Matrix product of sparse matrix with dense matrix.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix.
:rtype: :class:`Tensor`
"""
assert n == matrix.size(-2)
row, col = index[0], index[1]
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix.index_select(-2, col)
out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=-2, dim_size=m)
return out
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.matmul import matmul
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""Matrix product of two sparse tensors. Both input sparse matrices need to
be coalesced (use the :obj:`coalesced` attribute to force).
Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of first sparse matrix.
k (int): The second dimension of first sparse matrix and first
dimension of second sparse matrix.
n (int): The second dimension of second sparse matrix.
coalesced (bool, optional): If set to :obj:`True`, will coalesce both
input sparse matrices. (default: :obj:`False`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
sparse_sizes=(m, k), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=(k, n), is_sorted=not coalesced)
C = matmul(A, B)
row, col, value = C.coo()
return torch.stack([row, col], dim=0), value
import warnings
from typing import Optional, List, Tuple
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout: Optional[str] = None) -> str:
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout == 'coo' or layout == 'csr' or layout == 'csc'
return layout
@torch.jit.script
class SparseStorage(object):
_row: Optional[torch.Tensor]
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_sizes: Tuple[int, int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False,
trust_data: bool = False):
assert row is not None or rowptr is not None
assert col is not None
assert col.dtype == torch.long
assert col.dim() == 1
col = col.contiguous()
M: int = 0
if sparse_sizes is None or sparse_sizes[0] is None:
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None and row.numel() > 0:
M = int(row.max()) + 1
else:
_M = sparse_sizes[0]
assert _M is not None
M = _M
if rowptr is not None:
assert rowptr.numel() - 1 == M
elif row is not None and row.numel() > 0:
assert trust_data or int(row.max()) < M
N: int = 0
if sparse_sizes is None or sparse_sizes[1] is None:
if col.numel() > 0:
N = int(col.max()) + 1
else:
_N = sparse_sizes[1]
assert _N is not None
N = _N
if col.numel() > 0:
assert trust_data or int(col.max()) < N
sparse_sizes = (M, N)
if row is not None:
assert row.dtype == torch.long
assert row.device == col.device
assert row.dim() == 1
assert row.numel() == col.numel()
row = row.contiguous()
if rowptr is not None:
assert rowptr.dtype == torch.long
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_sizes[0]
rowptr = rowptr.contiguous()
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_sizes[0]
rowcount = rowcount.contiguous()
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_sizes[1]
colptr = colptr.contiguous()
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_sizes[1]
colcount = colcount.contiguous()
if csr2csc is not None:
assert csr2csc.dtype == torch.long
assert csr2csc.device == col.device
assert csr2csc.dim() == 1
assert csr2csc.numel() == col.size(0)
csr2csc = csr2csc.contiguous()
if csc2csr is not None:
assert csc2csr.dtype == torch.long
assert csc2csr.device == col.device
assert csc2csr.dim() == 1
assert csc2csr.numel() == col.size(0)
csc2csr = csc2csr.contiguous()
self._row = row
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_sizes = tuple(sparse_sizes)
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
self._csr2csc = csr2csc
self._csc2csr = csc2csr
if not is_sorted:
idx = self._col.new_zeros(self._col.numel() + 1)
idx[1:] = self.row()
idx[1:] *= self._sparse_sizes[1]
idx[1:] += self._col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row()[perm]
self._col = self._col[perm]
if value is not None:
self._value = value[perm]
self._csr2csc = None
self._csc2csr = None
@classmethod
def empty(self):
row = torch.tensor([], dtype=torch.long)
col = torch.tensor([], dtype=torch.long)
return SparseStorage(row=row, rowptr=None, col=col, value=None,
sparse_sizes=(0, 0), rowcount=None, colptr=None,
colcount=None, csr2csc=None, csc2csr=None,
is_sorted=True, trust_data=True)
def has_row(self) -> bool:
return self._row is not None
def row(self):
row = self._row
if row is not None:
return row
rowptr = self._rowptr
if rowptr is not None:
row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
self._row = row
return row
raise ValueError
def has_rowptr(self) -> bool:
return self._rowptr is not None
def rowptr(self) -> torch.Tensor:
rowptr = self._rowptr
if rowptr is not None:
return rowptr
row = self._row
if row is not None:
rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
self._rowptr = rowptr
return rowptr
raise ValueError
def col(self) -> torch.Tensor:
return self._col
def has_value(self) -> bool:
return self._value is not None
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
self._value = value
return self
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
elif diff_0 < 0:
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
elif diff_1 < 0:
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def sparse_reshape(self, num_rows: int, num_cols: int):
assert num_rows > 0 or num_rows == -1
assert num_cols > 0 or num_cols == -1
assert num_rows > 0 or num_cols > 0
total = self.sparse_size(0) * self.sparse_size(1)
if num_rows == -1:
num_rows = total // num_cols
if num_cols == -1:
num_cols = total // num_rows
assert num_rows * num_cols == total
idx = self.sparse_size(1) * self.row() + self.col()
row = torch.div(idx, num_cols, rounding_mode='floor')
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
sparse_sizes=(num_rows, num_cols), rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
def rowcount(self) -> torch.Tensor:
rowcount = self._rowcount
if rowcount is not None:
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[:-1]
self._rowcount = rowcount
return rowcount
def has_colptr(self) -> bool:
return self._colptr is not None
def colptr(self) -> torch.Tensor:
colptr = self._colptr
if colptr is not None:
return colptr
csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse.ind2ptr(self._col[csr2csc],
self._sparse_sizes[1])
else:
colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
def has_colcount(self) -> bool:
return self._colcount is not None
def colcount(self) -> torch.Tensor:
colcount = self._colcount
if colcount is not None:
return colcount
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
def has_csr2csc(self) -> bool:
return self._csr2csc is not None
def csr2csc(self) -> torch.Tensor:
csr2csc = self._csr2csc
if csr2csc is not None:
return csr2csc
idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
def has_csc2csr(self) -> bool:
return self._csc2csr is not None
def csc2csr(self) -> torch.Tensor:
csc2csr = self._csc2csr
if csc2csr is not None:
return csc2csr
csc2csr = self.csr2csc().argsort()
self._csc2csr = csc2csr
return csc2csr
def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return self
row = self.row()[mask]
col = self._col[mask]
value = self._value
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True, trust_data=True)
def fill_cache_(self):
self.row()
self.rowptr()
self.rowcount()
self.colptr()
self.colcount()
self.csr2csc()
self.csc2csr()
return self
def clear_cache_(self):
self._rowcount = None
self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def cached_keys(self) -> List[str]:
keys: List[str] = []
if self.has_rowcount():
keys.append('rowcount')
if self.has_colptr():
keys.append('colptr')
if self.has_colcount():
keys.append('colcount')
if self.has_csr2csc():
keys.append('csr2csc')
if self.has_csc2csr():
keys.append('csc2csr')
return keys
def num_cached_keys(self) -> int:
return len(self.cached_keys())
def copy(self):
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True,
trust_data=True)
def clone(self):
row = self._row
if row is not None:
row = row.clone()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.clone()
col = self._col.clone()
value = self._value
if value is not None:
value = value.clone()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
colptr = self._colptr
if colptr is not None:
colptr = colptr.clone()
colcount = self._colcount
if colcount is not None:
colcount = colcount.clone()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.clone()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self._value
if value is not None:
if dtype == value.dtype:
return self
else:
return self.set_value(
value.to(
dtype=dtype,
non_blocking=non_blocking),
layout='coo')
else:
return self
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(device, non_blocking=non_blocking)
col = self._col.to(device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
def cuda(self):
new_col = self._col.cuda()
if new_col.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.cuda()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.cuda()
value = self._value
if value is not None:
value = value.cuda()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.cuda()
colptr = self._colptr
if colptr is not None:
colptr = colptr.cuda()
colcount = self._colcount
if colcount is not None:
colcount = colcount.cuda()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.cuda()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.cuda()
return SparseStorage(row=row, rowptr=rowptr, col=new_col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def pin_memory(self):
row = self._row
if row is not None:
row = row.pin_memory()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.pin_memory()
col = self._col.pin_memory()
value = self._value
if value is not None:
value = value.pin_memory()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.pin_memory()
colptr = self._colptr
if colptr is not None:
colptr = colptr.pin_memory()
colcount = self._colcount
if colcount is not None:
colcount = colcount.pin_memory()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.pin_memory()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True, trust_data=True)
def is_pinned(self) -> bool:
is_pinned = True
row = self._row
if row is not None:
is_pinned = is_pinned and row.is_pinned()
rowptr = self._rowptr
if rowptr is not None:
is_pinned = is_pinned and rowptr.is_pinned()
is_pinned = self._col.is_pinned()
value = self._value
if value is not None:
is_pinned = is_pinned and value.is_pinned()
rowcount = self._rowcount
if rowcount is not None:
is_pinned = is_pinned and rowcount.is_pinned()
colptr = self._colptr
if colptr is not None:
is_pinned = is_pinned and colptr.is_pinned()
colcount = self._colcount
if colcount is not None:
is_pinned = is_pinned and colcount.is_pinned()
csr2csc = self._csr2csc
if csr2csc is not None:
is_pinned = is_pinned and csr2csc.is_pinned()
csc2csr = self._csc2csr
if csc2csr is not None:
is_pinned = is_pinned and csc2csr.is_pinned()
return is_pinned
def share_memory_(self) -> SparseStorage:
row = self._row
if row is not None:
row.share_memory_()
rowptr = self._rowptr
if rowptr is not None:
rowptr.share_memory_()
self._col.share_memory_()
value = self._value
if value is not None:
value.share_memory_()
rowcount = self._rowcount
if rowcount is not None:
rowcount.share_memory_()
colptr = self._colptr
if colptr is not None:
colptr.share_memory_()
colcount = self._colcount
if colcount is not None:
colcount.share_memory_()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc.share_memory_()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr.share_memory_()
def is_shared(self) -> bool:
is_shared = True
row = self._row
if row is not None:
is_shared = is_shared and row.is_shared()
rowptr = self._rowptr
if rowptr is not None:
is_shared = is_shared and rowptr.is_shared()
is_shared = is_shared and self._col.is_shared()
value = self._value
if value is not None:
is_shared = is_shared and value.is_shared()
rowcount = self._rowcount
if rowcount is not None:
is_shared = is_shared and rowcount.is_shared()
colptr = self._colptr
if colptr is not None:
is_shared = is_shared and colptr.is_shared()
colcount = self._colcount
if colcount is not None:
is_shared = is_shared and colcount.is_shared()
csr2csc = self._csr2csc
if csr2csc is not None:
is_shared = is_shared and csr2csc.is_shared()
csc2csr = self._csc2csr
if csc2csr is not None:
is_shared = is_shared and csc2csr.is_shared()
return is_shared
SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared
from textwrap import indent
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import scipy.sparse
import torch
from torch_scatter import segment_csr
from torch_sparse.storage import SparseStorage, get_layout
@torch.jit.script
class SparseTensor(object):
storage: SparseStorage
def __init__(
self,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
self.storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=is_sorted,
trust_data=trust_data,
)
@classmethod
def from_storage(self, storage: SparseStorage):
out = SparseTensor(
row=storage._row,
rowptr=storage._rowptr,
col=storage._col,
value=storage._value,
sparse_sizes=storage._sparse_sizes,
is_sorted=True,
trust_data=True,
)
out.storage._rowcount = storage._rowcount
out.storage._colptr = storage._colptr
out.storage._colcount = storage._colcount
out.storage._csr2csc = storage._csr2csc
out.storage._csc2csr = storage._csc2csr
return out
@classmethod
def from_edge_index(
self,
edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted, trust_data=trust_data)
@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
else:
index = mat.nonzero()
index = index.t()
row = index[0]
col = index[1]
value: Optional[torch.Tensor] = None
if has_value:
value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
has_value: bool = True):
mat = mat.coalesce()
index = mat._indices()
row, col = index[0], index[1]
value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
has_value: bool = True):
rowptr = mat.crow_indices()
col = mat.col_indices()
value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()
return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
dtype: Optional[int] = None, device: Optional[torch.device] = None,
fill_cache: bool = False):
N = M if N is None else N
row = torch.arange(min(M, N), device=device)
col = row
rowptr = torch.arange(M + 1, device=row.device)
if M > N:
rowptr[N + 1:] = N
value: Optional[torch.Tensor] = None
if has_value:
value = torch.ones(row.numel(), dtype=dtype, device=row.device)
rowcount: Optional[torch.Tensor] = None
colptr: Optional[torch.Tensor] = None
colcount: Optional[torch.Tensor] = None
csr2csc: Optional[torch.Tensor] = None
csc2csr: Optional[torch.Tensor] = None
if fill_cache:
rowcount = torch.ones(M, dtype=torch.long, device=row.device)
if M > N:
rowcount[N:] = 0
colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
colcount = torch.ones(N, dtype=torch.long, device=row.device)
if N > M:
colptr[M + 1:] = M
colcount[M:] = 0
csr2csc = csc2csr = row
out = SparseTensor(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(M, N),
is_sorted=True,
trust_data=True,
)
out.storage._rowcount = rowcount
out.storage._colptr = colptr
out.storage._colcount = colcount
out.storage._csr2csc = csr2csc
out.storage._csc2csr = csc2csr
return out
def copy(self):
return self.from_storage(self.storage)
def clone(self):
return self.from_storage(self.storage.clone())
def type(self, dtype: torch.dtype, non_blocking: bool = False):
value = self.storage.value()
if value is None or dtype == value.dtype:
return self
return self.from_storage(
self.storage.type(dtype=dtype, non_blocking=non_blocking))
def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.type(dtype=tensor.dtype, non_blocking=non_blocking)
def to_device(self, device: torch.device, non_blocking: bool = False):
if device == self.device():
return self
return self.from_storage(
self.storage.to_device(device=device, non_blocking=non_blocking))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
return self.to_device(device=tensor.device, non_blocking=non_blocking)
# Formats #################################################################
def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.row(), self.storage.col(), self.storage.value()
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.rowptr(), self.storage.col(), self.storage.value()
def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
perm = self.storage.csr2csc()
value = self.storage.value()
if value is not None:
value = value[perm]
return self.storage.colptr(), self.storage.row()[perm], value
# Storage inheritance #####################################################
def has_value(self) -> bool:
return self.storage.has_value()
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
self.storage.set_value_(value, layout)
return self
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> Tuple[int, int]:
return self.storage.sparse_sizes()
def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim]
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def sparse_reshape(self, num_rows: int, num_cols: int):
return self.from_storage(
self.storage.sparse_reshape(num_rows, num_cols))
def is_coalesced(self) -> bool:
return self.storage.is_coalesced()
def coalesce(self, reduce: str = "sum"):
return self.from_storage(self.storage.coalesce(reduce))
def fill_cache_(self):
self.storage.fill_cache_()
return self
def clear_cache_(self):
self.storage.clear_cache_()
return self
def __eq__(self, other) -> bool:
if not isinstance(other, self.__class__):
return False
if self.sizes() != other.sizes():
return False
rowptrA, colA, valueA = self.csr()
rowptrB, colB, valueB = other.csr()
if valueA is None and valueB is not None:
return False
if valueA is not None and valueB is None:
return False
if not torch.equal(rowptrA, rowptrB):
return False
if not torch.equal(colA, colB):
return False
if valueA is None and valueB is None:
return True
return torch.equal(valueA, valueB)
# Utility functions #######################################################
def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
return self.set_value_(value, layout='coo')
def fill_value(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
return self.set_value(value, layout='coo')
def sizes(self) -> List[int]:
sparse_sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
return list(sparse_sizes) + list(value.size())[1:]
else:
return list(sparse_sizes)
def size(self, dim: int) -> int:
return self.sizes()[dim]
def dim(self) -> int:
return len(self.sizes())
def nnz(self) -> int:
return self.storage.col().numel()
def numel(self) -> int:
value = self.storage.value()
if value is not None:
return value.numel()
else:
return self.nnz()
def density(self) -> float:
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
def sparsity(self) -> float:
return 1 - self.density()
def avg_row_length(self) -> float:
return self.nnz() / self.sparse_size(0)
def avg_col_length(self) -> float:
return self.nnz() / self.sparse_size(1)
def bandwidth(self) -> int:
row, col, _ = self.coo()
return int((row - col).abs_().max())
def avg_bandwidth(self) -> float:
row, col, _ = self.coo()
return float((row - col).abs_().to(torch.float).mean())
def bandwidth_proportion(self, bandwidth: int) -> float:
row, col, _ = self.coo()
tmp = (row - col).abs_()
return int((tmp <= bandwidth).sum()) / self.nnz()
def is_quadratic(self) -> bool:
return self.sparse_size(0) == self.sparse_size(1)
def is_symmetric(self) -> bool:
if not self.is_quadratic():
return False
rowptr, col, value1 = self.csr()
colptr, row, value2 = self.csc()
if (rowptr != colptr).any() or (col != row).any():
return False
if value1 is None or value2 is None:
return True
else:
return bool((value1 == value2).all())
def to_symmetric(self, reduce: str = "sum"):
N = max(self.size(0), self.size(1))
row, col, value = self.coo()
idx = col.new_full((2 * col.numel() + 1, ), -1)
idx[1:row.numel() + 1] = row
idx[row.numel() + 1:] = col
idx[1:] *= N
idx[1:row.numel() + 1] += col
idx[row.numel() + 1:] += row
idx, perm = idx.sort()
mask = idx[1:] > idx[:-1]
perm = perm[1:].sub_(1)
idx = perm[mask]
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))])
value = torch.cat([value, value])[perm]
value = segment_csr(value, ptr, reduce=reduce)
new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
out = SparseTensor(
row=new_row,
rowptr=None,
col=new_col,
value=value,
sparse_sizes=(N, N),
is_sorted=True,
trust_data=True,
)
return out
def detach_(self):
value = self.storage.value()
if value is not None:
value.detach_()
return self
def detach(self):
value = self.storage.value()
if value is not None:
value = value.detach()
return self.set_value(value, layout='coo')
def requires_grad(self) -> bool:
value = self.storage.value()
if value is not None:
return value.requires_grad
else:
return False
def requires_grad_(self, requires_grad: bool = True,
dtype: Optional[int] = None):
if requires_grad and not self.has_value():
self.fill_value_(1., dtype)
value = self.storage.value()
if value is not None:
value.requires_grad_(requires_grad)
return self
def pin_memory(self):
return self.from_storage(self.storage.pin_memory())
def is_pinned(self) -> bool:
return self.storage.is_pinned()
def device(self):
return self.storage.col().device
def cpu(self):
return self.to_device(device=torch.device('cpu'), non_blocking=False)
def cuda(self):
return self.from_storage(self.storage.cuda())
def is_cuda(self) -> bool:
return self.storage.col().is_cuda
def dtype(self):
value = self.storage.value()
return value.dtype if value is not None else torch.float
def is_floating_point(self) -> bool:
value = self.storage.value()
return torch.is_floating_point(value) if value is not None else True
def bfloat16(self):
return self.type(dtype=torch.bfloat16, non_blocking=False)
def bool(self):
return self.type(dtype=torch.bool, non_blocking=False)
def byte(self):
return self.type(dtype=torch.uint8, non_blocking=False)
def char(self):
return self.type(dtype=torch.int8, non_blocking=False)
def half(self):
return self.type(dtype=torch.half, non_blocking=False)
def float(self):
return self.type(dtype=torch.float, non_blocking=False)
def double(self):
return self.type(dtype=torch.double, non_blocking=False)
def short(self):
return self.type(dtype=torch.short, non_blocking=False)
def int(self):
return self.type(dtype=torch.int, non_blocking=False)
def long(self):
return self.type(dtype=torch.long, non_blocking=False)
# Conversions #############################################################
def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
row, col, value = self.coo()
if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device())
else:
mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
if value is not None:
mat[row, col] = value
else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
device=mat.device)
return mat
def to_torch_sparse_coo_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
return torch.sparse_coo_tensor(index, value, self.sizes())
def to_torch_sparse_csr_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
rowptr, col, value = self.csr()
if value is None:
value = torch.ones(self.nnz(), dtype=dtype, device=self.device())
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
# Python Bindings #############################################################
def share_memory_(self: SparseTensor) -> SparseTensor:
self.storage.share_memory_()
return self
def is_shared(self: SparseTensor) -> bool:
return self.storage.is_shared()
def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None:
self = self.type(dtype=dtype, non_blocking=non_blocking)
if device is not None:
self = self.to_device(device=device, non_blocking=non_blocking)
return self
def cpu(self) -> SparseTensor:
return self.device_as(torch.tensor(0., device='cpu'))
def cuda(self, device: Optional[Union[int, str]] = None,
non_blocking: bool = False):
return self.device_as(torch.tensor(0., device=device or 'cuda'))
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([
i for i in index
if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
]) > 1:
raise SyntaxError
dim = 0
out = self
while len(index) > 0:
item = index.pop(0)
if isinstance(item, (list, tuple)):
item = torch.tensor(item, device=self.device())
if isinstance(item, np.ndarray):
item = torch.from_numpy(item).to(self.device())
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
elif isinstance(item, slice):
if item.step is not None:
raise ValueError('Step parameter not yet supported.')
start = 0 if item.start is None else item.start
start = self.size(dim) + start if start < 0 else start
stop = self.size(dim) if item.stop is None else item.stop
stop = self.size(dim) + stop if stop < 0 else stop
out = out.narrow(dim, start, max(stop - start, 0))
dim += 1
elif torch.is_tensor(item):
if item.dtype == torch.bool:
out = out.masked_select(dim, item)
dim += 1
elif item.dtype == torch.long:
out = out.index_select(dim, item)
dim += 1
elif item == Ellipsis:
if self.dim() - len(index) < dim:
raise SyntaxError
dim = self.dim() - len(index)
else:
raise SyntaxError
return out
def __repr__(self: SparseTensor) -> str:
i = ' ' * 6
row, col, value = self.coo()
infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if value is not None:
infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
infos += [
f'size={tuple(self.sizes())}, nnz={self.nnz()}, '
f'density={100 * self.density():.02f}%'
]
infos = ',\n'.join(infos)
i = ' ' * (len(self.__class__.__name__) + 1)
return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
SparseTensor.cpu = cpu
SparseTensor.cuda = cuda
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix,
scipy.sparse.csc_matrix]
@torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
colptr = None
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr()
rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
value = None
if has_value:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return SparseTensor.from_storage(storage)
@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
assert self.dim() == 2
layout = get_layout(layout)
if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
row, col, value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
elif layout == 'csr':
rowptr, col, value = self.csr()
rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
elif layout == 'csc':
colptr, row, value = self.csc()
colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())
SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
sparse_sizes = src.storage.sparse_sizes()
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
csr2csc=src.storage._csc2csr,
csc2csr=csr2csc,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row, col = index
row, col = col, row
if coalesced:
sparse_sizes = (n, m)
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
row, col, value = storage.row(), storage.col(), storage.value()
return torch.stack([row, col], dim=0), value
from typing import Any
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment