Unverified Commit 28f12953 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #176 from shi27feng/patch-1

Update storage.py
parents 23709f94 9b5d3c79
......@@ -34,7 +34,8 @@ class SparseStorage(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[int, int]] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
rowcount: Optional[torch.Tensor] = None,
colptr: Optional[torch.Tensor] = None,
colcount: Optional[torch.Tensor] = None,
......@@ -48,26 +49,33 @@ class SparseStorage(object):
assert col.dim() == 1
col = col.contiguous()
if sparse_sizes is None:
M: int = 0
if sparse_sizes is None or sparse_sizes[0] is None:
if rowptr is not None:
M = rowptr.numel() - 1
elif row is not None and row.numel() > 0:
M = row.max().item() + 1
elif row is not None and row.numel() == 0:
M = 0
else:
raise ValueError
M = int(row.max()) + 1
else:
_M = sparse_sizes[0]
assert _M is not None
M = _M
if rowptr is not None:
assert rowptr.numel() - 1 == M
elif row is not None and row.numel() > 0:
assert int(row.max()) < M
N: int = 0
if sparse_sizes is None or sparse_sizes[1] is None:
if col.numel() > 0:
N = col.max().item() + 1
else:
N = 0
sparse_sizes = (int(M), int(N))
N = int(col.max()) + 1
else:
assert len(sparse_sizes) == 2
if row is not None and row.numel() > 0:
assert row.max().item() < sparse_sizes[0]
_N = sparse_sizes[1]
assert _N is not None
N = _N
if col.numel() > 0:
assert col.max().item() < sparse_sizes[1]
assert int(col.max()) < N
sparse_sizes = (M, N)
if row is not None:
assert row.dtype == torch.long
......
......@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr: Optional[torch.Tensor] = None,
col: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[int, int]] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False):
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
value=value, sparse_sizes=sparse_sizes,
......@@ -39,7 +40,8 @@ class SparseTensor(object):
@classmethod
def from_edge_index(self, edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None,
sparse_sizes: Optional[Tuple[int, int]] = None,
sparse_sizes: Optional[Tuple[Optional[int],
Optional[int]]] = None,
is_sorted: bool = False):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment