"docker/Dockerfile" did not exist on "1045640da499b92ddc4c0f75c7c1ca5fba8929ce"
nearest.py 3.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import scipy.cluster
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
8
9
def nearest(x: torch.Tensor, y: torch.Tensor,
            batch_x: Optional[torch.Tensor] = None,
            batch_y: Optional[torch.Tensor] = None) -> torch.Tensor:
rusty1s's avatar
typo  
rusty1s committed
10
    r"""Clusters points in :obj:`x` together which are nearest to a given query
rusty1s's avatar
rusty1s committed
11
    point in :obj:`y`.
rusty1s's avatar
docs  
rusty1s committed
12
13

    Args:
rusty1s's avatar
rusty1s committed
14
15
16
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        y (Tensor): Node feature matrix
Vadim Bereznyuk's avatar
typos  
Vadim Bereznyuk committed
17
            :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
        batch_x (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        batch_y (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
docs  
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
    :rtype: :class:`LongTensor`

    .. code-block:: python
rusty1s's avatar
rusty1s committed
28
29
30
31

        import torch
        from torch_cluster import nearest

rusty1s's avatar
rusty1s committed
32
33
34
35
36
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        cluster = nearest(x, y, batch_x, batch_y)
rusty1s's avatar
docs  
rusty1s committed
37
38
    """

rusty1s's avatar
rusty1s committed
39
40
41
    x = x.view(-1, 1) if x.dim() == 1 else x
    y = y.view(-1, 1) if y.dim() == 1 else y

rusty1s's avatar
rusty1s committed
42
    if x.is_cuda:
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        if batch_x is not None:
            assert x.size(0) == batch_x.numel()
            batch_size = int(batch_x.max()) + 1

            deg = x.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_x, torch.ones_like(batch_x))

            ptr_x = deg.new_zeros(batch_size + 1)
            deg.cumsum(0, out=ptr_x[1:])
        else:
            ptr_x = torch.tensor([0, x.size(0)], device=x.device)

        if batch_y is not None:
            assert y.size(0) == batch_y.numel()
            batch_size = int(batch_y.may()) + 1

            deg = y.new_zeros(batch_size, dtype=torch.long)
            deg.scatter_add_(0, batch_y, torch.ones_like(batch_y))

            ptr_y = deg.new_zeros(batch_size + 1)
            deg.cumsum(0, out=ptr_y[1:])
        else:
            ptr_y = torch.tensor([0, y.size(0)], device=y.device)

        return torch.ops.torch_cluster.nearest(x, y, ptr_x, ptr_y)
    else:
        if batch_x is None:
            batch_x = x.new_zeros(x.size(0), dtype=torch.long)

        if batch_y is None:
            batch_y = y.new_zeros(y.size(0), dtype=torch.long)

        assert x.dim() == 2 and batch_x.dim() == 1
        assert y.dim() == 2 and batch_y.dim() == 1
        assert x.size(1) == y.size(1)
        assert x.size(0) == batch_x.size(0)
        assert y.size(0) == batch_y.size(0)

        # Translate and rescale x and y to [0, 1].
        min_xy = min(x.min().item(), y.min().item())
        x, y = x - min_xy, y - min_xy

        max_xy = max(x.max().item(), y.max().item())
        x.div_(max_xy)
        y.div_(max_xy)

        # Concat batch/features to ensure no cross-links between examples.
        x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1)
        y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1)

        return torch.from_numpy(
            scipy.cluster.vq.vq(x.detach().cpu(),
                                y.detach().cpu())[0]).to(torch.long)