Commit df299e7c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

fix typos: FPS should be farthest point sampling instead of furthest point sampling

parent ec982888
......@@ -136,7 +136,7 @@ class VoxelSetAbstraction(nn.Module):
bs_mask = (batch_indices == bs_idx)
sampled_points = src_points[bs_mask].unsqueeze(dim=0) # (1, N, 3)
if self.model_cfg.SAMPLE_METHOD == 'FPS':
cur_pt_idxs = pointnet2_stack_utils.furthest_point_sample(
cur_pt_idxs = pointnet2_stack_utils.farthest_point_sample(
sampled_points[:, :, 0:3].contiguous(), self.model_cfg.NUM_KEYPOINTS
).long()
......
......@@ -174,7 +174,7 @@ class PointNet2Backbone(nn.Module):
else:
last_num_points = self.num_points_each_layer[i - 1]
cur_xyz = l_xyz[-1][k * last_num_points: (k + 1) * last_num_points]
cur_pt_idxs = pointnet2_utils_stack.furthest_point_sample(
cur_pt_idxs = pointnet2_utils_stack.farthest_point_sample(
cur_xyz[None, :, :].contiguous(), self.num_points_each_layer[i]
).long()[0]
if cur_xyz.shape[0] < self.num_points_each_layer[i]:
......
......@@ -31,7 +31,7 @@ class _PointnetSAModuleBase(nn.Module):
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
pointnet2_utils.farthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
......
......@@ -7,11 +7,11 @@ from torch.autograd import Function, Variable
from . import pointnet2_batch_cuda as pointnet2
class FurthestPointSampling(Function):
class FarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
"""
Uses iterative furthest point sampling to select a set of npoint features that have the largest
Uses iterative farthest point sampling to select a set of npoint features that have the largest
minimum distance
:param ctx:
:param xyz: (B, N, 3) where N > npoint
......@@ -25,7 +25,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
......@@ -33,7 +33,7 @@ class FurthestPointSampling(Function):
return None, None
furthest_point_sample = FurthestPointSampling.apply
farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class GatherOperation(Function):
......
......@@ -16,7 +16,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
......
......@@ -38,13 +38,13 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
}
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
const float *points = points_tensor.data<float>();
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
......@@ -98,7 +98,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
}
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
__global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -215,7 +215,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -227,29 +227,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
......
......@@ -20,10 +20,10 @@ void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints,
const float *grad_out, const int *idx, float *grad_points);
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs);
#endif
......@@ -155,7 +155,7 @@ class QueryAndGroup(nn.Module):
return new_features, idx
class FurthestPointSampling(Function):
class FarthestPointSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, npoint: int):
"""
......@@ -173,7 +173,7 @@ class FurthestPointSampling(Function):
output = torch.cuda.IntTensor(B, npoint)
temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
pointnet2.farthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
return output
@staticmethod
......@@ -181,7 +181,7 @@ class FurthestPointSampling(Function):
return None, None
furthest_point_sample = FurthestPointSampling.apply
farthest_point_sample = furthest_point_sample = FarthestPointSampling.apply
class ThreeNN(Function):
......
......@@ -12,7 +12,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ball_query_wrapper", &ball_query_wrapper_stack, "ball_query_wrapper_stack");
m.def("voxel_query_wrapper", &voxel_query_wrapper_stack, "voxel_query_wrapper_stack");
m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
m.def("farthest_point_sampling_wrapper", &farthest_point_sampling_wrapper, "farthest_point_sampling_wrapper");
m.def("group_points_wrapper", &group_points_wrapper_stack, "group_points_wrapper_stack");
m.def("group_points_grad_wrapper", &group_points_grad_wrapper_stack, "group_points_grad_wrapper_stack");
......
......@@ -21,7 +21,7 @@ extern THCState *state;
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
CHECK_INPUT(points_tensor);
......@@ -32,6 +32,6 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
farthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx);
return 1;
}
......@@ -22,7 +22,7 @@ __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, i
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m,
__global__ void farthest_point_sampling_kernel(int b, int n, int m,
const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -139,7 +139,7 @@ __global__ void furthest_point_sampling_kernel(int b, int n, int m,
}
}
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
......@@ -151,29 +151,29 @@ void furthest_point_sampling_kernel_launcher(int b, int n, int m,
switch (n_threads) {
case 1024:
furthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1024><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 512:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 256:
furthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<256><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 128:
furthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<128><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 64:
furthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<64><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 32:
furthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<32><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 16:
furthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<16><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 8:
furthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<8><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 4:
furthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<4><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 2:
furthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<2><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
case 1:
furthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
farthest_point_sampling_kernel<1><<<b, n_threads>>>(b, n, m, dataset, temp, idxs); break;
default:
furthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
farthest_point_sampling_kernel<512><<<b, n_threads>>>(b, n, m, dataset, temp, idxs);
}
err = cudaGetLastError();
......
......@@ -6,10 +6,10 @@
#include<vector>
int furthest_point_sampling_wrapper(int b, int n, int m,
int farthest_point_sampling_wrapper(int b, int n, int m,
at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
void furthest_point_sampling_kernel_launcher(int b, int n, int m,
void farthest_point_sampling_kernel_launcher(int b, int n, int m,
const float *dataset, float *temp, int *idxs);
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment