radius.py 3.18 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
import torch

if torch.cuda.is_available():
    import radius_cuda


def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
rusty1s's avatar
docs  
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    """Finds for each element in `y` all points in `x` within distance `r`.

    Args:
        x (Tensor): D-dimensional point features.
        y (Tensor): D-dimensional point features.
        r (float): The radius.
        batch_x (LongTensor, optional): Vector that maps each point to its
            example identifier. If :obj:`None`, all points belong to the same
            example. If not :obj:`None`, points in the same example need to
            have contiguous memory layout and :obj:`batch` needs to be
            ascending. (default: :obj:`None`)
        batch_y (LongTensor, optional): See `batch_x` (default: :obj:`None`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            return for each element in `y`. (default: :obj:`32`)

    :rtype: :class:`LongTensor`

    Examples::

        >>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
rusty1s's avatar
rusty1s committed
28
        >>> batch_x = torch.tensor([0, 0, 0, 0])
rusty1s's avatar
docs  
rusty1s committed
29
        >>> y = torch.Tensor([[-1, 0], [1, 0]])
rusty1s's avatar
rusty1s committed
30
31
        >>> batch_x = torch.tensor([0, 0])
        >>> assign_index = radius(x, y, 1.5, batch_x, batch_y)
rusty1s's avatar
docs  
rusty1s committed
32
    """
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    if batch_x is None:
        batch_x = x.new_zeros(x.size(0), dtype=torch.long)

    if batch_y is None:
        batch_y = y.new_zeros(y.size(0), dtype=torch.long)

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y

    assert x.is_cuda
    assert x.dim() == 2 and batch_x.dim() == 1
    assert y.dim() == 2 and batch_y.dim() == 1
    assert x.size(1) == y.size(1)
    assert x.size(0) == batch_x.size(0)
    assert y.size(0) == batch_y.size(0)

    op = radius_cuda.radius if x.is_cuda else None
    assign_index = op(x, y, r, batch_x, batch_y, max_num_neighbors)

    return assign_index


rusty1s's avatar
rusty1s committed
56
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
rusty1s's avatar
docs  
rusty1s committed
57
58
59
60
61
62
63
64
65
66
    """Finds for each element in `x` all points in `x` within distance `r`.

    Args:
        x (Tensor): D-dimensional point features.
        r (float): The radius.
        batch (LongTensor, optional): Vector that maps each point to its
            example identifier. If :obj:`None`, all points belong to the same
            example. If not :obj:`None`, points in the same example need to
            have contiguous memory layout and :obj:`batch` needs to be
            ascending. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
67
68
        loop (bool, optional): If :obj:`True`, the graph will contain
            self-loops. (default: :obj:`False`)
rusty1s's avatar
docs  
rusty1s committed
69
70
71
72
73
74
75
76
        max_num_neighbors (int, optional): The maximum number of neighbors to
            return for each element in `y`. (default: :obj:`32`)

    :rtype: :class:`LongTensor`

    Examples::

        >>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
rusty1s's avatar
rusty1s committed
77
78
        >>> batch = torch.tensor([0, 0, 0, 0])
        >>> edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
rusty1s's avatar
docs  
rusty1s committed
79
80
    """

rusty1s's avatar
rusty1s committed
81
82
    edge_index = radius(x, x, r, batch, batch, max_num_neighbors + 1)
    row, col = edge_index
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
    if not loop:
        row, col = edge_index
        mask = row != col
        row, col = row[mask], col[mask]
        edge_index = torch.stack([row, col], dim=0)
    return edge_index