Unverified Commit 0e2ad8df authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[feature]: support dilated ball query (#96)

* add dilated ball query

* fix bugs for group points

* rename max radius
parent e4857216
...@@ -11,12 +11,13 @@ class BallQuery(Function): ...@@ -11,12 +11,13 @@ class BallQuery(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, radius: float, sample_num: int, xyz: torch.Tensor, def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
center_xyz: torch.Tensor) -> torch.Tensor: xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
"""forward. """forward.
Args: Args:
radius (float): radius of the balls. min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls. sample_num (int): maximum number of features in the balls.
xyz (Tensor): (B, N, 3) xyz coordinates of the features. xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query. center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
...@@ -27,13 +28,14 @@ class BallQuery(Function): ...@@ -27,13 +28,14 @@ class BallQuery(Function):
""" """
assert center_xyz.is_contiguous() assert center_xyz.is_contiguous()
assert xyz.is_contiguous() assert xyz.is_contiguous()
assert min_radius < max_radius
B, N, _ = xyz.size() B, N, _ = xyz.size()
npoint = center_xyz.size(1) npoint = center_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_() idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_()
ball_query_ext.ball_query_wrapper(B, N, npoint, radius, sample_num, ball_query_ext.ball_query_wrapper(B, N, npoint, min_radius, max_radius,
center_xyz, xyz, idx) sample_num, center_xyz, xyz, idx)
ctx.mark_non_differentiable(idx) ctx.mark_non_differentiable(idx)
return idx return idx
......
...@@ -19,15 +19,15 @@ extern THCState *state; ...@@ -19,15 +19,15 @@ extern THCState *state;
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
int ball_query_wrapper(int b, int n, int m, float radius, int nsample, int ball_query_wrapper(int b, int n, int m, float min_radius, float max_radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor); at::Tensor idx_tensor);
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, void ball_query_kernel_launcher(int b, int n, int m, float min_radius, float max_radius,
const float *xyz, const float *new_xyz, int nsample, const float *xyz, const float *new_xyz,
int *idx, cudaStream_t stream); int *idx, cudaStream_t stream);
int ball_query_wrapper(int b, int n, int m, float radius, int nsample, int ball_query_wrapper(int b, int n, int m, float min_radius, float max_radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor) { at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor); CHECK_INPUT(new_xyz_tensor);
...@@ -37,8 +37,8 @@ int ball_query_wrapper(int b, int n, int m, float radius, int nsample, ...@@ -37,8 +37,8 @@ int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
int *idx = idx_tensor.data_ptr<int>(); int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx, ball_query_kernel_launcher(b, n, m, min_radius, max_radius,
stream); nsample, new_xyz, xyz, idx, stream);
return 1; return 1;
} }
......
...@@ -8,7 +8,9 @@ ...@@ -8,7 +8,9 @@
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void ball_query_kernel(int b, int n, int m, float radius, __global__ void ball_query_kernel(int b, int n, int m,
float min_radius,
float max_radius,
int nsample, int nsample,
const float *__restrict__ new_xyz, const float *__restrict__ new_xyz,
const float *__restrict__ xyz, const float *__restrict__ xyz,
...@@ -25,7 +27,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius, ...@@ -25,7 +27,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
xyz += bs_idx * n * 3; xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample; idx += bs_idx * m * nsample + pt_idx * nsample;
float radius2 = radius * radius; float max_radius2 = max_radius * max_radius;
float min_radius2 = min_radius * min_radius;
float new_x = new_xyz[0]; float new_x = new_xyz[0];
float new_y = new_xyz[1]; float new_y = new_xyz[1];
float new_z = new_xyz[2]; float new_z = new_xyz[2];
...@@ -37,7 +40,7 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius, ...@@ -37,7 +40,7 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
float z = xyz[k * 3 + 2]; float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z); (new_z - z) * (new_z - z);
if (d2 < radius2) { if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) {
if (cnt == 0) { if (cnt == 0) {
for (int l = 0; l < nsample; ++l) { for (int l = 0; l < nsample; ++l) {
idx[l] = k; idx[l] = k;
...@@ -50,8 +53,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius, ...@@ -50,8 +53,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
} }
} }
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, void ball_query_kernel_launcher(int b, int n, int m, float min_radius, float max_radius,
const float *new_xyz, const float *xyz, int nsample, const float *new_xyz, const float *xyz,
int *idx, cudaStream_t stream) { int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3) // new_xyz: (B, M, 3)
// xyz: (B, N, 3) // xyz: (B, N, 3)
...@@ -64,8 +67,8 @@ void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample, ...@@ -64,8 +67,8 @@ void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
b); // blockIdx.x(col), blockIdx.y(row) b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK); dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, min_radius, max_radius,
new_xyz, xyz, idx); nsample, new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function // cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError(); err = cudaGetLastError();
if (cudaSuccess != err) { if (cudaSuccess != err) {
......
...@@ -13,8 +13,9 @@ class QueryAndGroup(nn.Module): ...@@ -13,8 +13,9 @@ class QueryAndGroup(nn.Module):
Groups with a ball query of radius Groups with a ball query of radius
Args: Args:
radius (float): radius of the balls. max_radius (float): The maximum radius of the balls.
sample_num (int): Maximum number of features to gather in the ball. sample_num (int): Maximum number of features to gather in the ball.
min_radius (float): The minimum radius of the balls.
use_xyz (bool): Whether to use xyz. use_xyz (bool): Whether to use xyz.
Default: True. Default: True.
return_grouped_xyz (bool): Whether to return grouped xyz. return_grouped_xyz (bool): Whether to return grouped xyz.
...@@ -29,15 +30,17 @@ class QueryAndGroup(nn.Module): ...@@ -29,15 +30,17 @@ class QueryAndGroup(nn.Module):
""" """
def __init__(self, def __init__(self,
radius, max_radius,
sample_num, sample_num,
min_radius=0,
use_xyz=True, use_xyz=True,
return_grouped_xyz=False, return_grouped_xyz=False,
normalize_xyz=False, normalize_xyz=False,
uniform_sample=False, uniform_sample=False,
return_unique_cnt=False): return_unique_cnt=False):
super(QueryAndGroup, self).__init__() super(QueryAndGroup, self).__init__()
self.radius = radius self.max_radius = max_radius
self.min_radius = min_radius
self.sample_num = sample_num self.sample_num = sample_num
self.use_xyz = use_xyz self.use_xyz = use_xyz
self.return_grouped_xyz = return_grouped_xyz self.return_grouped_xyz = return_grouped_xyz
...@@ -58,7 +61,8 @@ class QueryAndGroup(nn.Module): ...@@ -58,7 +61,8 @@ class QueryAndGroup(nn.Module):
Return: Return:
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature. Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
""" """
idx = ball_query(self.radius, self.sample_num, points_xyz, center_xyz) idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)
if self.uniform_sample: if self.uniform_sample:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
...@@ -79,7 +83,7 @@ class QueryAndGroup(nn.Module): ...@@ -79,7 +83,7 @@ class QueryAndGroup(nn.Module):
grouped_xyz = grouping_operation(xyz_trans, idx) grouped_xyz = grouping_operation(xyz_trans, idx)
grouped_xyz -= center_xyz.transpose(1, 2).unsqueeze(-1) grouped_xyz -= center_xyz.transpose(1, 2).unsqueeze(-1)
if self.normalize_xyz: if self.normalize_xyz:
grouped_xyz /= self.radius grouped_xyz /= self.max_radius
if features is not None: if features is not None:
grouped_features = grouping_operation(features, idx) grouped_features = grouping_operation(features, idx)
......
...@@ -26,6 +26,8 @@ class PointSAModuleMSG(nn.Module): ...@@ -26,6 +26,8 @@ class PointSAModuleMSG(nn.Module):
FS: using F-FPS and D-FPS simultaneously. FS: using F-FPS and D-FPS simultaneously.
fps_sample_range_list (list[int]): Range of points to apply FPS. fps_sample_range_list (list[int]): Range of points to apply FPS.
Default: [-1]. Default: [-1].
dilated_group (bool): Whether to use dilated ball query.
Default: False.
norm_cfg (dict): Type of normalization method. norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d'). Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz. use_xyz (bool): Whether to use xyz.
...@@ -34,6 +36,9 @@ class PointSAModuleMSG(nn.Module): ...@@ -34,6 +36,9 @@ class PointSAModuleMSG(nn.Module):
Default: 'max_pool'. Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius. normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False. Default: False.
bias (bool | str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
False. Default: "auto".
""" """
def __init__(self, def __init__(self,
...@@ -43,6 +48,7 @@ class PointSAModuleMSG(nn.Module): ...@@ -43,6 +48,7 @@ class PointSAModuleMSG(nn.Module):
mlp_channels: List[List[int]], mlp_channels: List[List[int]],
fps_mod: List[str] = ['D-FPS'], fps_mod: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1], fps_sample_range_list: List[int] = [-1],
dilated_group: bool = False,
norm_cfg: dict = dict(type='BN2d'), norm_cfg: dict = dict(type='BN2d'),
use_xyz: bool = True, use_xyz: bool = True,
pool_mod='max', pool_mod='max',
...@@ -80,9 +86,14 @@ class PointSAModuleMSG(nn.Module): ...@@ -80,9 +86,14 @@ class PointSAModuleMSG(nn.Module):
radius = radii[i] radius = radii[i]
sample_num = sample_nums[i] sample_num = sample_nums[i]
if num_point is not None: if num_point is not None:
if dilated_group and i != 0:
min_radius = radii[i - 1]
else:
min_radius = 0
grouper = QueryAndGroup( grouper = QueryAndGroup(
radius, radius,
sample_num, sample_num,
min_radius=min_radius,
use_xyz=use_xyz, use_xyz=use_xyz,
normalize_xyz=normalize_xyz) normalize_xyz=normalize_xyz)
else: else:
......
...@@ -53,14 +53,23 @@ def test_ball_query(): ...@@ -53,14 +53,23 @@ def test_ball_query():
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda() -1.2000]]]).cuda()
idx = ball_query(0.2, 5, xyz, new_xyz) idx = ball_query(0, 0.2, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]], [0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda() [0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
# test dilated ball query
idx = ball_query(0.2, 0.4, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6],
[2, 3, 2, 2, 2], [0, 5, 7, 0, 0],
[0, 5, 7, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx) assert torch.all(idx == expected_idx)
...@@ -340,5 +349,6 @@ def test_fps_with_dist(): ...@@ -340,5 +349,6 @@ def test_fps_with_dist():
expected_idx = torch.from_numpy(fps_idx).cuda() expected_idx = torch.from_numpy(fps_idx).cuda()
features_for_fps_distance = torch.from_numpy( features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda() features_for_fps_distance).cuda()
idx = furthest_point_sample_with_dist(features_for_fps_distance, 16) idx = furthest_point_sample_with_dist(features_for_fps_distance, 16)
assert torch.all(idx == expected_idx) assert torch.all(idx == expected_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment