Unverified Commit bfb571cb authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #40 from rusty1s/metis

[WIP] Partition
parents e78637ea eee47eee
......@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
......@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
......
from typing import Tuple
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
......@@ -31,7 +29,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = torch.Size([length, src.sparse_size(1)])
sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount
if rowcount is not None:
......@@ -54,7 +52,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None:
value = value[mask]
sparse_sizes = torch.Size([src.sparse_size(0), length])
sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr
if colptr is not None:
......@@ -80,7 +78,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise ValueError
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
......
import torch
from torch_sparse.tensor import SparseTensor
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_quadratic()
return src.index_select(0, perm).index_select(1, perm)
SparseTensor.permute = lambda self, perm: permute(self, perm)
......@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value()
......@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise ValueError
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum')
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean')
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min')
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max')
......
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1)
......
......@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""
A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
sparse_sizes=torch.Size([m, k]), is_sorted=not coalesced)
sparse_sizes=(m, k), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=torch.Size([k, n]), is_sorted=not coalesced)
sparse_sizes=(k, n), is_sorted=not coalesced)
C = matmul(A, B)
row, col, value = C.coo()
......
import warnings
from typing import Optional, List
from typing import Optional, List, Tuple
import torch
from torch_scatter import segment_csr, scatter_add
......@@ -23,19 +23,18 @@ class SparseStorage(object):
_rowptr: Optional[torch.Tensor]
_col: torch.Tensor
_value: Optional[torch.Tensor]
_sparse_sizes: List[int]
_sparse_sizes: Tuple[int, int]
_rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor]
def __init__(self,
row: Optional[torch.Tensor] = None,
def __init__(self, row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[List[int]] = None,
sparse_sizes: Optional[Tuple[int, int]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
......@@ -57,7 +56,7 @@ class SparseStorage(object):
else:
raise ValueError
N = col.max().item() + 1
sparse_sizes = torch.Size([int(M), int(N)])
sparse_sizes = (int(M), int(N))
else:
assert len(sparse_sizes) == 2
......@@ -119,7 +118,7 @@ class SparseStorage(object):
self._rowptr = rowptr
self._col = col
self._value = value
self._sparse_sizes = sparse_sizes
self._sparse_sizes = tuple(sparse_sizes)
self._rowcount = rowcount
self._colptr = colptr
self._colcount = colcount
......@@ -192,8 +191,7 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]:
return self._value
def set_value_(self,
value: Optional[torch.Tensor],
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -205,8 +203,7 @@ class SparseStorage(object):
self._value = value
return self
def set_value(self,
value: Optional[torch.Tensor],
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc':
......@@ -215,26 +212,19 @@ class SparseStorage(object):
assert value.device == self._col.device
assert value.size(0) == self._col.numel()
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def sparse_sizes(self) -> List[int]:
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=value, sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def sparse_sizes(self) -> Tuple[int, int]:
return self._sparse_sizes
def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: List[int]):
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
......@@ -264,18 +254,11 @@ class SparseStorage(object):
if colcount is not None:
colcount = colcount[:-diff_1]
return SparseStorage(
row=self._row,
rowptr=rowptr,
col=self._col,
value=self._value,
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
value=self._value, sparse_sizes=sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def has_rowcount(self) -> bool:
return self._rowcount is not None
......@@ -320,10 +303,8 @@ class SparseStorage(object):
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(
torch.ones_like(self._col),
self._col,
dim_size=self._sparse_sizes[1])
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount
return colcount
......@@ -375,18 +356,10 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
return SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=self._sparse_sizes, rowcount=None,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
def fill_cache_(self):
self.row()
......@@ -406,33 +379,30 @@ class SparseStorage(object):
self._csc2csr = None
return self
def num_cached_keys(self) -> int:
count = 0
def cached_keys(self) -> List[str]:
keys: List[str] = []
if self.has_rowcount():
count += 1
keys.append('rowcount')
if self.has_colptr():
count += 1
keys.append('colptr')
if self.has_colcount():
count += 1
keys.append('colcount')
if self.has_csr2csc():
count += 1
keys.append('csr2csc')
if self.has_csc2csr():
count += 1
return count
keys.append('csc2csr')
return keys
def num_cached_keys(self) -> int:
return len(self.cached_keys())
def copy(self):
return SparseStorage(
row=self._row,
rowptr=self._rowptr,
col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, colptr=self._colptr,
colcount=self._colcount, csr2csc=self._csr2csc,
csc2csr=self._csc2csr, is_sorted=True)
def clone(self):
row = self._row
......@@ -460,18 +430,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.clone()
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def type_as(self, tensor=torch.Tensor):
value = self._value
......@@ -512,18 +475,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def pin_memory(self):
row = self._row
......@@ -551,18 +507,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.pin_memory()
return SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=self._sparse_sizes,
rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
def is_pinned(self) -> bool:
is_pinned = True
......
......@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: List[int] = None, is_sorted: bool = False):
sparse_sizes: Optional[Tuple[int, int]] = None,
is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None,
......@@ -45,7 +46,8 @@ class SparseTensor(object):
value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
......@@ -59,7 +61,8 @@ class SparseTensor(object):
value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True)
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod
def eye(self, M: int, N: Optional[int] = None,
......@@ -105,10 +108,9 @@ class SparseTensor(object):
csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr,
colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr,
is_sorted=True)
row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
rowcount=rowcount, colptr=colptr, colcount=colcount,
csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
self = SparseTensor.__new__(SparseTensor)
self.storage = storage
......@@ -160,13 +162,13 @@ class SparseTensor(object):
layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> List[int]:
def sparse_sizes(self) -> Tuple[int, int]:
return self.storage.sparse_sizes()
def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim]
def sparse_resize(self, sparse_sizes: List[int]):
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def is_coalesced(self) -> bool:
......@@ -206,11 +208,12 @@ class SparseTensor(object):
return self.set_value(value, layout='coo')
def sizes(self) -> List[int]:
sizes = self.sparse_sizes()
sparse_sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes = list(sizes) + list(value.size())[1:]
return sizes
return list(sparse_sizes) + list(value.size())[1:]
else:
return list(sparse_sizes)
def size(self, dim: int) -> int:
return self.sizes()[dim]
......@@ -268,7 +271,7 @@ class SparseTensor(object):
N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False)
sparse_sizes=(N, N), is_sorted=False)
out = out.coalesce(reduce)
return out
......
......@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
......@@ -20,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor:
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]),
sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
......@@ -54,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True):
row, col = col, row
if coalesced:
sparse_sizes = torch.Size([n, m])
sparse_sizes = (n, m)
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment