nearest.py 4.64 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
from typing import Optional

import scipy.cluster
limm's avatar
limm committed
4
import torch
quyuanhao123's avatar
quyuanhao123 committed
5
6


limm's avatar
limm committed
7
8
9
10
11
12
def nearest(
    x: torch.Tensor,
    y: torch.Tensor,
    batch_x: Optional[torch.Tensor] = None,
    batch_y: Optional[torch.Tensor] = None,
) -> torch.Tensor:
quyuanhao123's avatar
quyuanhao123 committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    r"""Clusters points in :obj:`x` together which are nearest to a given query
    point in :obj:`y`.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. :obj:`batch_x` needs to be sorted.
            (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. :obj:`batch_y` needs to be sorted.
            (default: :obj:`None`)

    :rtype: :class:`LongTensor`

    .. code-block:: python

        import torch
        from torch_cluster import nearest

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        cluster = nearest(x, y, batch_x, batch_y)
    """

    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y
    assert x.size(1) == y.size(1)

limm's avatar
limm committed
48
49
50
51
52
    if batch_x is not None and (batch_x[1:] - batch_x[:-1] < 0).any():
        raise ValueError("'batch_x' is not sorted")
    if batch_y is not None and (batch_y[1:] - batch_y[:-1] < 0).any():
        raise ValueError("'batch_y' is not sorted")

quyuanhao123's avatar
quyuanhao123 committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    if x.is_cuda:
        if batch_x is not None:
            assert x.size(0) == batch_x.numel()
            batch_size = int(batch_x.max()) + 1

            deg = x.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))

            ptr_x = deg.new_zeros(batch_size + 1)
            torch.cumsum(deg, 0, out=ptr_x[1:])
        else:
            ptr_x = torch.tensor([0, x.size(0)], device=x.device)

        if batch_y is not None:
            assert y.size(0) == batch_y.numel()
            batch_size = int(batch_y.max()) + 1

            deg = y.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))

            ptr_y = deg.new_zeros(batch_size + 1)
            torch.cumsum(deg, 0, out=ptr_y[1:])
        else:
            ptr_y = torch.tensor([0, y.size(0)], device=y.device)

limm's avatar
limm committed
78
79
80
81
82
83
84
85
        # If an instance in `batch_x` is non-empty, it must be non-empty in
        # `batch_y `as well:
        nonempty_ptr_x = (ptr_x[1:] - ptr_x[:-1]) > 0
        nonempty_ptr_y = (ptr_y[1:] - ptr_y[:-1]) > 0
        if not torch.equal(nonempty_ptr_x, nonempty_ptr_y):
            raise ValueError("Some batch indices occur in 'batch_x' "
                             "that do not occur in 'batch_y'")

quyuanhao123's avatar
quyuanhao123 committed
86
        return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)
limm's avatar
limm committed
87

quyuanhao123's avatar
quyuanhao123 committed
88
    else:
limm's avatar
limm committed
89
90
91
92
93
94

        if batch_x is None and batch_y is not None:
            batch_x = x.new_zeros(x.size(0), dtype=torch.long)
        if batch_y is None and batch_x is not None:
            batch_y = y.new_zeros(y.size(0), dtype=torch.long)

quyuanhao123's avatar
quyuanhao123 committed
95
96
        # Translate and rescale x and y to [0, 1].
        if batch_x is not None and batch_y is not None:
limm's avatar
limm committed
97
98
99
100
101
102
103
104
            # If an instance in `batch_x` is non-empty, it must be non-empty in
            # `batch_y `as well:
            unique_batch_x = batch_x.unique_consecutive()
            unique_batch_y = batch_y.unique_consecutive()
            if not torch.equal(unique_batch_x, unique_batch_y):
                raise ValueError("Some batch indices occur in 'batch_x' "
                                 "that do not occur in 'batch_y'")

quyuanhao123's avatar
quyuanhao123 committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            assert x.dim() == 2 and batch_x.dim() == 1
            assert y.dim() == 2 and batch_y.dim() == 1
            assert x.size(0) == batch_x.size(0)
            assert y.size(0) == batch_y.size(0)

            min_xy = min(x.min().item(), y.min().item())
            x, y = x - min_xy, y - min_xy

            max_xy = max(x.max().item(), y.max().item())
            x.div_(max_xy)
            y.div_(max_xy)

            # Concat batch/features to ensure no cross-links between examples.
            D = x.size(-1)
            x = torch.cat([x, 2 * D * batch_x.view(-1, 1).to(x.dtype)], -1)
            y = torch.cat([y, 2 * D * batch_y.view(-1, 1).to(y.dtype)], -1)

        return torch.from_numpy(
            scipy.cluster.vq.vq(x.detach().cpu(),
                                y.detach().cpu())[0]).to(torch.long)