test_sample.py 2.77 KB
Newer Older
1
2
3
4
import pytest
import torch
import numpy as np

rusty1s's avatar
rusty1s committed
5
6
from torch_cluster.sample import (sample_farthest, batch_slices,
                                  radius_query_edges)
7
8
9

from .utils import tensor, grad_dtypes, devices

rusty1s's avatar
rusty1s committed
10

11
12
13
@pytest.mark.parametrize('device', devices)
def test_batch_slices(device):
    # test sample case for correctness
rusty1s's avatar
rusty1s committed
14
15
    batch = tensor(
        [0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    slices, sizes = batch_slices(batch, sizes=True)
    slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist()

    assert slices == [0, 99, 100, 149, 150, 191]
    assert sizes == [100, 50, 42]


@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_fps(dtype):
    # test simple case for correctness
    batch = [0] * 10
    points = [[-1, -1, 0], [-1, 1, 0], [1, 1, 0], [1, -1, 0]]
    random_points = np.random.uniform(-1, 1, size=(6, 3))
    random_points[:, 2] = 0
    random_points = random_points.tolist()

    batch = tensor(batch, dtype=torch.long, device='cuda')
    pos = tensor(points + random_points, dtype=dtype, device='cuda')

rusty1s's avatar
rusty1s committed
37
    sample_farthest(batch, pos, num_sampled=4, index=True)
38
39

    # needs update since isin is missing (sort indices, then compare?)
rusty1s's avatar
rusty1s committed
40
41
    # assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'),
    # False).all().cpu().item() == 1
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    # test variable number of points for each element in a batch
    batch = [0] * 100 + [1] * 50
    points1 = np.random.uniform(-1, 1, size=(100, 3)).tolist()
    points2 = np.random.uniform(-1, 1, size=(50, 3)).tolist()

    batch = tensor(batch, dtype=torch.long, device='cuda')
    pos = tensor(points1 + points2, dtype=dtype, device='cuda')

    mask = sample_farthest(batch, pos, num_sampled=75, index=False)
    assert mask[batch == 0].sum().item() == 75
    assert mask[batch == 1].sum().item() == 50


@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype', grad_dtypes)
def test_radius_edges(dtype):
    batch = [0] * 100 + [1] * 50 + [2] * 42
    points = np.random.uniform(-1, 1, size=(192, 3))

    query_batch = [0] * 10 + [1] * 15 + [2] * 20
    query_points = np.random.uniform(-1, 1, size=(45, 3))

    radius = 0.5

    batch = tensor(batch, dtype=torch.long, device='cuda')
    query_batch = tensor(query_batch, dtype=torch.long, device='cuda')
    pos = tensor(points, dtype=dtype, device='cuda')
    query_pos = tensor(query_points, dtype=dtype, device='cuda')

rusty1s's avatar
rusty1s committed
72
73
74
75
76
77
78
    edge_index = radius_query_edges(
        batch,
        pos,
        query_batch,
        query_pos,
        radius=radius,
        max_num_neighbors=128)
79
80
81
    row, col = edge_index
    dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1)
    assert (dist <= radius).all().item()