Commit 183d353a authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: small bug in SPC sampling

parent fa39f1c5
...@@ -219,15 +219,15 @@ class VoxelSetAbstraction(nn.Module): ...@@ -219,15 +219,15 @@ class VoxelSetAbstraction(nn.Module):
Returns: Returns:
sampled_points: (N_out, 3) sampled_points: (N_out, 3)
""" """
sampled_points, _ = sample_points_with_roi( sampled_points, _ = sample_points_with_roi(
rois=roi_boxes, points=points, rois=roi_boxes, points=points,
sample_radius_with_roi=self.model_cfg.SPC.SAMPLE_RADIUS_WITH_ROI, sample_radius_with_roi=self.model_cfg.SPC_SAMPLING.SAMPLE_RADIUS_WITH_ROI,
num_max_points_of_part=self.model_cfg.SPC.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000) num_max_points_of_part=self.model_cfg.SPC_SAMPLING.get('NUM_POINTS_OF_EACH_SAMPLE_PART', 200000)
) )
sampled_points = sector_fps( sampled_points = sector_fps(
points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS, points=sampled_points, num_sampled_points=self.model_cfg.NUM_KEYPOINTS,
num_sectors=self.model_cfg.SPC.NUM_SECTORS num_sectors=self.model_cfg.SPC_SAMPLING.NUM_SECTORS
) )
return sampled_points return sampled_points
...@@ -271,7 +271,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -271,7 +271,7 @@ class VoxelSetAbstraction(nn.Module):
elif self.model_cfg.SAMPLE_METHOD == 'SPC': elif self.model_cfg.SAMPLE_METHOD == 'SPC':
cur_keypoints = self.sectorized_proposal_centric_sampling( cur_keypoints = self.sectorized_proposal_centric_sampling(
roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points roi_boxes=batch_dict['rois'][bs_idx], points=sampled_points[0]
) )
bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx bs_idxs = cur_keypoints.new_ones(cur_keypoints.shape[0]) * bs_idx
keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1) keypoints = torch.cat((bs_idxs[:, None], cur_keypoints), dim=1)
...@@ -290,7 +290,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -290,7 +290,7 @@ class VoxelSetAbstraction(nn.Module):
@staticmethod @staticmethod
def aggregate_keypoint_features_from_one_source( def aggregate_keypoint_features_from_one_source(
batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt, batch_size, aggregate_func, xyz, xyz_features, xyz_bs_idxs, new_xyz, new_xyz_batch_cnt,
filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=None, rois=None filter_neighbors_with_roi=False, radius_of_neighbor=None, num_max_points_of_part=200000, rois=None
): ):
""" """
...@@ -311,7 +311,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -311,7 +311,7 @@ class VoxelSetAbstraction(nn.Module):
""" """
xyz_batch_cnt = xyz.new_zeros(batch_size).int() xyz_batch_cnt = xyz.new_zeros(batch_size).int()
if filter_neighbors_with_roi: if filter_neighbors_with_roi:
point_features = torch.cat((xyz, xyz_features), dim=0) if xyz_features is not None else xyz point_features = torch.cat((xyz, xyz_features), dim=-1) if xyz_features is not None else xyz
point_features_list = [] point_features_list = []
for bs_idx in range(batch_size): for bs_idx in range(batch_size):
bs_mask = (xyz_bs_idxs == bs_idx) bs_mask = (xyz_bs_idxs == bs_idx)
...@@ -334,7 +334,7 @@ class VoxelSetAbstraction(nn.Module): ...@@ -334,7 +334,7 @@ class VoxelSetAbstraction(nn.Module):
xyz_batch_cnt=xyz_batch_cnt, xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz, new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt, new_xyz_batch_cnt=new_xyz_batch_cnt,
features=xyz_features, features=xyz_features.contiguous(),
) )
return pooled_features return pooled_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment