knn.py 5.76 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import scipy.spatial
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
10
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
        batch_x: Optional[torch.Tensor] = None,
        batch_y: Optional[torch.Tensor] = None,
        cosine: bool = False) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
11
12
    r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
    :obj:`x`.
rusty1s's avatar
rusty1s committed
13
14

    Args:
rusty1s's avatar
rusty1s committed
15
16
17
18
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
rusty1s's avatar
rusty1s committed
19
        k (int): The number of neighbors.
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
26
27
28
        cosine (boolean, optional): If :obj:`True`, will use the cosine
            distance instead of euclidean distance to find nearest neighbors.
            (default: :obj:`False`)
rusty1s's avatar
rusty1s committed
29
30
31

    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
32
    .. code-block:: python
rusty1s's avatar
rusty1s committed
33
34
35
36

        import torch
        from torch_cluster import knn

rusty1s's avatar
rusty1s committed
37
38
39
40
41
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_x = torch.tensor([0, 0])
        assign_index = knn(x, y, 2, batch_x, batch_y)
rusty1s's avatar
rusty1s committed
42
43
44
45
46
    """

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y

rusty1s's avatar
rusty1s committed
47
    if x.is_cuda:
rusty1s's avatar
rusty1s committed
48
49
50
51
52
53
54
55
        if batch_x is not None:
            assert x.size(0) == batch_x.numel()
            batch_size = int(batch_x.max()) + 1

            deg = x.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))

            ptr_x = deg.new_zeros(batch_size + 1)
rusty1s's avatar
fix  
rusty1s committed
56
            torch.cumsum(deg, 0, out=ptr_x[1:])
rusty1s's avatar
rusty1s committed
57
58
59
60
61
        else:
            ptr_x = torch.tensor([0, x.size(0)], device=x.device)

        if batch_y is not None:
            assert y.size(0) == batch_y.numel()
rusty1s's avatar
fix  
rusty1s committed
62
            batch_size = int(batch_y.max()) + 1
rusty1s's avatar
rusty1s committed
63
64
65
66
67

            deg = y.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))

            ptr_y = deg.new_zeros(batch_size + 1)
rusty1s's avatar
fix  
rusty1s committed
68
            torch.cumsum(deg, 0, out=ptr_y[1:])
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        else:
            ptr_y = torch.tensor([0, y.size(0)], device=y.device)

        return torch.ops.torch_cluster.knn(x, y, ptr_x, ptr_y, k, cosine)
    else:
        if batch_x is None:
            batch_x = x.new_zeros(x.size(0), dtype=torch.long)

        if batch_y is None:
            batch_y = y.new_zeros(y.size(0), dtype=torch.long)

        assert x.dim() == 2 and batch_x.dim() == 1
        assert y.dim() == 2 and batch_y.dim() == 1
        assert x.size(1) == y.size(1)
        assert x.size(0) == batch_x.size(0)
        assert y.size(0) == batch_y.size(0)

        if cosine:
            raise NotImplementedError('`cosine` argument not supported on CPU')

        # Translate and rescale x and y to [0, 1].
        min_xy = min(x.min().item(), y.min().item())
        x, y = x - min_xy, y - min_xy

        max_xy = max(x.max().item(), y.max().item())
        x.div_(max_xy)
        y.div_(max_xy)

        # Concat batch/features to ensure no cross-links between examples.
        x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
        y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)

        tree = scipy.spatial.cKDTree(x.detach().numpy())
        dist, col = tree.query(y.detach().cpu(), k=k,
                               distance_upper_bound=x.size(1))
        dist = torch.from_numpy(dist).to(x.dtype)
        col = torch.from_numpy(col).to(torch.long)
        row = torch.arange(col.size(0), dtype=torch.long)
        row = row.view(-1, 1).repeat(1, k)
        mask = ~torch.isinf(dist).view(-1)
        row, col = row.view(-1)[mask], col.view(-1)[mask]

        return torch.stack([row, col], dim=0)


def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
              loop: bool = False, flow: str = 'source_to_target',
              cosine: bool = False) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
117
    r"""Computes graph edges to the nearest :obj:`k` points.
rusty1s's avatar
rusty1s committed
118
119

    Args:
rusty1s's avatar
rusty1s committed
120
121
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
rusty1s's avatar
rusty1s committed
122
        k (int): The number of neighbors.
rusty1s's avatar
rusty1s committed
123
124
125
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
126
127
        loop (bool, optional): If :obj:`True`, the graph will contain
            self-loops. (default: :obj:`False`)
rusty1s's avatar
rusty1s committed
128
129
130
        flow (string, optional): The flow direction when using in combination
            with message passing (:obj:`"source_to_target"` or
            :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
rusty1s's avatar
rusty1s committed
131
132
133
        cosine (boolean, optional): If :obj:`True`, will use the cosine
            distance instead of euclidean distance to find nearest neighbors.
            (default: :obj:`False`)
rusty1s's avatar
rusty1s committed
134
135
136

    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
137
    .. code-block:: python
rusty1s's avatar
rusty1s committed
138
139
140
141

        import torch
        from torch_cluster import knn_graph

rusty1s's avatar
rusty1s committed
142
143
144
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        edge_index = knn_graph(x, k=2, batch=batch, loop=False)
rusty1s's avatar
rusty1s committed
145
146
    """

rusty1s's avatar
rusty1s committed
147
    assert flow in ['source_to_target', 'target_to_source']
148
    row, col = knn(x, x, k if loop else k + 1, batch, batch, cosine=cosine)
rusty1s's avatar
rusty1s committed
149
    row, col = (col, row) if flow == 'source_to_target' else (row, col)
rusty1s's avatar
rusty1s committed
150
151
152
    if not loop:
        mask = row != col
        row, col = row[mask], col[mask]
rusty1s's avatar
rusty1s committed
153
    return torch.stack([row, col], dim=0)