Unverified Commit 63a6cbe9 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add fps op from mmdet3d (#1337)

* add ops (furthest point sample) in mmdet3d

* refactor code

* refactor code

* fix typo

* fix typo

* fix typo

* refactor code

* fix typo

* define DIVUP in common_cuda_helper.hpp
parent 599163e6
...@@ -10,6 +10,8 @@ We implement common CUDA ops used in detection, segmentation, etc. ...@@ -10,6 +10,8 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool - CornerPool
- Deformable Convolution v1/v2 - Deformable Convolution v1/v2
- Deformable RoIPool - Deformable RoIPool
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention - GeneralizedAttention
- MaskedConv - MaskedConv
- NMS - NMS
......
...@@ -10,6 +10,8 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 ...@@ -10,6 +10,8 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool - CornerPool
- Deformable Convolution v1/v2 - Deformable Convolution v1/v2
- Deformable RoIPool - Deformable RoIPool
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention - GeneralizedAttention
- MaskedConv - MaskedConv
- NMS - NMS
......
...@@ -17,6 +17,8 @@ from .deprecated_wrappers import Linear_deprecated as Linear ...@@ -17,6 +17,8 @@ from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss) sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version, from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path) get_onnxruntime_op_path)
...@@ -29,6 +31,7 @@ from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms ...@@ -29,6 +31,7 @@ from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample, from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point) rel_roi_point_to_rel_img_point)
from .points_sampler import PointsSampler
from .psa_mask import PSAMask from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
...@@ -55,5 +58,6 @@ __all__ = [ ...@@ -55,5 +58,6 @@ __all__ = [
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'Correlation' 'furthest_point_sample', 'furthest_point_sample_with_dist',
'PointsSampler', 'Correlation'
] ]
...@@ -10,8 +10,6 @@ ...@@ -10,8 +10,6 @@
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
#endif #endif
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
template <typename T> template <typename T>
__global__ void ball_query_forward_cuda_kernel(int b, int n, int m, __global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
float min_radius, float min_radius,
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#define THREADS_PER_BLOCK 512 #define THREADS_PER_BLOCK 512
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
inline int GET_BLOCKS(const int N) { inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
int max_block_num = 4096; int max_block_num = 4096;
......
// Copyright (c) OpenMMLab. All rights reserved
#ifndef FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH
#define FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
int idx1, int idx2) {
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}
template <unsigned int block_size>
__global__ void furthest_point_sampling_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;
float d =
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0) idxs[j] = old;
}
}
// Modified from
// https://github.com/qiqihaer/3DSSD-pytorch/blob/master/lib/pointnet2/src/sampling_gpu.cu
template <unsigned int block_size>
__global__ void furthest_point_sampling_with_dist_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, N)
// tmp: (B, N)
// output:
// idx: (B, M)
if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];
int batch_index = blockIdx.x;
dataset += batch_index * n * n;
temp += batch_index * n;
idxs += batch_index * m;
int tid = threadIdx.x;
const int stride = block_size;
int old = 0;
if (threadIdx.x == 0) idxs[0] = old;
__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
// float x1 = dataset[old * 3 + 0];
// float y1 = dataset[old * 3 + 1];
// float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
// float x2, y2, z2;
// x2 = dataset[k * 3 + 0];
// y2 = dataset[k * 3 + 1];
// z2 = dataset[k * 3 + 2];
// float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) *
// (z2 - z1);
float d = dataset[old * n + k];
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();
for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();
}
old = dists_i[0];
if (tid == 0) idxs[j] = old;
}
}
#endif // FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
#endif // MMCV_USE_PARROTS #endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT #endif // MMCV_WITH_TRT
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long int) * 8; int const threadsPerBlock = sizeof(unsigned long long int) * 8;
__device__ inline bool devIoU(float const *const a, float const *const b, __device__ inline bool devIoU(float const *const a, float const *const b,
......
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling_gpu.cu
#include <stdio.h>
#include <stdlib.h>
#include "furthest_point_sample_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1 << pow_2, 1024), 1);
}
void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m,
const float *dataset,
float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthest_point_sampling_forward_cuda_kernel<1024>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_forward_cuda_kernel<256>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_forward_cuda_kernel<128>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_forward_cuda_kernel<64>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_forward_cuda_kernel<32>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_forward_cuda_kernel<16>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_forward_cuda_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_forward_cuda_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_forward_cuda_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_forward_cuda_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
AT_CUDA_CHECK(cudaGetLastError());
}
void FurthestPointSamplingWithDistForwardCUDAKernelLauncher(
int b, int n, int m, const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, N)
// temp: (B, N)
// output:
// idx: (B, M)
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
unsigned int n_threads = opt_n_threads(n);
switch (n_threads) {
case 1024:
furthest_point_sampling_with_dist_forward_cuda_kernel<1024>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_with_dist_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_with_dist_forward_cuda_kernel<256>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_with_dist_forward_cuda_kernel<128>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_with_dist_forward_cuda_kernel<64>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_with_dist_forward_cuda_kernel<32>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_with_dist_forward_cuda_kernel<16>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_with_dist_forward_cuda_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_with_dist_forward_cuda_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_with_dist_forward_cuda_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_with_dist_forward_cuda_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_with_dist_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}
AT_CUDA_CHECK(cudaGetLastError());
}
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m,
const float *dataset,
float *temp, int *idxs);
void furthest_point_sampling_forward_cuda(int b, int n, int m,
const float *dataset, float *temp,
int *idxs) {
FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs);
}
void FurthestPointSamplingWithDistForwardCUDAKernelLauncher(
int b, int n, int m, const float *dataset, float *temp, int *idxs);
void furthest_point_sampling_with_dist_forward_cuda(int b, int n, int m,
const float *dataset,
float *temp, int *idxs) {
FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp,
idxs);
}
#endif
void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
const float *points = points_tensor.data_ptr<float>();
float *temp = temp_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
furthest_point_sampling_forward_cuda(b, n, m, points, temp, idx);
#else
AT_ERROR("furthest_point_sampling is not compiled with GPU support");
#endif
} else {
AT_ERROR("furthest_point_sampling is not implemented on CPU");
}
}
void furthest_point_sampling_with_dist_forward(int b, int n, int m,
Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
furthest_point_sampling_with_dist_forward_cuda(b, n, m, points, temp, idx);
#else
AT_ERROR(
"furthest_point_sampling_with_dist is not compiled with GPU support");
#endif
} else {
AT_ERROR("furthest_point_sampling_with_dist is not implemented on CPU");
}
}
...@@ -69,6 +69,14 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, ...@@ -69,6 +69,14 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset); const int mode, const bool aligned, const int offset);
void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor);
void furthest_point_sampling_with_dist_forward(int b, int n, int m,
Tensor points_tensor,
Tensor temp_tensor,
Tensor idx_tensor);
void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx, void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx,
const Tensor mask_w_idx, Tensor col, const Tensor mask_w_idx, Tensor col,
const int kernel_h, const int kernel_w, const int kernel_h, const int kernel_w,
...@@ -315,6 +323,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -315,6 +323,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"), m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"), py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
py::arg("aligned"), py::arg("offset")); py::arg("aligned"), py::arg("offset"));
m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward,
"furthest_point_sampling_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"),
py::arg("idx_tensor"));
m.def("furthest_point_sampling_with_dist_forward",
&furthest_point_sampling_with_dist_forward,
"furthest_point_sampling_with_dist_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"),
py::arg("idx_tensor"));
m.def("masked_im2col_forward", &masked_im2col_forward, m.def("masked_im2col_forward", &masked_im2col_forward,
"masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"), "masked_im2col_forward", py::arg("im"), py::arg("mask_h_idx"),
py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"), py::arg("mask_w_idx"), py::arg("col"), py::arg("kernel_h"),
......
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'furthest_point_sampling_forward',
'furthest_point_sampling_with_dist_forward'
])
class FurthestPointSampling(Function):
"""Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance."""
@staticmethod
def forward(ctx, points_xyz: torch.Tensor,
num_points: int) -> torch.Tensor:
"""
Args:
points_xyz (Tensor): (B, N, 3) where N > num_points.
num_points (int): Number of points in the sampled set.
Returns:
Tensor: (B, num_points) indices of the sampled points.
"""
assert points_xyz.is_contiguous()
B, N = points_xyz.size()[:2]
output = torch.cuda.IntTensor(B, num_points)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
ext_module.furthest_point_sampling_forward(B, N, num_points,
points_xyz, temp, output)
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
class FurthestPointSamplingWithDist(Function):
"""Uses iterative furthest point sampling to select a set of features whose
corresponding points have the furthest distance."""
@staticmethod
def forward(ctx, points_dist: torch.Tensor,
num_points: int) -> torch.Tensor:
"""
Args:
points_dist (Tensor): (B, N, N) Distance between each point pair.
num_points (int): Number of points in the sampled set.
Returns:
Tensor: (B, num_points) indices of the sampled points.
"""
assert points_dist.is_contiguous()
B, N, _ = points_dist.size()
output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
temp = points_dist.new_zeros([B, N]).fill_(1e10)
ext_module.furthest_point_sampling_with_dist_forward(
B, N, num_points, points_dist, temp, output)
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(xyz, a=None):
return None, None
furthest_point_sample = FurthestPointSampling.apply
furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
from typing import List
import torch
from torch import nn as nn
from mmcv.runner import force_fp32
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
def calc_square_dist(point_feat_a, point_feat_b, norm=True):
"""Calculating square distance between a and b.
Args:
point_feat_a (Tensor): (B, N, C) Feature vector of each point.
point_feat_b (Tensor): (B, M, C) Feature vector of each point.
norm (Bool, optional): Whether to normalize the distance.
Default: True.
Returns:
Tensor: (B, N, M) Distance between each pair points.
"""
num_channel = point_feat_a.shape[-1]
# [bs, n, 1]
a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
# [bs, 1, m]
b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
dist = a_square + b_square - 2 * corr_matrix
if norm:
dist = torch.sqrt(dist) / num_channel
return dist
def get_sampler_cls(sampler_type):
"""Get the type and mode of points sampler.
Args:
sampler_type (str): The type of points sampler.
The valid value are "D-FPS", "F-FPS", or "FS".
Returns:
class: Points sampler type.
"""
sampler_mappings = {
'D-FPS': DFPSSampler,
'F-FPS': FFPSSampler,
'FS': FSSampler,
}
try:
return sampler_mappings[sampler_type]
except KeyError:
raise KeyError(
f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
{sampler_type}')
class PointsSampler(nn.Module):
"""Points sampling.
Args:
num_point (list[int]): Number of sample points.
fps_mod_list (list[str], optional): Type of FPS method, valid mod
['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
F-FPS: using feature distances for FPS.
D-FPS: using Euclidean distances of points for FPS.
FS: using F-FPS and D-FPS simultaneously.
fps_sample_range_list (list[int], optional):
Range of points to apply FPS. Default: [-1].
"""
def __init__(self,
num_point: List[int],
fps_mod_list: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1]):
super().__init__()
# FPS would be applied to different fps_mod in the list,
# so the length of the num_point should be equal to
# fps_mod_list and fps_sample_range_list.
assert len(num_point) == len(fps_mod_list) == len(
fps_sample_range_list)
self.num_point = num_point
self.fps_sample_range_list = fps_sample_range_list
self.samplers = nn.ModuleList()
for fps_mod in fps_mod_list:
self.samplers.append(get_sampler_cls(fps_mod)())
self.fp16_enabled = False
@force_fp32()
def forward(self, points_xyz, features):
"""
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) Descriptors of the features.
Return:
Tensor: (B, npoint, sample_num) Indices of sampled points.
"""
indices = []
last_fps_end_index = 0
for fps_sample_range, sampler, npoint in zip(
self.fps_sample_range_list, self.samplers, self.num_point):
assert fps_sample_range < points_xyz.shape[1]
if fps_sample_range == -1:
sample_points_xyz = points_xyz[:, last_fps_end_index:]
if features is not None:
sample_features = features[:, :, last_fps_end_index:]
else:
sample_features = None
else:
sample_points_xyz = \
points_xyz[:, last_fps_end_index:fps_sample_range]
if features is not None:
sample_features = features[:, :, last_fps_end_index:
fps_sample_range]
else:
sample_features = None
fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
npoint)
indices.append(fps_idx + last_fps_end_index)
last_fps_end_index += fps_sample_range
indices = torch.cat(indices, dim=1)
return indices
class DFPSSampler(nn.Module):
"""Using Euclidean distances of points for FPS."""
def __init__(self):
super().__init__()
def forward(self, points, features, npoint):
"""Sampling points with D-FPS."""
fps_idx = furthest_point_sample(points.contiguous(), npoint)
return fps_idx
class FFPSSampler(nn.Module):
"""Using feature distances for FPS."""
def __init__(self):
super().__init__()
def forward(self, points, features, npoint):
"""Sampling points with F-FPS."""
assert features is not None, \
'feature input to FFPS_Sampler should not be None'
features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
features_dist = calc_square_dist(
features_for_fps, features_for_fps, norm=False)
fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
return fps_idx
class FSSampler(nn.Module):
"""Using F-FPS and D-FPS simultaneously."""
def __init__(self):
super().__init__()
def forward(self, points, features, npoint):
"""Sampling points with FS_Sampling."""
assert features is not None, \
'feature input to FS_Sampler should not be None'
ffps_sampler = FFPSSampler()
dfps_sampler = DFPSSampler()
fps_idx_ffps = ffps_sampler(points, features, npoint)
fps_idx_dfps = dfps_sampler(points, features, npoint)
fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
return fps_idx
import pytest
import torch
from mmcv.ops import furthest_point_sample, furthest_point_sample_with_dist
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps():
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
idx = furthest_point_sample(xyz, 3)
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
assert torch.all(idx == expected_idx)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_fps_with_dist():
xyz = torch.tensor([[[-0.2748, 1.0020, -1.1674], [0.1015, 1.3952, -1.2681],
[-0.8070, 2.4137,
-0.5845], [-1.0001, 2.1982, -0.5859],
[0.3841, 1.8983, -0.7431]],
[[-1.0696, 3.0758,
-0.1899], [-0.2559, 3.5521, -0.1402],
[0.8164, 4.0081, -0.1839], [-1.1000, 3.0213, -0.8205],
[-0.0518, 3.7251, -0.3950]]]).cuda()
expected_idx = torch.tensor([[0, 2, 4], [0, 2, 1]]).cuda()
xyz_square_dist = ((xyz.unsqueeze(dim=1) -
xyz.unsqueeze(dim=2))**2).sum(-1)
idx = furthest_point_sample_with_dist(xyz_square_dist, 3)
assert torch.all(idx == expected_idx)
import numpy as np
fps_idx = np.load('tests/data/for_3d_ops/fps_idx.npy')
features_for_fps_distance = np.load(
'tests/data/for_3d_ops/features_for_fps_distance.npy')
expected_idx = torch.from_numpy(fps_idx).cuda()
features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda()
idx = furthest_point_sample_with_dist(features_for_fps_distance, 16)
assert torch.all(idx == expected_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment