Unverified Commit c4c6db4a authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Remove `test/__init__.py` (#298)

* set version

* update

* update

* update

* fix test

* fix test
parent c86d777a
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import SparseTensor, add
from .utils import dtypes, devices, tensor from torch_sparse import SparseTensor, add
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from .utils import devices, tensor from torch_sparse.cat import cat
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices, tensor from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
from itertools import product from itertools import product
import pytest import pytest
from torch_sparse.tensor import SparseTensor
from .utils import dtypes, devices from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, dtypes
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
...@@ -6,8 +6,7 @@ import torch_scatter ...@@ -6,8 +6,7 @@ import torch_scatter
from torch_sparse.matmul import matmul from torch_sparse.matmul import matmul
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, grad_dtypes, reductions
from .utils import devices, grad_dtypes, reductions
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices
try: try:
rowptr = torch.tensor([0, 1]) rowptr = torch.tensor([0, 1])
......
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor
from .utils import devices, tensor from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import spmm
from .utils import dtypes, devices, tensor from torch_sparse import spmm
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
...@@ -4,8 +4,7 @@ import pytest ...@@ -4,8 +4,7 @@ import pytest
import torch import torch
from torch_sparse import SparseTensor, spspmm from torch_sparse import SparseTensor, spspmm
from torch_sparse.testing import devices, grad_dtypes, tensor
from .utils import devices, grad_dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor from torch_sparse.storage import SparseStorage
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import SparseTensor
from .utils import grad_dtypes, devices from torch_sparse import SparseTensor
from torch_sparse.testing import devices, grad_dtypes
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
...@@ -15,8 +15,8 @@ def test_getitem(dtype, device): ...@@ -15,8 +15,8 @@ def test_getitem(dtype, device):
mat = torch.randn(m, n, dtype=dtype, device=device) mat = torch.randn(m, n, dtype=dtype, device=device)
mat = SparseTensor.from_dense(mat) mat = SparseTensor.from_dense(mat)
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device) idx1 = torch.randint(0, m, (k, ), dtype=torch.long, device=device)
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device) idx2 = torch.randint(0, n, (k, ), dtype=torch.long, device=device)
bool1 = torch.zeros(m, dtype=torch.bool, device=device) bool1 = torch.zeros(m, dtype=torch.bool, device=device)
bool2 = torch.zeros(n, dtype=torch.bool, device=device) bool2 = torch.zeros(n, dtype=torch.bool, device=device)
bool1.scatter_(0, idx1, 1) bool1.scatter_(0, idx1, 1)
......
...@@ -2,9 +2,9 @@ from itertools import product ...@@ -2,9 +2,9 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import transpose
from .utils import dtypes, devices, tensor from torch_sparse import transpose
from torch_sparse.testing import devices, dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
from typing import Any
import torch import torch
import torch_scatter import torch_scatter
from packaging import version from packaging import version
...@@ -13,8 +15,8 @@ if version.parse(torch_scatter.__version__) > version.parse("2.0.9"): ...@@ -13,8 +15,8 @@ if version.parse(torch_scatter.__version__) > version.parse("2.0.9"):
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] devices += [torch.device('cuda:0')]
def tensor(x, dtype, device): def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment