Unverified Commit 7671fcb0 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #33 from rusty1s/adj

[WIP] SparseTensor Format
parents 1fb5fa4f 704ad420
import warnings
import os.path as osp
from typing import Optional, List
import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_convert.so'))
except OSError:
warnings.warn('Failed to load `convert` binaries.')
def ind2ptr_placeholder(ind: torch.Tensor, M: int) -> torch.Tensor:
raise ImportError
return ind
def ptr2ind_placeholder(ptr: torch.Tensor, E: int) -> torch.Tensor:
raise ImportError
return ptr
torch.ops.torch_sparse.ind2ptr = ind2ptr_placeholder
torch.ops.torch_sparse.ptr2ind = ptr2ind_placeholder
layouts: Final[List[str]] = ['coo', 'csr', 'csc']
def get_layout(layout: Optional[str] = None) -> str:
if layout is None:
layout = 'coo'
warnings.warn('`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.')
assert layout == 'coo' or layout == 'csr' or layout == 'csc'
return layout
@torch.jit.script
class SparseStorage(object):
_row: Optional[torch.Tensor]
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_sizes: List[int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[List[int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
csr2csc: Optional[torch.Tensor] = None,
csc2csr: Optional[torch.Tensor] = None,
is_sorted: bool = False):
assert row is not None or rowptr is not None
assert col is not None
assert col.dtype == torch.long
assert col.dim() == 1
col = col.contiguous()
if sparse_sizes is None:
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None:
M = row.max().item() + 1
else:
raise ValueError
N = col.max().item() + 1
sparse_sizes = torch.Size([int(M), int(N)])
else:
assert len(sparse_sizes) == 2
if row is not None:
assert row.dtype == torch.long
assert row.device == col.device
assert row.dim() == 1
assert row.numel() == col.numel()
row = row.contiguous()
if rowptr is not None:
assert rowptr.dtype == torch.long
assert rowptr.device == col.device
assert rowptr.dim() == 1
assert rowptr.numel() - 1 == sparse_sizes[0]
rowptr = rowptr.contiguous()
if value is not None:
assert value.device == col.device
assert value.size(0) == col.size(0)
value = value.contiguous()
if rowcount is not None:
assert rowcount.dtype == torch.long
assert rowcount.device == col.device
assert rowcount.dim() == 1
assert rowcount.numel() == sparse_sizes[0]
rowcount = rowcount.contiguous()
if colptr is not None:
assert colptr.dtype == torch.long
assert colptr.device == col.device
assert colptr.dim() == 1
assert colptr.numel() - 1 == sparse_sizes[1]
colptr = colptr.contiguous()
if colcount is not None:
assert colcount.dtype == torch.long
assert colcount.device == col.device
assert colcount.dim() == 1
assert colcount.numel() == sparse_sizes[1]
colcount = colcount.contiguous()
if csr2csc is not None:
assert csr2csc.dtype == torch.long
assert csr2csc.device == col.device
assert csr2csc.dim() == 1
assert csr2csc.numel() == col.size(0)
csr2csc = csr2csc.contiguous()
if csc2csr is not None:
assert csc2csr.dtype == torch.long
assert csc2csr.device == col.device
assert csc2csr.dim() == 1
assert csc2csr.numel() == col.size(0)
csc2csr = csc2csr.contiguous()
self._row = row
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_sizes = sparse_sizes
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
self._csr2csc = csr2csc
self._csc2csr = csc2csr
if not is_sorted:
idx = self._col.new_zeros(self._col.numel() + 1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
self._row = self.row()[perm]
self._col = self._col[perm]
if value is not None:
self._value = value[perm]
self._csr2csc = None
self._csc2csr = None
@classmethod
def empty(self):
self = SparseStorage.__new__(SparseStorage)
self._row = None
self._rowptr = None
self._value = None
self._rowcount = None
self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def has_row(self) -> bool:
return self._row is not None
def row(self):
row = self._row
if row is not None:
return row
rowptr = self._rowptr
if rowptr is not None:
row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
self._row = row
return row
raise ValueError
def has_rowptr(self) -> bool:
return self._rowptr is not None
def rowptr(self) -> torch.Tensor:
rowptr = self._rowptr
if rowptr is not None:
return rowptr
row = self._row
if row is not None:
rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
self._rowptr = rowptr
return rowptr
raise ValueError
def col(self) -> torch.Tensor:
return self._col
def has_value(self) -> bool:
return self._value is not None
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
self._value = value
return self
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def sparse_sizes(self) -> List[int]:
return self._sparse_sizes
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: List[int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
rowcount, rowptr = self._rowcount, self._rowptr
if diff_0 > 0:
if rowptr is not None:
rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if rowcount is not None:
rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
else:
if rowptr is not None:
rowptr = rowptr[:-diff_0]
if rowcount is not None:
rowcount = rowcount[:-diff_0]
diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
colcount, colptr = self._colcount, self._colptr
if diff_1 > 0:
if colptr is not None:
colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if colcount is not None:
colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
else:
if colptr is not None:
colptr = colptr[:-diff_1]
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
def rowcount(self) -> torch.Tensor:
rowcount = self._rowcount
if rowcount is not None:
return rowcount
rowptr = self.rowptr()
rowcount = rowptr[1:] - rowptr[:-1]
self._rowcount = rowcount
return rowcount
def has_colptr(self) -> bool:
return self._colptr is not None
def colptr(self) -> torch.Tensor:
colptr = self._colptr
if colptr is not None:
return colptr
csr2csc = self._csr2csc
if csr2csc is not None:
colptr = torch.ops.torch_sparse.ind2ptr(self._col[csr2csc],
self._sparse_sizes[1])
else:
colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
self._colptr = colptr
return colptr
def has_colcount(self) -> bool:
return self._colcount is not None
def colcount(self) -> torch.Tensor:
colcount = self._colcount
if colcount is not None:
return colcount
colptr = self._colptr
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
def has_csr2csc(self) -> bool:
return self._csr2csc is not None
def csr2csc(self) -> torch.Tensor:
csr2csc = self._csr2csc
if csr2csc is not None:
return csr2csc
idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort()
self._csr2csc = csr2csc
return csr2csc
def has_csc2csr(self) -> bool:
return self._csc2csr is not None
def csc2csr(self) -> torch.Tensor:
csc2csr = self._csc2csr
if csc2csr is not None:
return csc2csr
csc2csr = self.csr2csc().argsort()
self._csc2csr = csc2csr
return csc2csr
def is_coalesced(self) -> bool:
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
return bool((idx[1:] > idx[:-1]).all())
def coalesce(self, reduce: str = "add"):
idx = self._col.new_full((self._col.numel() + 1, ), -1)
idx[1:] = self._sparse_sizes[1] * self.row() + self._col
mask = idx[1:] > idx[:-1]
if mask.all(): # Skip if indices are already coalesced.
return self
row = self.row()[mask]
col = self._col[mask]
value = self._value
if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def fill_cache_(self):
self.row()
self.rowptr()
self.rowcount()
self.colptr()
self.colcount()
self.csr2csc()
self.csc2csr()
return self
def clear_cache_(self):
self._rowcount = None
self._colptr = None
self._colcount = None
self._csr2csc = None
self._csc2csr = None
return self
def num_cached_keys(self) -> int:
count = 0
if self.has_rowcount():
count += 1
if self.has_colptr():
count += 1
if self.has_colcount():
count += 1
if self.has_csr2csc():
count += 1
if self.has_csc2csr():
count += 1
return count
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def clone(self):
row = self._row
if row is not None:
row = row.clone()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.clone()
col = self._col.clone()
value = self._value
if value is not None:
value = value.clone()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.clone()
colptr = self._colptr
if colptr is not None:
colptr = colptr.clone()
colcount = self._colcount
if colcount is not None:
colcount = colcount.clone()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.clone()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
if value is not None:
if tensor.dtype == value.dtype:
return self
else:
return self.set_value(value.type_as(tensor), layout='coo')
else:
return self
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self._col.device:
return self
row = self._row
if row is not None:
row = row.to(tensor.device, non_blocking=non_blocking)
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.to(tensor.device, non_blocking=non_blocking)
col = self._col.to(tensor.device, non_blocking=non_blocking)
value = self._value
if value is not None:
value = value.to(tensor.device, non_blocking=non_blocking)
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.to(tensor.device, non_blocking=non_blocking)
colptr = self._colptr
if colptr is not None:
colptr = colptr.to(tensor.device, non_blocking=non_blocking)
colcount = self._colcount
if colcount is not None:
colcount = colcount.to(tensor.device, non_blocking=non_blocking)
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.to(tensor.device, non_blocking=non_blocking)
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self):
row = self._row
if row is not None:
row = row.pin_memory()
rowptr = self._rowptr
if rowptr is not None:
rowptr = rowptr.pin_memory()
col = self._col.pin_memory()
value = self._value
if value is not None:
value = value.pin_memory()
rowcount = self._rowcount
if rowcount is not None:
rowcount = rowcount.pin_memory()
colptr = self._colptr
if colptr is not None:
colptr = colptr.pin_memory()
colcount = self._colcount
if colcount is not None:
colcount = colcount.pin_memory()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.pin_memory()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
row = self._row
if row is not None:
is_pinned = is_pinned and row.is_pinned()
rowptr = self._rowptr
if rowptr is not None:
is_pinned = is_pinned and rowptr.is_pinned()
is_pinned = self._col.is_pinned()
value = self._value
if value is not None:
is_pinned = is_pinned and value.is_pinned()
rowcount = self._rowcount
if rowcount is not None:
is_pinned = is_pinned and rowcount.is_pinned()
colptr = self._colptr
if colptr is not None:
is_pinned = is_pinned and colptr.is_pinned()
colcount = self._colcount
if colcount is not None:
is_pinned = is_pinned and colcount.is_pinned()
csr2csc = self._csr2csc
if csr2csc is not None:
is_pinned = is_pinned and csr2csc.is_pinned()
csc2csr = self._csc2csr
if csc2csr is not None:
is_pinned = is_pinned and csc2csr.is_pinned()
return is_pinned
def share_memory_(self) -> SparseStorage:
row = self._row
if row is not None:
row.share_memory_()
rowptr = self._rowptr
if rowptr is not None:
rowptr.share_memory_()
self._col.share_memory_()
value = self._value
if value is not None:
value.share_memory_()
rowcount = self._rowcount
if rowcount is not None:
rowcount.share_memory_()
colptr = self._colptr
if colptr is not None:
colptr.share_memory_()
colcount = self._colcount
if colcount is not None:
colcount.share_memory_()
csr2csc = self._csr2csc
if csr2csc is not None:
csr2csc.share_memory_()
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr.share_memory_()
def is_shared(self) -> bool:
is_shared = True
row = self._row
if row is not None:
is_shared = is_shared and row.is_shared()
rowptr = self._rowptr
if rowptr is not None:
is_shared = is_shared and rowptr.is_shared()
is_shared = is_shared and self._col.is_shared()
value = self._value
if value is not None:
is_shared = is_shared and value.is_shared()
rowcount = self._rowcount
if rowcount is not None:
is_shared = is_shared and rowcount.is_shared()
colptr = self._colptr
if colptr is not None:
is_shared = is_shared and colptr.is_shared()
colcount = self._colcount
if colcount is not None:
is_shared = is_shared and colcount.is_shared()
csr2csc = self._csr2csc
if csr2csc is not None:
is_shared = is_shared and csr2csc.is_shared()
csc2csr = self._csc2csr
if csc2csr is not None:
is_shared = is_shared and csc2csr.is_shared()
return is_shared
SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared
from textwrap import indent
from typing import Optional, List, Tuple, Dict, Union, Any
import torch
import scipy.sparse
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.utils import is_scalar
@torch.jit.script
class SparseTensor(object):
storage: SparseStorage
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: List[int] = None, is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None,
csr2csc=None, csc2csr=None,
is_sorted=is_sorted)
@classmethod
def from_storage(self, storage: SparseStorage):
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
return self
@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if mat.dim() > 2:
index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero()
else:
index = mat.nonzero()
index = index.t()
row = index[0]
col = index[1]
value: Optional[torch.Tensor] = None
if has_value:
value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
has_value: bool = True):
mat = mat.coalesce()
index = mat._indices()
row, col = index[0], index[1]
value: Optional[torch.Tensor] = None
if has_value:
value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None,
options: Optional[torch.Tensor] = None, has_value: bool = True,
fill_cache: bool = False):
N = M if N is None else N
if options is not None:
row = torch.arange(min(M, N), device=options.device)
else:
row = torch.arange(min(M, N))
col = row
rowptr = torch.arange(M + 1, dtype=torch.long, device=row.device)
if M > N:
rowptr[N + 1:] = N
value: Optional[torch.Tensor] = None
if has_value:
if options is not None:
value = torch.ones(row.numel(), dtype=options.dtype,
device=row.device)
else:
value = torch.ones(row.numel(), device=row.device)
rowcount: Optional[torch.Tensor] = None
colptr: Optional[torch.Tensor] = None
colcount: Optional[torch.Tensor] = None
csr2csc: Optional[torch.Tensor] = None
csc2csr: Optional[torch.Tensor] = None
if fill_cache:
rowcount = torch.ones(M, dtype=torch.long, device=row.device)
if M > N:
rowcount[N:] = 0
colptr = torch.arange(N + 1, dtype=torch.long, device=row.device)
colcount = torch.ones(N, dtype=torch.long, device=row.device)
if N > M:
colptr[M + 1:] = M
colcount[M:] = 0
csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr,
is_sorted=True)
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
return self
def copy(self):
return self.from_storage(self.storage)
def clone(self):
return self.from_storage(self.storage.clone())
def type_as(self, tensor=torch.Tensor):
value = self.storage._value
if value is None or tensor.dtype == value.dtype:
return self
return self.from_storage(self.storage.type_as(tensor))
def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
if tensor.device == self.device():
return self
return self.from_storage(self.storage.device_as(tensor, non_blocking))
# Formats #################################################################
def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.row(), self.storage.col(), self.storage.value()
def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.storage.rowptr(), self.storage.col(), self.storage.value()
def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
perm = self.storage.csr2csc()
value = self.storage.value()
if value is not None:
value = value[perm]
return self.storage.colptr(), self.storage.row()[perm], value
# Storage inheritance #####################################################
def has_value(self) -> bool:
return self.storage.has_value()
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
self.storage.set_value_(value, layout)
return self
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> List[int]:
return self.storage.sparse_sizes()
def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim]
def sparse_resize(self, sparse_sizes: List[int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def is_coalesced(self) -> bool:
return self.storage.is_coalesced()
def coalesce(self, reduce: str = "sum"):
return self.from_storage(self.storage.coalesce(reduce))
def fill_cache_(self):
self.storage.fill_cache_()
return self
def clear_cache_(self):
self.storage.clear_cache_()
return self
# Utility functions #######################################################
def fill_value_(self, fill_value: float,
options: Optional[torch.Tensor] = None):
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device())
return self.set_value_(value, layout='coo')
def fill_value(self, fill_value: float,
options: Optional[torch.Tensor] = None):
if options is not None:
value = torch.full((self.nnz(), ), fill_value, dtype=options.dtype,
device=self.device())
else:
value = torch.full((self.nnz(), ), fill_value,
device=self.device())
return self.set_value(value, layout='coo')
def sizes(self) -> List[int]:
sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes = list(sizes) + list(value.size())[1:]
return sizes
def size(self, dim: int) -> int:
return self.sizes()[dim]
def dim(self) -> int:
return len(self.sizes())
def nnz(self) -> int:
return self.storage.col().numel()
def numel(self) -> int:
value = self.storage.value()
if value is not None:
return value.numel()
else:
return self.nnz()
def density(self) -> float:
return self.nnz() / (self.sparse_size(0) * self.sparse_size(1))
def sparsity(self) -> float:
return 1 - self.density()
def avg_row_length(self) -> float:
return self.nnz() / self.sparse_size(0)
def avg_col_length(self) -> float:
return self.nnz() / self.sparse_size(1)
def is_quadratic(self) -> bool:
return self.sparse_size(0) == self.sparse_size(1)
def is_symmetric(self) -> bool:
if not self.is_quadratic():
return False
rowptr, col, value1 = self.csr()
colptr, row, value2 = self.csc()
if (rowptr != colptr).any() or (col != row).any():
return False
if value1 is None or value2 is None:
return True
else:
return bool((value1 == value2).all())
def to_symmetric(self, reduce: str = "sum"):
row, col, value = self.coo()
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
if value is not None:
value = torch.cat([value, value], dim=0)
N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False)
out = out.coalesce(reduce)
return out
def detach_(self):
value = self.storage.value()
if value is not None:
value.detach_()
return self
def detach(self):
value = self.storage.value()
if value is not None:
value = value.detach()
return self.set_value(value, layout='coo')
def requires_grad(self) -> bool:
value = self.storage.value()
if value is not None:
return value.requires_grad
else:
return False
def requires_grad_(self, requires_grad: bool = True,
options: Optional[torch.Tensor] = None):
if requires_grad and not self.has_value():
self.fill_value_(1., options=options)
value = self.storage.value()
if value is not None:
value.requires_grad_(requires_grad)
return self
def pin_memory(self):
return self.from_storage(self.storage.pin_memory())
def is_pinned(self) -> bool:
return self.storage.is_pinned()
def options(self) -> torch.Tensor:
value = self.storage.value()
if value is not None:
return value
else:
return torch.tensor(0., dtype=torch.float,
device=self.storage.col().device)
def device(self):
return self.storage.col().device
def cpu(self):
return self.device_as(torch.tensor(0.), non_blocking=False)
def cuda(self, options: Optional[torch.Tensor] = None,
non_blocking: bool = False):
if options is not None:
return self.device_as(options, non_blocking)
else:
options = torch.tensor(0.).cuda()
return self.device_as(options, non_blocking)
def is_cuda(self) -> bool:
return self.storage.col().is_cuda
def dtype(self):
return self.options().dtype
def is_floating_point(self) -> bool:
return torch.is_floating_point(self.options())
def bfloat16(self):
return self.type_as(
torch.tensor(0, dtype=torch.bfloat16, device=self.device()))
def bool(self):
return self.type_as(
torch.tensor(0, dtype=torch.bool, device=self.device()))
def byte(self):
return self.type_as(
torch.tensor(0, dtype=torch.uint8, device=self.device()))
def char(self):
return self.type_as(
torch.tensor(0, dtype=torch.int8, device=self.device()))
def half(self):
return self.type_as(
torch.tensor(0, dtype=torch.half, device=self.device()))
def float(self):
return self.type_as(
torch.tensor(0, dtype=torch.float, device=self.device()))
def double(self):
return self.type_as(
torch.tensor(0, dtype=torch.double, device=self.device()))
def short(self):
return self.type_as(
torch.tensor(0, dtype=torch.short, device=self.device()))
def int(self):
return self.type_as(
torch.tensor(0, dtype=torch.int, device=self.device()))
def long(self):
return self.type_as(
torch.tensor(0, dtype=torch.long, device=self.device()))
# Conversions #############################################################
def to_dense(self, options: Optional[torch.Tensor] = None):
row, col, value = self.coo()
if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device())
elif options is not None:
mat = torch.zeros(self.sizes(), dtype=options.dtype,
device=self.device())
else:
mat = torch.zeros(self.sizes(), device=self.device())
if value is not None:
mat[row, col] = value
else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
device=mat.device)
return mat
def to_torch_sparse_coo_tensor(self,
options: Optional[torch.Tensor] = None):
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)
if value is None:
if options is not None:
value = torch.ones(self.nnz(), dtype=options.dtype,
device=self.device())
else:
value = torch.ones(self.nnz(), device=self.device())
return torch.sparse_coo_tensor(index, value, self.sizes())
# Python Bindings #############################################################
Dtype = Optional[torch.dtype]
Device = Optional[Union[torch.device, str]]
def share_memory_(self: SparseTensor) -> SparseTensor:
self.storage.share_memory_()
def is_shared(self: SparseTensor) -> bool:
return self.storage.is_shared()
def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking, _ = torch._C._nn._parse_to(*args, **kwargs)
if dtype is not None:
self = self.type_as(torch.tensor(0., dtype=dtype))
if device is not None:
self = self.device_as(torch.tensor(0., device=device), non_blocking)
return self
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
raise SyntaxError
dim = 0
out = self
while len(index) > 0:
item = index.pop(0)
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
elif isinstance(item, slice):
if item.step is not None:
raise ValueError('Step parameter not yet supported.')
start = 0 if item.start is None else item.start
start = self.size(dim) + start if start < 0 else start
stop = self.size(dim) if item.stop is None else item.stop
stop = self.size(dim) + stop if stop < 0 else stop
out = out.narrow(dim, start, max(stop - start, 0))
dim += 1
elif torch.is_tensor(item):
if item.dtype == torch.bool:
out = out.masked_select(dim, item)
dim += 1
elif item.dtype == torch.long:
out = out.index_select(dim, item)
dim += 1
elif item == Ellipsis:
if self.dim() - len(index) < dim:
raise SyntaxError
dim = self.dim() - len(index)
else:
raise SyntaxError
return out
def __repr__(self: SparseTensor) -> str:
i = ' ' * 6
row, col, value = self.coo()
infos = []
infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
if value is not None:
infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
infos += [
f'size={tuple(self.sizes())}, '
f'nnz={self.nnz()}, '
f'density={100 * self.density():.02f}%'
]
infos = ',\n'.join(infos)
i = ' ' * (len(self.__class__.__name__) + 1)
return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
SparseTensor.share_memory_ = share_memory_
SparseTensor.is_shared = is_shared
SparseTensor.to = to
SparseTensor.__getitem__ = __getitem__
SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.
csr_matrix, scipy.sparse.csc_matrix]
@torch.jit.ignore
def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
colptr = None
if isinstance(mat, scipy.sparse.csc_matrix):
colptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocsr()
rowptr = torch.from_numpy(mat.indptr).to(torch.long)
mat = mat.tocoo()
row = torch.from_numpy(mat.row).to(torch.long)
col = torch.from_numpy(mat.col).to(torch.long)
value = None
if has_value:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return SparseTensor.from_storage(storage)
@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
assert self.dim() == 2
layout = get_layout(layout)
if not self.has_value():
ones = torch.ones(self.nnz(), dtype=dtype).numpy()
if layout == 'coo':
row, col, value = self.coo()
row = row.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.coo_matrix((value, (row, col)), self.sizes())
elif layout == 'csr':
rowptr, col, value = self.csr()
rowptr = rowptr.detach().cpu().numpy()
col = col.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes())
elif layout == 'csc':
colptr, row, value = self.csc()
colptr = colptr.detach().cpu().numpy()
row = row.detach().cpu().numpy()
value = value.detach().cpu().numpy() if self.has_value() else ones
return scipy.sparse.csc_matrix((value, row, colptr), self.sizes())
SparseTensor.from_scipy = from_scipy
SparseTensor.to_scipy = to_scipy
# Hacky fixes #################################################################
# Fix standard operators of `torch.Tensor` for PyTorch<=1.4.
# https://github.com/pytorch/pytorch/pull/31769
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR < 1) or (TORCH_MAJOR == 1 and TORCH_MINOR <= 4):
def add(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.add(other)
return NotImplemented
def mul(self, other):
if torch.is_tensor(other) or is_scalar(other):
return self.mul(other)
return NotImplemented
torch.Tensor.__add__ = add
torch.Tensor.__mul__ = mul
import torch
from torch_sparse import to_scipy, from_scipy, coalesce
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
sparse_sizes = src.storage.sparse_sizes()
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
csr2csc=src.storage._csc2csr,
csc2csr=csr2csc,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
......@@ -15,14 +50,14 @@ def transpose(index, value, m, n, coalesced=True):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if value.dim() == 1 and not value.is_cuda:
mat = to_scipy(index, value, m, n).tocsc()
(col, row), value = from_scipy(mat)
index = torch.stack([row, col], dim=0)
return index, value
row, col = index
index = torch.stack([col, row], dim=0)
row, col = col, row
if coalesced:
index, value = coalesce(index, value, n, m)
return index, value
sparse_sizes = torch.Size([n, m])
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
row, col, value = storage.row(), storage.col(), storage.value()
return torch.stack([row, col], dim=0), value
from typing import Any
try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa
def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)
import torch
import numpy as np
if torch.cuda.is_available():
import torch_sparse.unique_cuda
def unique(src):
src = src.contiguous().view(-1)
if src.is_cuda:
out, perm = torch_sparse.unique_cuda.unique(src)
else:
out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
return out, perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment