Commit a236428b authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support multi-classes nms for multi-head, not checked

parent 6901df66
...@@ -229,19 +229,11 @@ class AnchorHeadMulti(AnchorHeadTemplate): ...@@ -229,19 +229,11 @@ class AnchorHeadMulti(AnchorHeadTemplate):
) )
if isinstance(batch_cls_preds, list): if isinstance(batch_cls_preds, list):
all_pred_labels = [] multihead_label_mapping = []
all_cls_preds = [] for idx in range(len(batch_cls_preds)):
for idx, cls_pred in enumerate(batch_cls_preds): multihead_label_mapping.append(self.rpn_heads[idx].head_label_indices)
pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
pred_label = self.rpn_heads[idx].head_label_indices[pred_head_label] data_dict['multihead_label_mapping'] = multihead_label_mapping
all_pred_labels.append(pred_label)
all_cls_preds.append(pred_score[:, :, None])
batch_cls_preds = torch.cat(all_cls_preds, dim=1)
batch_pred_labels = torch.cat(all_pred_labels, dim=1)
data_dict['batch_pred_labels'] = batch_pred_labels
data_dict['has_class_labels'] = True
data_dict['batch_cls_preds'] = batch_cls_preds data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds data_dict['batch_box_preds'] = batch_box_preds
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .. import backbones_3d, backbones_2d, dense_heads, roi_heads from .. import backbones_3d, backbones_2d, dense_heads, roi_heads
from ..backbones_3d import vfe, pfe from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms from ..model_utils import model_nms_utils
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
...@@ -169,6 +169,8 @@ class Detector3DTemplate(nn.Module): ...@@ -169,6 +169,8 @@ class Detector3DTemplate(nn.Module):
batch_dict: batch_dict:
batch_size: batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1) batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
multihead_label_mapping: [(num_class1), (num_class2), ...]
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C) batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...) batch_index: optional (N1+N2+...)
...@@ -184,32 +186,62 @@ class Detector3DTemplate(nn.Module): ...@@ -184,32 +186,62 @@ class Detector3DTemplate(nn.Module):
pred_dicts = [] pred_dicts = []
for index in range(batch_size): for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None: if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2 assert batch_dict['batch_box_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index) batch_mask = (batch_dict['batch_index'] == index)
else: else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3 assert batch_dict['batch_box_preds'].shape.__len__() == 3
batch_mask = index batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask] box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
src_box_preds = box_preds src_box_preds = box_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']: if not isinstance(batch_dict['batch_cls_preds'], list):
cls_preds = torch.sigmoid(cls_preds) cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
else:
cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']]
src_cls_preds = cls_preds
if not batch_dict['cls_preds_normalized']:
cls_preds = [torch.sigmoid(x) for x in cls_preds]
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS: if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)]
else:
multihead_label_mapping = batch_dict['multihead_label_mapping']
cur_start_idx = 0
pred_scores, pred_labels, pred_boxes = [], [], []
for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping):
assert cur_cls_preds.shape[1] == len(cur_label_mapping)
cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]]
cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms(
cls_scores=cur_cls_preds, box_preds=cur_box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
cur_pred_labels = cur_label_mapping[cur_pred_labels]
pred_scores.append(cur_pred_scores)
pred_labels.append(cur_pred_labels)
pred_boxes.append(cur_pred_boxes)
final_scores = torch.cat(pred_scores, dim=0)
final_labels = torch.cat(pred_labels, dim=0)
final_boxes = torch.cat(pred_boxes, dim=0)
else: else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1) cls_preds, label_preds = torch.max(cls_preds, dim=-1)
if batch_dict.get('has_class_labels', False): if batch_dict.get('has_class_labels', False):
label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels' label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
label_preds = batch_dict[label_key][index] label_preds = batch_dict[label_key][index]
else: else:
label_preds + 1 label_preds = label_preds + 1
selected, selected_scores = model_nms_utils.class_agnostic_nms(
selected, selected_scores = class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds, box_scores=cls_preds, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG, nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH score_thresh=post_process_cfg.SCORE_THRESH
......
...@@ -22,3 +22,46 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None): ...@@ -22,3 +22,46 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
original_idxs = scores_mask.nonzero().view(-1) original_idxs = scores_mask.nonzero().view(-1)
selected = original_idxs[selected] selected = original_idxs[selected]
return selected, src_box_scores[selected] return selected, src_box_scores[selected]
def multi_classes_nms(cls_scores, box_preds, nms_config, score_thresh=None):
"""
Args:
cls_scores: (N, num_class)
box_preds: (N, 7 + C)
nms_config:
score_thresh:
Returns:
"""
pred_scores, pred_labels, pred_boxes = [], [], []
for k in range(cls_scores.shape[0]):
if score_thresh is not None:
scores_mask = (cls_scores[:, k] >= score_thresh)
box_scores = cls_scores[scores_mask, k]
box_preds = box_preds[scores_mask]
else:
box_scores = cls_scores[:, k]
selected = []
if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config
)
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
if score_thresh is not None:
selected = scores_mask.nonzero().view(-1)
pred_scores.append(box_scores[selected])
pred_labels.append(box_scores.new_ones(selected.shape[0]) * k)
pred_boxes.append(box_preds[selected])
pred_scores = torch.cat(pred_scores, dim=0)
pred_labels = torch.cat(pred_labels, dim=0)
pred_boxes = torch.cat(pred_boxes, dim=0)
return pred_scores, pred_labels, pred_boxes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment