Unverified Commit 0e84f82d authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

bugfixed: re-sample scene if gt_boxes.shape[0] == 0, check fg_num in PointHeadTemplate (#340)

parent f982b5bf
......@@ -124,9 +124,6 @@ class DatasetTemplate(torch_data.Dataset):
'gt_boxes_mask': gt_boxes_mask
}
)
if len(data_dict['gt_boxes']) == 0:
new_index = np.random.randint(self.__len__())
return self.__getitem__(new_index)
if data_dict.get('gt_boxes', None) is not None:
selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
......@@ -141,6 +138,11 @@ class DatasetTemplate(torch_data.Dataset):
data_dict = self.data_processor.forward(
data_dict=data_dict
)
if len(data_dict['gt_boxes']) == 0:
new_index = np.random.randint(self.__len__())
return self.__getitem__(new_index)
data_dict.pop('gt_names', None)
return data_dict
......
......@@ -102,7 +102,7 @@ class PointHeadTemplate(nn.Module):
point_cls_labels_single[fg_flag] = 1 if self.num_class == 1 else gt_box_of_fg_points[:, -1].long()
point_cls_labels[bs_mask] = point_cls_labels_single
if ret_box_labels:
if ret_box_labels and gt_box_of_fg_points.shape[0] > 0:
point_box_labels_single = point_box_labels.new_zeros((bs_mask.sum(), 8))
fg_point_box_labels = self.box_coder.encode_torch(
gt_boxes=gt_box_of_fg_points[:, :-1], points=points_single[fg_flag],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment