Unverified Commit 5bb8d17b authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

update (#154)

parent 6f222280
...@@ -4,8 +4,7 @@ import pytest ...@@ -4,8 +4,7 @@ import pytest
import torch import torch
from torch import Tensor from torch import Tensor
from torch_cluster import fps from torch_cluster import fps
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
@torch.jit.script @torch.jit.script
......
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import graclus_cluster from torch_cluster import graclus_cluster
from torch_cluster.testing import devices, dtypes, tensor
from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
...@@ -42,6 +41,9 @@ def assert_correct(row, col, cluster): ...@@ -42,6 +41,9 @@ def assert_correct(row, col, cluster):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster(test, dtype, device): def test_graclus_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
row = tensor(test['row'], torch.long, device) row = tensor(test['row'], torch.long, device)
col = tensor(test['col'], torch.long, device) col = tensor(test['col'], torch.long, device)
weight = tensor(test.get('weight'), dtype, device) weight = tensor(test.get('weight'), dtype, device)
......
from itertools import product from itertools import product
import pytest import pytest
import torch
from torch_cluster import grid_cluster from torch_cluster import grid_cluster
from torch_cluster.testing import devices, dtypes, tensor
from .utils import dtypes, devices, tensor
tests = [{ tests = [{
'pos': [2, 6], 'pos': [2, 6],
...@@ -28,6 +28,9 @@ tests = [{ ...@@ -28,6 +28,9 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_grid_cluster(test, dtype, device): def test_grid_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
pos = tensor(test['pos'], dtype, device) pos = tensor(test['pos'], dtype, device)
size = tensor(test['size'], dtype, device) size = tensor(test['size'], dtype, device)
start = tensor(test.get('start'), dtype, device) start = tensor(test.get('start'), dtype, device)
......
from itertools import product from itertools import product
import pytest import pytest
import torch
import scipy.spatial import scipy.spatial
import torch
from torch_cluster import knn, knn_graph from torch_cluster import knn, knn_graph
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index): def to_set(edge_index):
......
...@@ -3,8 +3,7 @@ from itertools import product ...@@ -3,8 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_cluster import nearest from torch_cluster import nearest
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
from itertools import product from itertools import product
import pytest import pytest
import torch
import scipy.spatial import scipy.spatial
import torch
from torch_cluster import radius, radius_graph from torch_cluster import radius, radius_graph
from torch_cluster.testing import devices, grad_dtypes, tensor
from .utils import grad_dtypes, devices, tensor
def to_set(edge_index): def to_set(edge_index):
......
import pytest import pytest
import torch import torch
from torch_cluster import random_walk from torch_cluster import random_walk
from torch_cluster.testing import devices, tensor
from .utils import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
...@@ -41,7 +40,10 @@ def test_rw_large_with_edge_indices(device): ...@@ -41,7 +40,10 @@ def test_rw_large_with_edge_indices(device):
walk_length = 10 walk_length = 10
node_seq, edge_seq = random_walk( node_seq, edge_seq = random_walk(
row, col, start, walk_length, row,
col,
start,
walk_length,
return_edge_indices=True, return_edge_indices=True,
) )
assert node_seq[:, 0].tolist() == start.tolist() assert node_seq[:, 0].tolist() == start.tolist()
...@@ -63,7 +65,10 @@ def test_rw_small_with_edge_indices(device): ...@@ -63,7 +65,10 @@ def test_rw_small_with_edge_indices(device):
walk_length = 4 walk_length = 4
node_seq, edge_seq = random_walk( node_seq, edge_seq = random_walk(
row, col, start, walk_length, row,
col,
start,
walk_length,
num_nodes=3, num_nodes=3,
return_edge_indices=True, return_edge_indices=True,
) )
......
from typing import Any
import torch import torch
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double, dtypes = [
torch.int, torch.long] torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.half, torch.float, torch.double] grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')] devices += [torch.device('cuda:0')]
def tensor(x, dtype, device): def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device) return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment