three_nn.py 1.54 KB
Newer Older
tripleMu's avatar
tripleMu committed
1
from typing import Any, Tuple
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])


class ThreeNN(Function):
    """Find the top-3 nearest neighbors of the target set from the source set.

    Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
    for more details.
    """

    @staticmethod
tripleMu's avatar
tripleMu committed
19
    def forward(ctx: Any, target: torch.Tensor,
20
21
22
                source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
23
            target (torch.Tensor): shape (B, N, 3), points set that needs to
24
                find the nearest neighbors.
25
            source (torch.Tensor): shape (B, M, 3), points set that is used
26
27
28
                to find the nearest neighbors of points in target set.

        Returns:
29
30
            torch.Tensor: shape (B, N, 3), L2 distance of each point in target
            set to their corresponding top three nearest neighbors.
31
32
33
34
35
36
        """
        target = target.contiguous()
        source = source.contiguous()

        B, N, _ = target.size()
        m = source.size(1)
37
38
        dist2 = torch.FloatTensor(B, N, 3).to(target.device)
        idx = torch.IntTensor(B, N, 3).to(target.device)
39

pc's avatar
pc committed
40
41
42
        ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
        if torch.__version__ != 'parrots':
            ctx.mark_non_differentiable(idx)
43
44
45
46
47
48
49
50
51

        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply