merge_augs.py 3.49 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
3
import torch

zhangshilong's avatar
zhangshilong committed
4
5
from mmdet3d.structures import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr
from ..layers import nms_bev, nms_normal_bev
zhangwenwei's avatar
zhangwenwei committed
6
7


8
def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg):
zhangwenwei's avatar
zhangwenwei committed
9
10
11
12
13
    """Merge augmented detection 3D bboxes and scores.

    Args:
        aug_results (list[dict]): The dict of detection results.
            The dict contains the following keys
14

zhangshilong's avatar
zhangshilong committed
15
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
wangtai's avatar
wangtai committed
16
17
18
            - scores_3d (torch.Tensor): Detection scores.
            - labels_3d (torch.Tensor): Predicted box labels.
        img_metas (list[dict]): Meta information of each sample.
zhangwenwei's avatar
zhangwenwei committed
19
20
21
        test_cfg (dict): Test config.

    Returns:
wangtai's avatar
wangtai committed
22
        dict: Bounding boxes results in cpu mode, containing merged results.
23

zhangshilong's avatar
zhangshilong committed
24
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Merged detection bbox.
wangtai's avatar
wangtai committed
25
26
            - scores_3d (torch.Tensor): Merged detection scores.
            - labels_3d (torch.Tensor): Merged predicted box labels.
zhangwenwei's avatar
zhangwenwei committed
27
28
    """

29
    assert len(aug_results) == len(aug_batch_input_metas), \
zhangwenwei's avatar
zhangwenwei committed
30
        '"aug_results" should have the same length as "img_metas", got len(' \
31
32
        f'aug_results)={len(aug_results)} and ' \
        f'len(img_metas)={len(aug_batch_input_metas)}'
zhangwenwei's avatar
zhangwenwei committed
33
34
35
36
37

    recovered_bboxes = []
    recovered_scores = []
    recovered_labels = []

38
39
40
41
    for bboxes, input_info in zip(aug_results, aug_batch_input_metas):
        scale_factor = input_info['pcd_scale_factor']
        pcd_horizontal_flip = input_info['pcd_horizontal_flip']
        pcd_vertical_flip = input_info['pcd_vertical_flip']
zhangwenwei's avatar
zhangwenwei committed
42
43
        recovered_scores.append(bboxes['scores_3d'])
        recovered_labels.append(bboxes['labels_3d'])
zhangshilong's avatar
zhangshilong committed
44
        bboxes = bbox3d_mapping_back(bboxes['bbox_3d'], scale_factor,
wuyuefeng's avatar
wuyuefeng committed
45
                                     pcd_horizontal_flip, pcd_vertical_flip)
zhangwenwei's avatar
zhangwenwei committed
46
47
48
49
50
51
52
53
        recovered_bboxes.append(bboxes)

    aug_bboxes = recovered_bboxes[0].cat(recovered_bboxes)
    aug_bboxes_for_nms = xywhr2xyxyr(aug_bboxes.bev)
    aug_scores = torch.cat(recovered_scores, dim=0)
    aug_labels = torch.cat(recovered_labels, dim=0)

    # TODO: use a more elegent way to deal with nms
jshilong's avatar
jshilong committed
54
    if test_cfg.get('use_rotate_nms', False):
55
        nms_func = nms_bev
zhangwenwei's avatar
zhangwenwei committed
56
    else:
57
        nms_func = nms_normal_bev
zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    merged_bboxes = []
    merged_scores = []
    merged_labels = []

    # Apply multi-class nms when merge bboxes
    if len(aug_labels) == 0:
        return bbox3d2result(aug_bboxes, aug_scores, aug_labels)

    for class_id in range(torch.max(aug_labels).item() + 1):
        class_inds = (aug_labels == class_id)
        bboxes_i = aug_bboxes[class_inds]
        bboxes_nms_i = aug_bboxes_for_nms[class_inds, :]
        scores_i = aug_scores[class_inds]
        labels_i = aug_labels[class_inds]
        if len(bboxes_nms_i) == 0:
            continue
        selected = nms_func(bboxes_nms_i, scores_i, test_cfg.nms_thr)

        merged_bboxes.append(bboxes_i[selected, :])
        merged_scores.append(scores_i[selected])
        merged_labels.append(labels_i[selected])

    merged_bboxes = merged_bboxes[0].cat(merged_bboxes)
    merged_scores = torch.cat(merged_scores, dim=0)
    merged_labels = torch.cat(merged_labels, dim=0)

    _, order = merged_scores.sort(0, descending=True)
jshilong's avatar
jshilong committed
86
    num = min(test_cfg.get('max_num', 500), len(aug_bboxes))
zhangwenwei's avatar
zhangwenwei committed
87
88
89
90
91
92
93
    order = order[:num]

    merged_bboxes = merged_bboxes[order]
    merged_scores = merged_scores[order]
    merged_labels = merged_labels[order]

    return bbox3d2result(merged_bboxes, merged_scores, merged_labels)