Unverified Commit bfb571cb authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #40 from rusty1s/metis

[WIP] Partition
parents e78637ea eee47eee
...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr ...@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value(value, layout='coo') return src.set_value(value, layout='coo')
@torch.jit.script
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise... if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
...@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: ...@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return src.set_value_(value, layout='coo') return src.set_value_(value, layout='coo')
@torch.jit.script
def mul_nnz(src: SparseTensor, other: torch.Tensor, def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
...@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor, ...@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return src.set_value(value, layout=layout) return src.set_value(value, layout=layout)
@torch.jit.script
def mul_nnz_(src: SparseTensor, other: torch.Tensor, def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor: layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value() value = src.storage.value()
......
from typing import Tuple from typing import Tuple
import torch
from torch_sparse.storage import SparseStorage from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def narrow(src: SparseTensor, dim: int, start: int, def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor: length: int) -> SparseTensor:
if dim < 0: if dim < 0:
...@@ -31,7 +29,7 @@ def narrow(src: SparseTensor, dim: int, start: int, ...@@ -31,7 +29,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None: if value is not None:
value = value.narrow(0, row_start, row_length) value = value.narrow(0, row_start, row_length)
sparse_sizes = torch.Size([length, src.sparse_size(1)]) sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount rowcount = src.storage._rowcount
if rowcount is not None: if rowcount is not None:
...@@ -54,7 +52,7 @@ def narrow(src: SparseTensor, dim: int, start: int, ...@@ -54,7 +52,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if value is not None: if value is not None:
value = value[mask] value = value[mask]
sparse_sizes = torch.Size([src.sparse_size(0), length]) sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr colptr = src.storage._colptr
if colptr is not None: if colptr is not None:
...@@ -80,7 +78,6 @@ def narrow(src: SparseTensor, dim: int, start: int, ...@@ -80,7 +78,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise ValueError raise ValueError
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int], def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor: length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence # This function builds the inverse operation of `cat_diag` and should hence
......
import torch
from torch_sparse.tensor import SparseTensor
def permute(src: SparseTensor, perm: torch.Tensor) -> SparseTensor:
assert src.is_quadratic()
return src.index_select(0, perm).index_select(1, perm)
SparseTensor.permute = lambda self, perm: permute(self, perm)
...@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr ...@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def reduction(src: SparseTensor, dim: Optional[int] = None, def reduction(src: SparseTensor, dim: Optional[int] = None,
reduce: str = 'sum') -> torch.Tensor: reduce: str = 'sum') -> torch.Tensor:
value = src.storage.value() value = src.storage.value()
...@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None, ...@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise ValueError raise ValueError
@torch.jit.script
def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def sum(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='sum') return reduction(src, dim, reduce='sum')
@torch.jit.script
def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def mean(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='mean') return reduction(src, dim, reduce='mean')
@torch.jit.script
def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def min(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='min') return reduction(src, dim, reduce='min')
@torch.jit.script
def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor: def max(src: SparseTensor, dim: Optional[int] = None) -> torch.Tensor:
return reduction(src, dim, reduce='max') return reduction(src, dim, reduce='max')
......
import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow from torch_sparse.narrow import narrow
@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor: def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
return narrow(src, dim, start=idx, length=1) return narrow(src, dim, start=idx, length=1)
......
...@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): ...@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
""" """
A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA, A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA,
sparse_sizes=torch.Size([m, k]), is_sorted=not coalesced) sparse_sizes=(m, k), is_sorted=not coalesced)
B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB, B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB,
sparse_sizes=torch.Size([k, n]), is_sorted=not coalesced) sparse_sizes=(k, n), is_sorted=not coalesced)
C = matmul(A, B) C = matmul(A, B)
row, col, value = C.coo() row, col, value = C.coo()
......
import warnings import warnings
from typing import Optional, List from typing import Optional, List, Tuple
import torch import torch
from torch_scatter import segment_csr, scatter_add from torch_scatter import segment_csr, scatter_add
...@@ -23,19 +23,18 @@ class SparseStorage(object): ...@@ -23,19 +23,18 @@ class SparseStorage(object):
_rowptr: Optional[torch.Tensor] _rowptr: Optional[torch.Tensor]
_col: torch.Tensor _col: torch.Tensor
_value: Optional[torch.Tensor] _value: Optional[torch.Tensor]
_sparse_sizes: List[int] _sparse_sizes: Tuple[int, int]
_rowcount: Optional[torch.Tensor] _rowcount: Optional[torch.Tensor]
_colptr: Optional[torch.Tensor] _colptr: Optional[torch.Tensor]
_colcount: Optional[torch.Tensor] _colcount: Optional[torch.Tensor]
_csr2csc: Optional[torch.Tensor] _csr2csc: Optional[torch.Tensor]
_csc2csr: Optional[torch.Tensor] _csc2csr: Optional[torch.Tensor]
def __init__(self, def __init__(self, row: Optional[torch.Tensor] = None,
row: Optional[torch.Tensor] = None,
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[List[int]] = None, sparse_sizes: Optional[Tuple[int, int]] = None,
rowcount: Optional[torch.Tensor] = None, rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None, colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None, colcount: Optional[torch.Tensor] = None,
...@@ -57,7 +56,7 @@ class SparseStorage(object): ...@@ -57,7 +56,7 @@ class SparseStorage(object):
else: else:
raise ValueError raise ValueError
N = col.max().item() + 1 N = col.max().item() + 1
sparse_sizes = torch.Size([int(M), int(N)]) sparse_sizes = (int(M), int(N))
else: else:
assert len(sparse_sizes) == 2 assert len(sparse_sizes) == 2
...@@ -119,7 +118,7 @@ class SparseStorage(object): ...@@ -119,7 +118,7 @@ class SparseStorage(object):
self._rowptr = rowptr self._rowptr = rowptr
self._col = col self._col = col
self._value = value self._value = value
self._sparse_sizes = sparse_sizes self._sparse_sizes = tuple(sparse_sizes)
self._rowcount = rowcount self._rowcount = rowcount
self._colptr = colptr self._colptr = colptr
self._colcount = colcount self._colcount = colcount
...@@ -192,8 +191,7 @@ class SparseStorage(object): ...@@ -192,8 +191,7 @@ class SparseStorage(object):
def value(self) -> Optional[torch.Tensor]: def value(self) -> Optional[torch.Tensor]:
return self._value return self._value
def set_value_(self, def set_value_(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -205,8 +203,7 @@ class SparseStorage(object): ...@@ -205,8 +203,7 @@ class SparseStorage(object):
self._value = value self._value = value
return self return self
def set_value(self, def set_value(self, value: Optional[torch.Tensor],
value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc': if get_layout(layout) == 'csc':
...@@ -215,26 +212,19 @@ class SparseStorage(object): ...@@ -215,26 +212,19 @@ class SparseStorage(object):
assert value.device == self._col.device assert value.device == self._col.device
assert value.size(0) == self._col.numel() assert value.size(0) == self._col.numel()
return SparseStorage( return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
row=self._row, value=value, sparse_sizes=self._sparse_sizes,
rowptr=self._rowptr, rowcount=self._rowcount, colptr=self._colptr,
col=self._col, colcount=self._colcount, csr2csc=self._csr2csc,
value=value, csc2csr=self._csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=self._rowcount, def sparse_sizes(self) -> Tuple[int, int]:
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def sparse_sizes(self) -> List[int]:
return self._sparse_sizes return self._sparse_sizes
def sparse_size(self, dim: int) -> int: def sparse_size(self, dim: int) -> int:
return self._sparse_sizes[dim] return self._sparse_sizes[dim]
def sparse_resize(self, sparse_sizes: List[int]): def sparse_resize(self, sparse_sizes: Tuple[int, int]):
assert len(sparse_sizes) == 2 assert len(sparse_sizes) == 2
old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel() old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()
...@@ -264,18 +254,11 @@ class SparseStorage(object): ...@@ -264,18 +254,11 @@ class SparseStorage(object):
if colcount is not None: if colcount is not None:
colcount = colcount[:-diff_1] colcount = colcount[:-diff_1]
return SparseStorage( return SparseStorage(row=self._row, rowptr=rowptr, col=self._col,
row=self._row, value=self._value, sparse_sizes=sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=self._col, colcount=colcount, csr2csc=self._csr2csc,
value=self._value, csc2csr=self._csc2csr, is_sorted=True)
sparse_sizes=sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def has_rowcount(self) -> bool: def has_rowcount(self) -> bool:
return self._rowcount is not None return self._rowcount is not None
...@@ -320,10 +303,8 @@ class SparseStorage(object): ...@@ -320,10 +303,8 @@ class SparseStorage(object):
if colptr is not None: if colptr is not None:
colcount = colptr[1:] - colptr[:-1] colcount = colptr[1:] - colptr[:-1]
else: else:
colcount = scatter_add( colcount = scatter_add(torch.ones_like(self._col), self._col,
torch.ones_like(self._col), dim_size=self._sparse_sizes[1])
self._col,
dim_size=self._sparse_sizes[1])
self._colcount = colcount self._colcount = colcount
return colcount return colcount
...@@ -375,18 +356,10 @@ class SparseStorage(object): ...@@ -375,18 +356,10 @@ class SparseStorage(object):
value = segment_csr(value, ptr, reduce=reduce) value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return SparseStorage( return SparseStorage(row=row, rowptr=None, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes, rowcount=None,
rowptr=None, colptr=None, colcount=None, csr2csc=None,
col=col, csc2csr=None, is_sorted=True)
value=value,
sparse_sizes=self._sparse_sizes,
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True)
def fill_cache_(self): def fill_cache_(self):
self.row() self.row()
...@@ -406,33 +379,30 @@ class SparseStorage(object): ...@@ -406,33 +379,30 @@ class SparseStorage(object):
self._csc2csr = None self._csc2csr = None
return self return self
def num_cached_keys(self) -> int: def cached_keys(self) -> List[str]:
count = 0 keys: List[str] = []
if self.has_rowcount(): if self.has_rowcount():
count += 1 keys.append('rowcount')
if self.has_colptr(): if self.has_colptr():
count += 1 keys.append('colptr')
if self.has_colcount(): if self.has_colcount():
count += 1 keys.append('colcount')
if self.has_csr2csc(): if self.has_csr2csc():
count += 1 keys.append('csr2csc')
if self.has_csc2csr(): if self.has_csc2csr():
count += 1 keys.append('csc2csr')
return count return keys
def num_cached_keys(self) -> int:
return len(self.cached_keys())
def copy(self): def copy(self):
return SparseStorage( return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
row=self._row, value=self._value,
rowptr=self._rowptr, sparse_sizes=self._sparse_sizes,
col=self._col, rowcount=self._rowcount, colptr=self._colptr,
value=self._value, colcount=self._colcount, csr2csc=self._csr2csc,
sparse_sizes=self._sparse_sizes, csc2csr=self._csc2csr, is_sorted=True)
rowcount=self._rowcount,
colptr=self._colptr,
colcount=self._colcount,
csr2csc=self._csr2csc,
csc2csr=self._csc2csr,
is_sorted=True)
def clone(self): def clone(self):
row = self._row row = self._row
...@@ -460,18 +430,11 @@ class SparseStorage(object): ...@@ -460,18 +430,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.clone() csc2csr = csc2csr.clone()
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def type_as(self, tensor=torch.Tensor): def type_as(self, tensor=torch.Tensor):
value = self._value value = self._value
...@@ -512,18 +475,11 @@ class SparseStorage(object): ...@@ -512,18 +475,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking) csc2csr = csc2csr.to(tensor.device, non_blocking=non_blocking)
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def pin_memory(self): def pin_memory(self):
row = self._row row = self._row
...@@ -551,18 +507,11 @@ class SparseStorage(object): ...@@ -551,18 +507,11 @@ class SparseStorage(object):
csc2csr = self._csc2csr csc2csr = self._csc2csr
if csc2csr is not None: if csc2csr is not None:
csc2csr = csc2csr.pin_memory() csc2csr = csc2csr.pin_memory()
return SparseStorage( return SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
row=row, sparse_sizes=self._sparse_sizes,
rowptr=rowptr, rowcount=rowcount, colptr=colptr,
col=col, colcount=colcount, csr2csc=csr2csc,
value=value, csc2csr=csc2csr, is_sorted=True)
sparse_sizes=self._sparse_sizes,
rowcount=rowcount,
colptr=colptr,
colcount=colcount,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True)
def is_pinned(self) -> bool: def is_pinned(self) -> bool:
is_pinned = True is_pinned = True
......
...@@ -16,7 +16,8 @@ class SparseTensor(object): ...@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr: Optional[torch.Tensor] = None, rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None, col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None,
sparse_sizes: List[int] = None, is_sorted: bool = False): sparse_sizes: Optional[Tuple[int, int]] = None,
is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col, self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes, value=value, sparse_sizes=sparse_sizes,
rowcount=None, colptr=None, colcount=None, rowcount=None, colptr=None, colcount=None,
...@@ -45,7 +46,8 @@ class SparseTensor(object): ...@@ -45,7 +46,8 @@ class SparseTensor(object):
value = mat[row, col] value = mat[row, col]
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True) sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod @classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor, def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
...@@ -59,7 +61,8 @@ class SparseTensor(object): ...@@ -59,7 +61,8 @@ class SparseTensor(object):
value = mat._values() value = mat._values()
return SparseTensor(row=row, rowptr=None, col=col, value=value, return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=mat.size()[:2], is_sorted=True) sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True)
@classmethod @classmethod
def eye(self, M: int, N: Optional[int] = None, def eye(self, M: int, N: Optional[int] = None,
...@@ -105,10 +108,9 @@ class SparseTensor(object): ...@@ -105,10 +108,9 @@ class SparseTensor(object):
csr2csc = csc2csr = row csr2csc = csc2csr = row
storage: SparseStorage = SparseStorage( storage: SparseStorage = SparseStorage(
row=row, rowptr=rowptr, col=col, value=value, row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=(M, N),
sparse_sizes=torch.Size([M, N]), rowcount=rowcount, colptr=colptr, rowcount=rowcount, colptr=colptr, colcount=colcount,
colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True)
is_sorted=True)
self = SparseTensor.__new__(SparseTensor) self = SparseTensor.__new__(SparseTensor)
self.storage = storage self.storage = storage
...@@ -160,13 +162,13 @@ class SparseTensor(object): ...@@ -160,13 +162,13 @@ class SparseTensor(object):
layout: Optional[str] = None): layout: Optional[str] = None):
return self.from_storage(self.storage.set_value(value, layout)) return self.from_storage(self.storage.set_value(value, layout))
def sparse_sizes(self) -> List[int]: def sparse_sizes(self) -> Tuple[int, int]:
return self.storage.sparse_sizes() return self.storage.sparse_sizes()
def sparse_size(self, dim: int) -> int: def sparse_size(self, dim: int) -> int:
return self.storage.sparse_sizes()[dim] return self.storage.sparse_sizes()[dim]
def sparse_resize(self, sparse_sizes: List[int]): def sparse_resize(self, sparse_sizes: Tuple[int, int]):
return self.from_storage(self.storage.sparse_resize(sparse_sizes)) return self.from_storage(self.storage.sparse_resize(sparse_sizes))
def is_coalesced(self) -> bool: def is_coalesced(self) -> bool:
...@@ -206,11 +208,12 @@ class SparseTensor(object): ...@@ -206,11 +208,12 @@ class SparseTensor(object):
return self.set_value(value, layout='coo') return self.set_value(value, layout='coo')
def sizes(self) -> List[int]: def sizes(self) -> List[int]:
sizes = self.sparse_sizes() sparse_sizes = self.sparse_sizes()
value = self.storage.value() value = self.storage.value()
if value is not None: if value is not None:
sizes = list(sizes) + list(value.size())[1:] return list(sparse_sizes) + list(value.size())[1:]
return sizes else:
return list(sparse_sizes)
def size(self, dim: int) -> int: def size(self, dim: int) -> int:
return self.sizes()[dim] return self.sizes()[dim]
...@@ -268,7 +271,7 @@ class SparseTensor(object): ...@@ -268,7 +271,7 @@ class SparseTensor(object):
N = max(self.size(0), self.size(1)) N = max(self.size(0), self.size(1))
out = SparseTensor(row=row, rowptr=None, col=col, value=value, out = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=torch.Size([N, N]), is_sorted=False) sparse_sizes=(N, N), is_sorted=False)
out = out.coalesce(reduce) out = out.coalesce(reduce)
return out return out
......
...@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage ...@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
@torch.jit.script
def t(src: SparseTensor) -> SparseTensor: def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc() csr2csc = src.storage.csr2csc()
...@@ -20,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor: ...@@ -20,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor:
rowptr=src.storage._colptr, rowptr=src.storage._colptr,
col=row[csr2csc], col=row[csr2csc],
value=value, value=value,
sparse_sizes=torch.Size([sparse_sizes[1], sparse_sizes[0]]), sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rowcount=src.storage._colcount, rowcount=src.storage._colcount,
colptr=src.storage._rowptr, colptr=src.storage._rowptr,
colcount=src.storage._rowcount, colcount=src.storage._rowcount,
...@@ -54,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True): ...@@ -54,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True):
row, col = col, row row, col = col, row
if coalesced: if coalesced:
sparse_sizes = torch.Size([n, m]) sparse_sizes = (n, m)
storage = SparseStorage(row=row, col=col, value=value, storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False) sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce() storage = storage.coalesce()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment