Commit 53959eee authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 5f1939fd
...@@ -13,12 +13,12 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, ...@@ -13,12 +13,12 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(ptr); CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1; auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "utils.cuh" #include "utils.cuh"
#include <stdio.h>
#define THREADS 256 #define THREADS 256
template <typename scalar_t> template <typename scalar_t>
...@@ -64,19 +64,18 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -64,19 +64,18 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
} }
} }
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
bool random_start) { torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
// AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
ratio = ratio.contiguous(); auto batch_size = ptr.numel() - 1;
auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
bool random_start); torch::Tensor ratio, bool random_start);
...@@ -98,7 +98,7 @@ setup( ...@@ -98,7 +98,7 @@ setup(
ext_modules=get_extensions() if not BUILD_DOCS else [], ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={ cmdclass={
'build_ext': 'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True) BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
}, },
packages=find_packages(), packages=find_packages(),
) )
...@@ -21,16 +21,30 @@ def test_fps(dtype, device): ...@@ -21,16 +21,30 @@ def test_fps(dtype, device):
], dtype, device) ], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=0.5, random_start=False) out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5), random_start=False) out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5]), random_start=False) out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ratio=torch.tensor(0.5), random_start=False) out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7] assert out.sort()[0].tolist() == [0, 5, 6, 7]
...@@ -42,5 +56,5 @@ def test_random_fps(device): ...@@ -42,5 +56,5 @@ def test_random_fps(device):
batch_1 = torch.zeros(N, dtype=torch.long, device=device) batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device) batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2]) batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=torch.tensor(0.5)) idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N assert idx.min() >= 0 and idx.max() < 2 * N
...@@ -3,19 +3,19 @@ from torch import Tensor ...@@ -3,19 +3,19 @@ from torch import Tensor
import torch import torch
@torch.jit._overload @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): def fps(src, batch, ratio, random_start):
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tensor
pass pass
@torch.jit._overload @torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): def fps(src, batch, ratio, random_start):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor # type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): def fps(src: torch.Tensor, batch=None, ratio=0.5, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -27,12 +27,14 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): ...@@ -27,12 +27,14 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
batch (LongTensor, optional): Batch vector batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
ratio (Tensor, optional): Sampling ratio. (default: :obj:`0.5`) ratio (float or Tensor, optional): Sampling ratio.
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. code-block:: python .. code-block:: python
import torch import torch
...@@ -44,10 +46,7 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): ...@@ -44,10 +46,7 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True):
""" """
if not isinstance(ratio, Tensor): if not isinstance(ratio, Tensor):
ratio = torch.tensor(ratio) ratio = torch.tensor(ratio, device=src.device)
assert len(ratio.shape) < 2, f'ratio should be a scalar or a vector, received a tensor rank {len(ratio.shape)}'
ratio = ratio.to(src.device)
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment