Commit d3179c0f authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support CenterHead for RCNN training/testing

parent 2d269538
...@@ -302,6 +302,24 @@ class CenterHead(nn.Module): ...@@ -302,6 +302,24 @@ class CenterHead(nn.Module):
return ret_dict return ret_dict
@staticmethod
def reorder_rois_for_refining(batch_size, pred_dicts):
num_max_rois = max([len(cur_dict['pred_boxes']) for cur_dict in pred_dicts])
num_max_rois = max(1, num_max_rois) # at least one faked rois to avoid error
pred_boxes = pred_dicts[0]['pred_boxes']
rois = pred_boxes.new_zeros((batch_size, num_max_rois, pred_boxes.shape[-1]))
roi_scores = pred_boxes.new_zeros((batch_size, num_max_rois))
roi_labels = pred_boxes.new_zeros((batch_size, num_max_rois)).long()
for bs_idx in range(batch_size):
num_boxes = len(pred_dicts[bs_idx]['pred_boxes'])
rois[bs_idx, :num_boxes, :] = pred_dicts[bs_idx]['pred_boxes']
roi_scores[bs_idx, :num_boxes] = pred_dicts[bs_idx]['pred_scores']
roi_labels[bs_idx, :num_boxes] = pred_dicts[bs_idx]['pred_labels']
return rois, roi_scores, roi_labels
def forward(self, data_dict): def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d'] spatial_features_2d = data_dict['spatial_features_2d']
x = self.shared_conv(spatial_features_2d) x = self.shared_conv(spatial_features_2d)
...@@ -320,9 +338,17 @@ class CenterHead(nn.Module): ...@@ -320,9 +338,17 @@ class CenterHead(nn.Module):
self.forward_ret_dict['pred_dicts'] = pred_dicts self.forward_ret_dict['pred_dicts'] = pred_dicts
if not self.training or self.predict_boxes_when_training: if not self.training or self.predict_boxes_when_training:
final_box_dicts = self.generate_predicted_boxes( pred_dicts = self.generate_predicted_boxes(
data_dict['batch_size'], pred_dicts data_dict['batch_size'], pred_dicts
) )
data_dict['final_box_dicts'] = final_box_dicts
if self.predict_boxes_when_training:
rois, roi_scores, roi_labels = self.reorder_rois_for_refining(data_dict['batch_size'], pred_dicts)
data_dict['rois'] = rois
data_dict['roi_scores'] = roi_scores
data_dict['roi_labels'] = roi_labels
data_dict['has_class_labels'] = True
else:
data_dict['final_box_dicts'] = pred_dicts
return data_dict return data_dict
...@@ -61,6 +61,9 @@ class RoIHeadTemplate(nn.Module): ...@@ -61,6 +61,9 @@ class RoIHeadTemplate(nn.Module):
roi_labels: (B, num_rois) roi_labels: (B, num_rois)
""" """
if batch_dict.get('rois', None) is not None:
return batch_dict
batch_size = batch_dict['batch_size'] batch_size = batch_dict['batch_size']
batch_box_preds = batch_dict['batch_box_preds'] batch_box_preds = batch_dict['batch_box_preds']
batch_cls_preds = batch_dict['batch_cls_preds'] batch_cls_preds = batch_dict['batch_cls_preds']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment