Unverified Commit 32bee646 authored by Piotr Chmiel's avatar Piotr Chmiel Committed by GitHub
Browse files

Add `batch_size` argument for `fps`, `knn`, `radius` functions (#175)



* Add batch_size argument for fps, knn, radius functions.

It can be used to avoid additional calculations if a user is using
fixed-size batch.

* update

* update

* update

* update

---------
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent 84bbb714
......@@ -31,10 +31,11 @@ jobs:
- name: Install main package
run: |
pip install -e .[test]
python setup.py develop
- name: Run test-suite
run: |
pip install pytest pytest-cov
pytest --cov --cov-report=xml
- name: Upload coverage
......
from typing import Optional
from typing import Optional, Union
import torch
from torch import Tensor
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[torch.Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -32,10 +38,11 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
......@@ -57,6 +64,7 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
if batch is not None:
assert src.size(0) == batch.numel()
if batch_size is None:
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long)
......
......@@ -4,10 +4,16 @@ import torch
@torch.jit.script
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
def knn(
x: torch.Tensor,
y: torch.Tensor,
k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
num_workers: int = 1) -> torch.Tensor:
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
......@@ -31,6 +37,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -52,6 +60,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
......@@ -59,6 +68,7 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
......@@ -74,9 +84,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
@torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False, num_workers: int = 1) -> torch.Tensor:
def knn_graph(
x: torch.Tensor,
k: int,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
flow: str = 'source_to_target',
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
......@@ -98,6 +115,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -113,7 +132,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
......
......@@ -4,10 +4,16 @@ import torch
@torch.jit.script
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
def radius(
x: torch.Tensor,
y: torch.Tensor,
r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
......@@ -33,6 +39,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
......@@ -52,6 +60,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
......@@ -59,9 +68,11 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
......@@ -74,10 +85,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
@torch.jit.script
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, flow: str = 'source_to_target',
num_workers: int = 1) -> torch.Tensor:
def radius_graph(
x: torch.Tensor,
r: float,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
......@@ -101,6 +118,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
......@@ -117,7 +136,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment