"official/vision/beta/evaluation/coco_utils.py" did not exist on "a5b38a7209ea962b03510c6f2540dcc6601a51e0"
fps.py 3.97 KB
Newer Older
1
from typing import List, Optional, Union
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
update  
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch_cluster.typing


@torch.jit._overload  # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr):  # noqa
    # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor  # noqa
    pass  # pragma: no cover


@torch.jit._overload  # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr):  # noqa
    # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor  # noqa
    pass  # pragma: no cover

rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
@torch.jit._overload  # noqa
22
23
def fps(src, batch, ratio, random_start, batch_size, ptr):  # noqa
    # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor  # noqa
rusty1s's avatar
rusty1s committed
24
    pass  # pragma: no cover
Duc's avatar
Duc committed
25
26


rusty1s's avatar
rusty1s committed
27
@torch.jit._overload  # noqa
28
29
def fps(src, batch, ratio, random_start, batch_size, ptr):  # noqa
    # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor  # noqa
rusty1s's avatar
rusty1s committed
30
    pass  # pragma: no cover
Duc's avatar
Duc committed
31
32


33
34
35
def fps(  # noqa
    src: torch.Tensor,
    batch: Optional[Tensor] = None,
36
    ratio: Optional[Union[Tensor, float]] = None,
37
38
    random_start: bool = True,
    batch_size: Optional[int] = None,
39
    ptr: Optional[Union[Tensor, List[int]]] = None,
40
):
rusty1s's avatar
rusty1s committed
41
42
43
44
    r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
    Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
    most distant point with regard to the rest points.
rusty1s's avatar
rusty1s committed
45
46

    Args:
rusty1s's avatar
update  
rusty1s committed
47
        src (Tensor): Point feature matrix
rusty1s's avatar
rusty1s committed
48
49
50
51
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
52
53
        ratio (float or Tensor, optional): Sampling ratio.
            (default: :obj:`0.5`)
rusty1s's avatar
rusty1s committed
54
55
        random_start (bool, optional): If set to :obj:`False`, use the first
            node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
56
57
        batch_size (int, optional): The number of examples :math:`B`.
            Automatically calculated if not given. (default: :obj:`None`)
58
59
60
61
        ptr (torch.Tensor or [int], optional): If given, batch assignment will
            be determined based on boundaries in CSR representation, *e.g.*,
            :obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
            (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
docs  
rusty1s committed
63
64
    :rtype: :class:`LongTensor`

rusty1s's avatar
update  
rusty1s committed
65
    .. code-block:: python
rusty1s's avatar
rusty1s committed
66
67
68
69

        import torch
        from torch_cluster import fps

rusty1s's avatar
update  
rusty1s committed
70
71
72
        src = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        index = fps(src, batch, ratio=0.5)
rusty1s's avatar
rusty1s committed
73
    """
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
80
81
    r: Optional[Tensor] = None
    if ratio is None:
        r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
    elif isinstance(ratio, float):
        r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
    else:
        r = ratio
    assert r is not None
82

83
84
85
86
87
88
89
90
91
92
93
    if ptr is not None:
        if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
            return torch.ops.torch_cluster.fps_ptr_list(
                src, ptr, r, random_start)

        if isinstance(ptr, list):
            return torch.ops.torch_cluster.fps(
                src, torch.tensor(ptr, device=src.device), r, random_start)
        else:
            return torch.ops.torch_cluster.fps(src, ptr, r, random_start)

rusty1s's avatar
update  
rusty1s committed
94
    if batch is not None:
rusty1s's avatar
rusty1s committed
95
        assert src.size(0) == batch.numel()
96
97
        if batch_size is None:
            batch_size = int(batch.max()) + 1
rusty1s's avatar
rusty1s committed
98

rusty1s's avatar
update  
rusty1s committed
99
100
        deg = src.new_zeros(batch_size, dtype=torch.long)
        deg.scatter_add_(0, batch, torch.ones_like(batch))
rusty1s's avatar
typos  
rusty1s committed
101

102
103
        ptr_vec = deg.new_zeros(batch_size + 1)
        torch.cumsum(deg, 0, out=ptr_vec[1:])
rusty1s's avatar
rusty1s committed
104
    else:
105
        ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
rusty1s's avatar
rusty1s committed
106

107
    return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)