Commit 82968b99 authored by Duc Nguyen's avatar Duc Nguyen
Browse files

accepted different ratios for different point clouds

parent 2bf5e763
...@@ -6,20 +6,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) { ...@@ -6,20 +6,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).norm(2, 1); return (x - x[idx]).norm(2, 1);
} }
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(ptr); CHECK_CPU(ptr);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input"); // AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1; auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options()); auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options());
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start); bool random_start);
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "utils.cuh" #include "utils.cuh"
#include <stdio.h>
#define THREADS 256 #define THREADS 256
template <typename scalar_t> template <typename scalar_t>
...@@ -64,21 +64,22 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -64,21 +64,22 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
} }
} }
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input"); // AT_ASSERTM(at::all(at::__and__(at::gt(ratio, 0), at::lt(ratio, 1))), "Invalid input");
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
ratio = ratio.contiguous();
auto batch_size = ptr.size(0) - 1; auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0); out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start); bool random_start);
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
PyMODINIT_FUNC PyInit__fps(void) { return NULL; } PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif #endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
...@@ -84,7 +84,7 @@ setup( ...@@ -84,7 +84,7 @@ setup(
ext_modules=get_extensions() if not BUILD_DOCS else [], ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={ cmdclass={
'build_ext': 'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) BuildExtension.with_options(no_python_abi_suffix=True)
}, },
packages=find_packages(), packages=find_packages(),
) )
...@@ -21,10 +21,10 @@ def test_fps(dtype, device): ...@@ -21,10 +21,10 @@ def test_fps(dtype, device):
], dtype, device) ], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps(x, batch, ratio=0.5, random_start=False) out = fps(x, batch, ratio=torch.tensor(0.5), random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ratio=0.5, random_start=False) out = fps(x, ratio=torch.tensor(0.5), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7] assert out.sort()[0].tolist() == [0, 5, 6, 7]
...@@ -36,5 +36,5 @@ def test_random_fps(device): ...@@ -36,5 +36,5 @@ def test_random_fps(device):
batch_1 = torch.zeros(N, dtype=torch.long, device=device) batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device) batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2]) batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=0.5) idx = fps(pos, batch, ratio=torch.tensor(0.5))
assert idx.min() >= 0 and idx.max() < 2 * N assert idx.min() >= 0 and idx.max() < 2 * N
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
@torch.jit.script @torch.jit.script
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
ratio: float = 0.5, random_start: bool = True) -> torch.Tensor: ratio: torch.Tensor = torch.tensor(0.5), random_start: bool = True) -> torch.Tensor:
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -17,7 +17,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -17,7 +17,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
batch (LongTensor, optional): Batch vector batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
ratio (float, optional): Sampling ratio. (default: :obj:`0.5`) ratio (Tensor, optional): Sampling ratio. (default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
...@@ -33,6 +33,11 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -33,6 +33,11 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5) index = fps(src, batch, ratio=0.5)
""" """
assert len(ratio.shape) < 2, 'Invalid ratio'
ratio = ratio.to(src.device)
if len(ratio.shape) == 1:
assert ratio.shape[0] == int(batch.max()) + 1, 'Mismatched input and ratio numbers'
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1 batch_size = int(batch.max()) + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment