Unverified Commit 5bb8d17b authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

update (#154)

parent 6f222280
......@@ -4,8 +4,7 @@ import pytest
import torch
from torch import Tensor
from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
@torch.jit.script
......
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import graclus_cluster
from .utils import dtypes, devices, tensor
from torch_cluster.testing import devices, dtypes, tensor
tests = [{
'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
......@@ -42,6 +41,9 @@ def assert_correct(row, col, cluster):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_graclus_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
row = tensor(test['row'], torch.long, device)
col = tensor(test['col'], torch.long, device)
weight = tensor(test.get('weight'), dtype, device)
......
from itertools import product
import pytest
import torch
from torch_cluster import grid_cluster
from .utils import dtypes, devices, tensor
from torch_cluster.testing import devices, dtypes, tensor
tests = [{
'pos': [2, 6],
......@@ -28,6 +28,9 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_grid_cluster(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
pos = tensor(test['pos'], dtype, device)
size = tensor(test['size'], dtype, device)
start = tensor(test.get('start'), dtype, device)
......
from itertools import product
import pytest
import torch
import scipy.spatial
import torch
from torch_cluster import knn, knn_graph
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
def to_set(edge_index):
......
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_cluster import nearest
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
from itertools import product
import pytest
import torch
import scipy.spatial
import torch
from torch_cluster import radius, radius_graph
from .utils import grad_dtypes, devices, tensor
from torch_cluster.testing import devices, grad_dtypes, tensor
def to_set(edge_index):
......
import pytest
import torch
from torch_cluster import random_walk
from .utils import devices, tensor
from torch_cluster.testing import devices, tensor
@pytest.mark.parametrize('device', devices)
......@@ -41,7 +40,10 @@ def test_rw_large_with_edge_indices(device):
walk_length = 10
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
row,
col,
start,
walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()
......@@ -63,7 +65,10 @@ def test_rw_small_with_edge_indices(device):
walk_length = 4
node_seq, edge_seq = random_walk(
row, col, start, walk_length,
row,
col,
start,
walk_length,
num_nodes=3,
return_edge_indices=True,
)
......
from typing import Any
import torch
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
torch.int, torch.long]
dtypes = [
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment