Unverified Commit c4c6db4a authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Remove `test/__init__.py` (#298)

* set version

* update

* update

* update

* fix test

* fix test
parent c86d777a
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor, add
from .utils import dtypes, devices, tensor
from torch_sparse import SparseTensor, add
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from .utils import devices, tensor
from torch_sparse.cat import cat
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor
@pytest.mark.parametrize('device', devices)
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices, tensor
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
from itertools import product
import pytest
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, dtypes
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
......@@ -6,8 +6,7 @@ import torch_scatter
from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor
from .utils import devices, grad_dtypes, reductions
from torch_sparse.testing import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce',
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices
try:
rowptr = torch.tensor([0, 1])
......
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor
@pytest.mark.parametrize('device', devices)
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse import spmm
from .utils import dtypes, devices, tensor
from torch_sparse import spmm
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
......@@ -4,8 +4,7 @@ import pytest
import torch
from torch_sparse import SparseTensor, spspmm
from .utils import devices, grad_dtypes, tensor
from torch_sparse.testing import devices, grad_dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
from torch_sparse.storage import SparseStorage
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('device', devices)
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices
from torch_sparse import SparseTensor
from torch_sparse.testing import devices, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......@@ -15,8 +15,8 @@ def test_getitem(dtype, device):
mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
idx1 = torch.randint(0, m, (k, ), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k, ), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1)
......
......@@ -2,9 +2,9 @@ from itertools import product
import pytest
import torch
from torch_sparse import transpose
from .utils import dtypes, devices, tensor
from torch_sparse import transpose
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
from typing import Any
import torch
import torch_scatter
from packaging import version
......@@ -13,8 +15,8 @@ if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment