Commit ed3de958 authored by rusty1s's avatar rusty1s
Browse files

Merge branch 'master' of github.com:rusty1s/pytorch_sparse

parents 3679da2b 266dffd3
...@@ -9,6 +9,7 @@ if(WITH_CUDA) ...@@ -9,6 +9,7 @@ if(WITH_CUDA)
enable_language(CUDA) enable_language(CUDA)
add_definitions(-D__CUDA_NO_HALF_OPERATORS__) add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
add_definitions(-DWITH_CUDA) add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=sm_35 --expt-relaxed-constexpr")
endif() endif()
find_package(Python3 COMPONENTS Development) find_package(Python3 COMPONENTS Development)
......
from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_getitem(dtype, device):
mat = torch.randn(50, 40, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device)
idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device)
assert mat[:10, :10].sizes() == [10, 10]
assert mat[..., :10].sizes() == [50, 10]
assert mat[idx1, idx2].sizes() == [10, 10]
assert mat[idx1.tolist()].sizes() == [10, 40]
...@@ -442,6 +442,8 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: ...@@ -442,6 +442,8 @@ def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
out = self out = self
while len(index) > 0: while len(index) > 0:
item = index.pop(0) item = index.pop(0)
if isinstance(item, (list, tuple)):
item = torch.tensor(item, dtype=torch.long, device=self.device())
if isinstance(item, int): if isinstance(item, int):
out = out.select(dim, item) out = out.select(dim, item)
dim += 1 dim += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment