radius.py 5.48 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import scipy.spatial
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
@torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
9
    if col.size(0) > count:
rusty1s's avatar
rusty1s committed
10
        col = col[torch.randperm(col.size(0), dtype=torch.long)][:count]
11
12
13
    return col


rusty1s's avatar
rusty1s committed
14
15
16
17
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
           batch_x: Optional[torch.Tensor] = None,
           batch_y: Optional[torch.Tensor] = None,
           max_num_neighbors: int = 32) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
18
19
    r"""Finds for each element in :obj:`y` all points in :obj:`x` within
    distance :obj:`r`.
rusty1s's avatar
docs  
rusty1s committed
20
21

    Args:
rusty1s's avatar
rusty1s committed
22
23
24
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
Vadim Bereznyuk's avatar
typos  
Vadim Bereznyuk committed
25
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
rusty1s's avatar
docs  
rusty1s committed
26
        r (float): The radius.
rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
docs  
rusty1s committed
33
        max_num_neighbors (int, optional): The maximum number of neighbors to
rusty1s's avatar
rusty1s committed
34
            return for each element in :obj:`y`. (default: :obj:`32`)
rusty1s's avatar
docs  
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
    .. code-block:: python
rusty1s's avatar
rusty1s committed
37
38
39
40

        import torch
        from torch_cluster import radius

rusty1s's avatar
rusty1s committed
41
42
43
44
45
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        assign_index = radius(x, y, 1.5, batch_x, batch_y)
rusty1s's avatar
docs  
rusty1s committed
46
    """
rusty1s's avatar
rusty1s committed
47
48
49
50

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y

rusty1s's avatar
rusty1s committed
51
    if x.is_cuda:
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
        if batch_x is not None:
            assert x.size(0) == batch_x.numel()
            batch_size = int(batch_x.max()) + 1

            deg = x.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))

            ptr_x = deg.new_zeros(batch_size + 1)
rusty1s's avatar
fix  
rusty1s committed
60
            torch.cumsum(deg, 0, out=ptr_x[1:])
rusty1s's avatar
rusty1s committed
61
62
63
64
65
        else:
            ptr_x = torch.tensor([0, x.size(0)], device=x.device)

        if batch_y is not None:
            assert y.size(0) == batch_y.numel()
rusty1s's avatar
fix  
rusty1s committed
66
            batch_size = int(batch_y.max()) + 1
rusty1s's avatar
rusty1s committed
67
68
69
70

            deg = y.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))
            ptr_y = deg.new_zeros(batch_size + 1)
rusty1s's avatar
fix  
rusty1s committed
71
            torch.cumsum(deg, 0, out=ptr_y[1:])
rusty1s's avatar
rusty1s committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        else:
            ptr_y = torch.tensor([0, y.size(0)], device=y.device)

        return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
                                              max_num_neighbors)
    else:
        if batch_x is None:
            batch_x = x.new_zeros(x.size(0), dtype=torch.long)

        if batch_y is None:
            batch_y = y.new_zeros(y.size(0), dtype=torch.long)

        assert x.dim() == 2 and batch_x.dim() == 1
        assert y.dim() == 2 and batch_y.dim() == 1
        assert x.size(1) == y.size(1)
        assert x.size(0) == batch_x.size(0)
        assert y.size(0) == batch_y.size(0)

        x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
        y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)

        tree = scipy.spatial.cKDTree(x.detach().numpy())
        col = tree.query_ball_point(y.detach().numpy(), r)
        col = [sample(torch.tensor(c), max_num_neighbors) for c in col]
        row = [torch.full_like(c, i) for i, c in enumerate(col)]
        row, col = torch.cat(row, dim=0), torch.cat(col, dim=0)
        mask = col < int(tree.n)
        return torch.stack([row[mask], col[mask]], dim=0)


def radius_graph(x: torch.Tensor, r: float,
                 batch: Optional[torch.Tensor] = None, loop: bool = False,
                 max_num_neighbors: int = 32,
                 flow: str = 'source_to_target') -> torch.Tensor:
rusty1s's avatar
rusty1s committed
106
    r"""Computes graph edges to all points within a given distance.
rusty1s's avatar
docs  
rusty1s committed
107
108

    Args:
rusty1s's avatar
rusty1s committed
109
110
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
rusty1s's avatar
docs  
rusty1s committed
111
        r (float): The radius.
rusty1s's avatar
rusty1s committed
112
113
114
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
115
116
        loop (bool, optional): If :obj:`True`, the graph will contain
            self-loops. (default: :obj:`False`)
rusty1s's avatar
docs  
rusty1s committed
117
        max_num_neighbors (int, optional): The maximum number of neighbors to
rusty1s's avatar
rusty1s committed
118
            return for each element in :obj:`y`. (default: :obj:`32`)
rusty1s's avatar
rusty1s committed
119
120
121
        flow (string, optional): The flow direction when using in combination
            with message passing (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
rusty1s's avatar
docs  
rusty1s committed
122
123
124

    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
125
    .. code-block:: python
rusty1s's avatar
rusty1s committed
126
127
128
129

        import torch
        from torch_cluster import radius_graph

rusty1s's avatar
rusty1s committed
130
131
132
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
rusty1s's avatar
docs  
rusty1s committed
133
134
    """

rusty1s's avatar
rusty1s committed
135
    assert flow in ['source_to_target', 'target_to_source']
136
137
    row, col = radius(x, x, r, batch, batch,
                      max_num_neighbors if loop else max_num_neighbors + 1)
rusty1s's avatar
rusty1s committed
138
    row, col = (col, row) if flow == 'source_to_target' else (row, col)
rusty1s's avatar
rusty1s committed
139
140
141
    if not loop:
        mask = row != col
        row, col = row[mask], col[mask]
rusty1s's avatar
rusty1s committed
142
    return torch.stack([row, col], dim=0)