Unverified Commit 0e2ad8df authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[feature]: support dilated ball query (#96)

* add dilated ball query

* fix bugs for group points

* rename max radius
parent e4857216
......@@ -11,12 +11,13 @@ class BallQuery(Function):
"""
@staticmethod
def forward(ctx, radius: float, sample_num: int, xyz: torch.Tensor,
center_xyz: torch.Tensor) -> torch.Tensor:
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
"""forward.
Args:
radius (float): radius of the balls.
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
......@@ -27,13 +28,14 @@ class BallQuery(Function):
"""
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = torch.cuda.IntTensor(B, npoint, sample_num).zero_()
ball_query_ext.ball_query_wrapper(B, N, npoint, radius, sample_num,
center_xyz, xyz, idx)
ball_query_ext.ball_query_wrapper(B, N, npoint, min_radius, max_radius,
sample_num, center_xyz, xyz, idx)
ctx.mark_non_differentiable(idx)
return idx
......
......@@ -19,15 +19,15 @@ extern THCState *state;
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
int ball_query_wrapper(int b, int n, int m, float min_radius, float max_radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor);
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *xyz, const float *new_xyz,
void ball_query_kernel_launcher(int b, int n, int m, float min_radius, float max_radius,
int nsample, const float *xyz, const float *new_xyz,
int *idx, cudaStream_t stream);
int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
int ball_query_wrapper(int b, int n, int m, float min_radius, float max_radius, int nsample,
at::Tensor new_xyz_tensor, at::Tensor xyz_tensor,
at::Tensor idx_tensor) {
CHECK_INPUT(new_xyz_tensor);
......@@ -37,8 +37,8 @@ int ball_query_wrapper(int b, int n, int m, float radius, int nsample,
int *idx = idx_tensor.data_ptr<int>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
ball_query_kernel_launcher(b, n, m, radius, nsample, new_xyz, xyz, idx,
stream);
ball_query_kernel_launcher(b, n, m, min_radius, max_radius,
nsample, new_xyz, xyz, idx, stream);
return 1;
}
......
......@@ -8,7 +8,9 @@
#define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
__global__ void ball_query_kernel(int b, int n, int m, float radius,
__global__ void ball_query_kernel(int b, int n, int m,
float min_radius,
float max_radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
......@@ -25,7 +27,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
float radius2 = radius * radius;
float max_radius2 = max_radius * max_radius;
float min_radius2 = min_radius * min_radius;
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
......@@ -37,7 +40,7 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
......@@ -50,8 +53,8 @@ __global__ void ball_query_kernel(int b, int n, int m, float radius,
}
}
void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
const float *new_xyz, const float *xyz,
void ball_query_kernel_launcher(int b, int n, int m, float min_radius, float max_radius,
int nsample, const float *new_xyz, const float *xyz,
int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
......@@ -64,8 +67,8 @@ void ball_query_kernel_launcher(int b, int n, int m, float radius, int nsample,
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample,
new_xyz, xyz, idx);
ball_query_kernel<<<blocks, threads, 0, stream>>>(b, n, m, min_radius, max_radius,
nsample, new_xyz, xyz, idx);
// cudaDeviceSynchronize(); // for using printf in kernel function
err = cudaGetLastError();
if (cudaSuccess != err) {
......
......@@ -13,8 +13,9 @@ class QueryAndGroup(nn.Module):
Groups with a ball query of radius
Args:
radius (float): radius of the balls.
max_radius (float): The maximum radius of the balls.
sample_num (int): Maximum number of features to gather in the ball.
min_radius (float): The minimum radius of the balls.
use_xyz (bool): Whether to use xyz.
Default: True.
return_grouped_xyz (bool): Whether to return grouped xyz.
......@@ -29,15 +30,17 @@ class QueryAndGroup(nn.Module):
"""
def __init__(self,
radius,
max_radius,
sample_num,
min_radius=0,
use_xyz=True,
return_grouped_xyz=False,
normalize_xyz=False,
uniform_sample=False,
return_unique_cnt=False):
super(QueryAndGroup, self).__init__()
self.radius = radius
self.max_radius = max_radius
self.min_radius = min_radius
self.sample_num = sample_num
self.use_xyz = use_xyz
self.return_grouped_xyz = return_grouped_xyz
......@@ -58,7 +61,8 @@ class QueryAndGroup(nn.Module):
Return:
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
"""
idx = ball_query(self.radius, self.sample_num, points_xyz, center_xyz)
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
points_xyz, center_xyz)
if self.uniform_sample:
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
......@@ -79,7 +83,7 @@ class QueryAndGroup(nn.Module):
grouped_xyz = grouping_operation(xyz_trans, idx)
grouped_xyz -= center_xyz.transpose(1, 2).unsqueeze(-1)
if self.normalize_xyz:
grouped_xyz /= self.radius
grouped_xyz /= self.max_radius
if features is not None:
grouped_features = grouping_operation(features, idx)
......
......@@ -26,6 +26,8 @@ class PointSAModuleMSG(nn.Module):
FS: using F-FPS and D-FPS simultaneously.
fps_sample_range_list (list[int]): Range of points to apply FPS.
Default: [-1].
dilated_group (bool): Whether to use dilated ball query.
Default: False.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
......@@ -34,6 +36,9 @@ class PointSAModuleMSG(nn.Module):
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
bias (bool | str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
False. Default: "auto".
"""
def __init__(self,
......@@ -43,6 +48,7 @@ class PointSAModuleMSG(nn.Module):
mlp_channels: List[List[int]],
fps_mod: List[str] = ['D-FPS'],
fps_sample_range_list: List[int] = [-1],
dilated_group: bool = False,
norm_cfg: dict = dict(type='BN2d'),
use_xyz: bool = True,
pool_mod='max',
......@@ -80,9 +86,14 @@ class PointSAModuleMSG(nn.Module):
radius = radii[i]
sample_num = sample_nums[i]
if num_point is not None:
if dilated_group and i != 0:
min_radius = radii[i - 1]
else:
min_radius = 0
grouper = QueryAndGroup(
radius,
sample_num,
min_radius=min_radius,
use_xyz=use_xyz,
normalize_xyz=normalize_xyz)
else:
......
......@@ -53,14 +53,23 @@ def test_ball_query():
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()
idx = ball_query(0.2, 5, xyz, new_xyz)
idx = ball_query(0, 0.2, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
# test dilated ball query
idx = ball_query(0.2, 0.4, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6],
[2, 3, 2, 2, 2], [0, 5, 7, 0, 0],
[0, 5, 7, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
......@@ -340,5 +349,6 @@ def test_fps_with_dist():
expected_idx = torch.from_numpy(fps_idx).cuda()
features_for_fps_distance = torch.from_numpy(
features_for_fps_distance).cuda()
idx = furthest_point_sample_with_dist(features_for_fps_distance, 16)
assert torch.all(idx == expected_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment