Commit 183d353a authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: small bug in SPC sampling

parent fa39f1c5
......@@ -219,15 +219,15 @@ class VoxelSetAbstraction(nn.Module):
Returns:
sampled_points: (N_out, 3)
"""
sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
)
sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC.NUM_SECTORS
num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS
)
return sampled_points
......@@ -271,7 +271,7 @@ class VoxelSetAbstraction(nn.Module):
elif self.model_cfg.SAMPLE_METHOD == 'SPC':
cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0]
)
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
......@@ -290,7 +290,7 @@ class VoxelSetAbstraction(nn.Module):
@staticmethod
def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=None, rois=None
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None
):
"""
......@@ -311,7 +311,7 @@ class VoxelSetAbstraction(nn.Module):
"""
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=0) if xyz_features is not None else xyz
point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz
point_features_list = []
for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx)
......@@ -334,7 +334,7 @@ class VoxelSetAbstraction(nn.Module):
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features,
features=xyz_features.contiguous(),
)
return pooled_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment