radius.py 5.45 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import scipy.spatial
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
@torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
9
10
11
12
13
    if col.size(0) > count:
        col = col[torch.randperm(col.size(0))][:count]
    return col


rusty1s's avatar
rusty1s committed
14
15
16
17
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
           batch_x: Optional[torch.Tensor] = None,
           batch_y: Optional[torch.Tensor] = None,
           max_num_neighbors: int = 32) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
18
19
    r"""Finds for each element in :obj:`y` all points in :obj:`x` within
    distance :obj:`r`.
rusty1s's avatar
docs  
rusty1s committed
20
21

    Args:
rusty1s's avatar
rusty1s committed
22
23
24
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
Vadim Bereznyuk's avatar
typos  
Vadim Bereznyuk committed
25
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
rusty1s's avatar
docs  
rusty1s committed
26
        r (float): The radius.
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
docs  
rusty1s committed
33
        max_num_neighbors (int, optional): The maximum number of neighbors to
rusty1s's avatar
rusty1s committed
34
            return for each element in :obj:`y`. (default: :obj:`32`)
rusty1s's avatar
docs  
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
    .. code-block:: python
rusty1s's avatar
rusty1s committed
37
38
39
40

        import torch
        from torch_cluster import radius

rusty1s's avatar
rusty1s committed
41
42
43
44
45
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        assign_index = radius(x, y, 1.5, batch_x, batch_y)
rusty1s's avatar
docs  
rusty1s committed
46
    """
rusty1s's avatar
rusty1s committed
47
48
49
50

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y

rusty1s's avatar
rusty1s committed
51
    if x.is_cuda:
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        if batch_x is not None:
            assert x.size(0) == batch_x.numel()
            batch_size = int(batch_x.max()) + 1

            deg = x.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))

            ptr_x = deg.new_zeros(batch_size + 1)
            deg.cumsum(0, out=ptr_x[1:])
        else:
            ptr_x = torch.tensor([0, x.size(0)], device=x.device)

        if batch_y is not None:
            assert y.size(0) == batch_y.numel()
            batch_size = int(batch_y.may()) + 1

            deg = y.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))

            ptr_y = deg.new_zeros(batch_size + 1)
            deg.cumsum(0, out=ptr_y[1:])
        else:
            ptr_y = torch.tensor([0, y.size(0)], device=y.device)

        return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
                                              max_num_neighbors)
    else:
        if batch_x is None:
            batch_x = x.new_zeros(x.size(0), dtype=torch.long)

        if batch_y is None:
            batch_y = y.new_zeros(y.size(0), dtype=torch.long)

        assert x.dim() == 2 and batch_x.dim() == 1
        assert y.dim() == 2 and batch_y.dim() == 1
        assert x.size(1) == y.size(1)
        assert x.size(0) == batch_x.size(0)
        assert y.size(0) == batch_y.size(0)

        x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
        y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)

        tree = scipy.spatial.cKDTree(x.detach().numpy())
        col = tree.query_ball_point(y.detach().numpy(), r)
        col = [sample(torch.tensor(c), max_num_neighbors) for c in col]
        row = [torch.full_like(c, i) for i, c in enumerate(col)]
        row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
        mask = col < int(tree.n)
        return torch.stack([row[mask], col[mask]], dim=0)


def radius_graph(x: torch.Tensor, r: float,
                 batch: Optional[torch.Tensor] = None, loop: bool = False,
                 max_num_neighbors: int = 32,
                 flow: str = 'source_to_target') -> torch.Tensor:
rusty1s's avatar
rusty1s committed
107
    r"""Computes graph edges to all points within a given distance.
rusty1s's avatar
docs  
rusty1s committed
108
109

    Args:
rusty1s's avatar
rusty1s committed
110
111
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
rusty1s's avatar
docs  
rusty1s committed
112
        r (float): The radius.
rusty1s's avatar
rusty1s committed
113
114
115
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
116
117
        loop (bool, optional): If :obj:`True`, the graph will contain
            self-loops. (default: :obj:`False`)
rusty1s's avatar
docs  
rusty1s committed
118
        max_num_neighbors (int, optional): The maximum number of neighbors to
rusty1s's avatar
rusty1s committed
119
            return for each element in :obj:`y`. (default: :obj:`32`)
rusty1s's avatar
rusty1s committed
120
121
122
        flow (string, optional): The flow direction when using in combination
            with message passing (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
rusty1s's avatar
docs  
rusty1s committed
123
124
125

    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
126
    .. code-block:: python
rusty1s's avatar
rusty1s committed
127
128
129
130

        import torch
        from torch_cluster import radius_graph

rusty1s's avatar
rusty1s committed
131
132
133
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
rusty1s's avatar
docs  
rusty1s committed
134
135
    """

rusty1s's avatar
rusty1s committed
136
    assert flow in ['source_to_target', 'target_to_source']
137
138
    row, col = radius(x, x, r, batch, batch,
                      max_num_neighbors if loop else max_num_neighbors + 1)
rusty1s's avatar
rusty1s committed
139
    row, col = (col, row) if flow == 'source_to_target' else (row, col)
rusty1s's avatar
rusty1s committed
140
141
142
    if not loop:
        mask = row != col
        row, col = row[mask], col[mask]
rusty1s's avatar
rusty1s committed
143
    return torch.stack([row, col], dim=0)