import torch from ...ops.iou3d_nms import iou3d_nms_utils def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None): src_box_scores = box_scores if score_thresh is not None: scores_mask = (box_scores >= score_thresh) box_scores = box_scores[scores_mask] box_preds = box_preds[scores_mask] selected = [] if box_scores.shape[0] > 0: box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0])) boxes_for_nms = box_preds[indices] keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)( boxes_for_nms[:, 0:7], box_scores_nms, nms_config.NMS_THRESH, **nms_config ) selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]] if score_thresh is not None: original_idxs = scores_mask.nonzero().view(-1) selected = original_idxs[selected] return selected, src_box_scores[selected]