Unverified Commit 32bee646 authored by Piotr Chmiel's avatar Piotr Chmiel Committed by GitHub
Browse files

Add `batch_size` argument for `fps`, `knn`, `radius` functions (#175)



* Add batch_size argument for fps, knn, radius functions.

It can be used to avoid additional calculations if a user is using
fixed-size batch.

* update

* update

* update

* update

---------
Co-authored-by: default avatarrusty1s <matthias.fey@tu-dortmund.de>
parent 84bbb714
...@@ -31,10 +31,11 @@ jobs: ...@@ -31,10 +31,11 @@ jobs:
- name: Install main package - name: Install main package
run: | run: |
pip install -e .[test] python setup.py develop
- name: Run test-suite - name: Run test-suite
run: | run: |
pip install pytest pytest-cov
pytest --cov --cov-report=xml pytest --cov --cov-report=xml
- name: Upload coverage - name: Upload coverage
......
from typing import Optional from typing import Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover pass # pragma: no cover
@torch.jit._overload # noqa @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover pass # pragma: no cover
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[torch.Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -32,10 +38,11 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa ...@@ -32,10 +38,11 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`) (default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. code-block:: python .. code-block:: python
import torch import torch
...@@ -57,7 +64,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa ...@@ -57,7 +64,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1 if batch_size is None:
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long) deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch)) deg.scatter_add_(0, batch, torch.ones_like(batch))
......
...@@ -4,10 +4,16 @@ import torch ...@@ -4,10 +4,16 @@ import torch
@torch.jit.script @torch.jit.script
def knn(x: torch.Tensor, y: torch.Tensor, k: int, def knn(
batch_x: Optional[torch.Tensor] = None, x: torch.Tensor,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False, y: torch.Tensor,
num_workers: int = 1) -> torch.Tensor: k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`. :obj:`x`.
...@@ -31,6 +37,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -31,6 +37,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -52,13 +60,15 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -52,13 +60,15 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
batch_size = 1 if batch_size is None:
if batch_x is not None: batch_size = 1
assert x.size(0) == batch_x.numel() if batch_x is not None:
batch_size = int(batch_x.max()) + 1 assert x.size(0) == batch_x.numel()
if batch_y is not None: batch_size = int(batch_x.max()) + 1
assert y.size(0) == batch_y.numel() if batch_y is not None:
batch_size = max(batch_size, int(batch_y.max()) + 1) assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None
...@@ -74,9 +84,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, ...@@ -74,9 +84,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
@torch.jit.script @torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, def knn_graph(
loop: bool = False, flow: str = 'source_to_target', x: torch.Tensor,
cosine: bool = False, num_workers: int = 1) -> torch.Tensor: k: int,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
flow: str = 'source_to_target',
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points. r"""Computes graph edges to the nearest :obj:`k` points.
Args: Args:
...@@ -98,6 +115,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -98,6 +115,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`) on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -113,7 +132,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -113,7 +132,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers) num_workers, batch_size)
if flow == 'source_to_target': if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0] row, col = edge_index[1], edge_index[0]
......
...@@ -4,10 +4,16 @@ import torch ...@@ -4,10 +4,16 @@ import torch
@torch.jit.script @torch.jit.script
def radius(x: torch.Tensor, y: torch.Tensor, r: float, def radius(
batch_x: Optional[torch.Tensor] = None, x: torch.Tensor,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32, y: torch.Tensor,
num_workers: int = 1) -> torch.Tensor: r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`. distance :obj:`r`.
...@@ -33,6 +39,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -33,6 +39,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`) :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python .. code-block:: python
...@@ -52,16 +60,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -52,16 +60,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous() x, y = x.contiguous(), y.contiguous()
batch_size = 1 if batch_size is None:
if batch_x is not None: batch_size = 1
assert x.size(0) == batch_x.numel() if batch_x is not None:
batch_size = int(batch_x.max()) + 1 assert x.size(0) == batch_x.numel()
if batch_y is not None: batch_size = int(batch_x.max()) + 1
assert y.size(0) == batch_y.numel() if batch_y is not None:
batch_size = max(batch_size, int(batch_y.max()) + 1) assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None
if batch_size > 1: if batch_size > 1:
assert batch_x is not None assert batch_x is not None
assert batch_y is not None assert batch_y is not None
...@@ -74,10 +85,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -74,10 +85,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
@torch.jit.script @torch.jit.script
def radius_graph(x: torch.Tensor, r: float, def radius_graph(
batch: Optional[torch.Tensor] = None, loop: bool = False, x: torch.Tensor,
max_num_neighbors: int = 32, flow: str = 'source_to_target', r: float,
num_workers: int = 1) -> torch.Tensor: batch: Optional[torch.Tensor] = None,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance. r"""Computes graph edges to all points within a given distance.
Args: Args:
...@@ -101,6 +118,8 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -101,6 +118,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`) on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
...@@ -117,7 +136,7 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -117,7 +136,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source'] assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch, edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1, max_num_neighbors if loop else max_num_neighbors + 1,
num_workers) num_workers, batch_size)
if flow == 'source_to_target': if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0] row, col = edge_index[1], edge_index[0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment