Unverified Commit 88c6ceb6 authored by bwdeng20's avatar bwdeng20 Committed by GitHub
Browse files

fix(`SparseTensor.__getitem__`): support `np.ndarray` and fix `List[b… (#194)



* fix(`SparseTensor.__getitem__`): support `np.ndarray` and fix `List[bool]`

support indexing with np.ndarray & fix bug merging from indexing with
List[bool]

* style(tensor, test_tensor): pep8 E501 too long

support indexing with np.ndarray & fix bug merging from indexing
with
List[bool]

* update

* typo

* typo
Co-authored-by: default avatartim <dbwtimteo@outlook.com>
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent efc98089
......@@ -9,16 +9,50 @@ from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
mat = torch.randn(50, 40, dtype=dtype, device=device)
m = 50
n = 40
k = 10
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device)
idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
bool2.scatter_(0, idx2, 1)
# idx1 and idx2 may have duplicates
k1_bool = bool1.nonzero().size(0)
k2_bool = bool2.nonzero().size(0)
assert mat[:10, :10].sizes() == [10, 10]
assert mat[..., :10].sizes() == [50, 10]
assert mat[idx1, idx2].sizes() == [10, 10]
assert mat[idx1.tolist()].sizes() == [10, 40]
idx1np = idx1.cpu().numpy()
idx2np = idx2.cpu().numpy()
bool1np = bool1.cpu().numpy()
bool2np = bool2.cpu().numpy()
idx1list = idx1np.tolist()
idx2list = idx2np.tolist()
bool1list = bool1np.tolist()
bool2list = bool2np.tolist()
assert mat[:k, :k].sizes() == [k, k]
assert mat[..., :k].sizes() == [m, k]
assert mat[idx1, idx2].sizes() == [k, k]
assert mat[idx1np, idx2np].sizes() == [k, k]
assert mat[idx1list, idx2list].sizes() == [k, k]
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
assert mat[idx1].sizes() == [k, n]
assert mat[idx1np].sizes() == [k, n]
assert mat[idx1list].sizes() == [k, n]
assert mat[bool1].sizes() == [k1_bool, n]
assert mat[bool1np].sizes() == [k1_bool, n]
assert mat[bool1list].sizes() == [k1_bool, n]
@pytest.mark.parametrize('device', devices)
......
......@@ -2,6 +2,7 @@ from textwrap import indent
from typing import Optional, List, Tuple, Dict, Union, Any
import torch
import numpy as np
import scipy.sparse
from torch_scatter import segment_csr
......@@ -468,7 +469,6 @@ def is_shared(self: SparseTensor) -> bool:
def to(self, *args: Optional[List[Any]],
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
if dtype is not None:
......@@ -491,7 +491,10 @@ def cuda(self, device: Optional[Union[int, str]] = None,
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
index = list(index) if isinstance(index, tuple) else [index]
# More than one `Ellipsis` is not allowed...
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
if len([
i for i in index
if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
]) > 1:
raise SyntaxError
dim = 0
......@@ -499,7 +502,10 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
while len(index) > 0:
item = index.pop(0)
if isinstance(item, (list, tuple)):
item = torch.tensor(item, dtype=torch.long, device=self.device())
item = torch.tensor(item, device=self.device())
if isinstance(item, np.ndarray):
item = torch.from_numpy(item).to(self.device())
if isinstance(item, int):
out = out.select(dim, item)
dim += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment