Unverified Commit 88c6ceb6 authored by bwdeng20's avatar bwdeng20 Committed by GitHub
Browse files

fix(`SparseTensor.__getitem__`): support `np.ndarray` and fix `List[b… (#194)



* fix(`SparseTensor.__getitem__`): support `np.ndarray` and fix `List[bool]`

support indexing with np.ndarray & fix bug merging from indexing with
List[bool]

* style(tensor, test_tensor): pep8 E501 too long

support indexing with np.ndarray & fix bug merging from indexing
with
List[bool]

* update

* typo

* typo
Co-authored-by: default avatartim <dbwtimteo@outlook.com>
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent efc98089
...@@ -9,16 +9,50 @@ from .utils import grad_dtypes, devices ...@@ -9,16 +9,50 @@ from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device): def test_getitem(dtype, device):
mat = torch.randn(50, 40, dtype=dtype, device=device) m = 50
n = 40
k = 10
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat) mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device) idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device) idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
bool2.scatter_(0, idx2, 1)
# idx1 and idx2 may have duplicates
k1_bool = bool1.nonzero().size(0)
k2_bool = bool2.nonzero().size(0)
assert mat[:10, :10].sizes() == [10, 10] idx1np = idx1.cpu().numpy()
assert mat[..., :10].sizes() == [50, 10] idx2np = idx2.cpu().numpy()
assert mat[idx1, idx2].sizes() == [10, 10] bool1np = bool1.cpu().numpy()
assert mat[idx1.tolist()].sizes() == [10, 40] bool2np = bool2.cpu().numpy()
idx1list = idx1np.tolist()
idx2list = idx2np.tolist()
bool1list = bool1np.tolist()
bool2list = bool2np.tolist()
assert mat[:k, :k].sizes() == [k, k]
assert mat[..., :k].sizes() == [m, k]
assert mat[idx1, idx2].sizes() == [k, k]
assert mat[idx1np, idx2np].sizes() == [k, k]
assert mat[idx1list, idx2list].sizes() == [k, k]
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
assert mat[idx1].sizes() == [k, n]
assert mat[idx1np].sizes() == [k, n]
assert mat[idx1list].sizes() == [k, n]
assert mat[bool1].sizes() == [k1_bool, n]
assert mat[bool1np].sizes() == [k1_bool, n]
assert mat[bool1list].sizes() == [k1_bool, n]
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
......
...@@ -2,6 +2,7 @@ from textwrap import indent ...@@ -2,6 +2,7 @@ from textwrap import indent
from typing import Optional, List, Tuple, Dict, Union, Any from typing import Optional, List, Tuple, Dict, Union, Any
import torch import torch
import numpy as np
import scipy.sparse import scipy.sparse
from torch_scatter import segment_csr from torch_scatter import segment_csr
...@@ -468,7 +469,6 @@ def is_shared(self: SparseTensor) -> bool: ...@@ -468,7 +469,6 @@ def is_shared(self: SparseTensor) -> bool:
def to(self, *args: Optional[List[Any]], def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor: **kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3] device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None: if dtype is not None:
...@@ -491,7 +491,10 @@ def cuda(self, device: Optional[Union[int, str]] = None, ...@@ -491,7 +491,10 @@ def cuda(self, device: Optional[Union[int, str]] = None,
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index] index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed... # More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1: if len([
i for i in index
if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
]) > 1:
raise SyntaxError raise SyntaxError
dim = 0 dim = 0
...@@ -499,7 +502,10 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: ...@@ -499,7 +502,10 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
while len(index) > 0: while len(index) > 0:
item = index.pop(0) item = index.pop(0)
if isinstance(item, (list, tuple)): if isinstance(item, (list, tuple)):
item = torch.tensor(item, dtype=torch.long, device=self.device()) item = torch.tensor(item, device=self.device())
if isinstance(item, np.ndarray):
item = torch.from_numpy(item).to(self.device())
if isinstance(item, int): if isinstance(item, int):
out = out.select(dim, item) out = out.select(dim, item)
dim += 1 dim += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment