merge_augs.py 3.68 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List

zhangwenwei's avatar
zhangwenwei committed
4
5
import torch

zhangshilong's avatar
zhangshilong committed
6
from mmdet3d.structures import bbox3d2result, bbox3d_mapping_back, xywhr2xyxyr
7
from mmdet3d.utils import ConfigType
zhangshilong's avatar
zhangshilong committed
8
from ..layers import nms_bev, nms_normal_bev
zhangwenwei's avatar
zhangwenwei committed
9
10


11
12
13
def merge_aug_bboxes_3d(aug_results: List[dict],
                        aug_batch_input_metas: List[dict],
                        test_cfg: ConfigType) -> dict:
zhangwenwei's avatar
zhangwenwei committed
14
15
16
    """Merge augmented detection 3D bboxes and scores.

    Args:
17
        aug_results (List[dict]): The dict of detection results.
zhangwenwei's avatar
zhangwenwei committed
18
            The dict contains the following keys
19

zhangshilong's avatar
zhangshilong committed
20
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Detection bbox.
21
22
23
24
            - scores_3d (Tensor): Detection scores.
            - labels_3d (Tensor): Predicted box labels.
        aug_batch_input_metas (List[dict]): Meta information of each sample.
        test_cfg (dict or :obj:`ConfigDict`): Test config.
zhangwenwei's avatar
zhangwenwei committed
25
26

    Returns:
wangtai's avatar
wangtai committed
27
        dict: Bounding boxes results in cpu mode, containing merged results.
28

zhangshilong's avatar
zhangshilong committed
29
            - bbox_3d (:obj:`BaseInstance3DBoxes`): Merged detection bbox.
wangtai's avatar
wangtai committed
30
31
            - scores_3d (torch.Tensor): Merged detection scores.
            - labels_3d (torch.Tensor): Merged predicted box labels.
zhangwenwei's avatar
zhangwenwei committed
32
33
    """

34
    assert len(aug_results) == len(aug_batch_input_metas), \
35
36
37
        '"aug_results" should have the same length as ' \
        f'"aug_batch_input_metas", got len(aug_results)={len(aug_results)} ' \
        f'and len(aug_batch_input_metas)={len(aug_batch_input_metas)}'
zhangwenwei's avatar
zhangwenwei committed
38
39
40
41
42

    recovered_bboxes = []
    recovered_scores = []
    recovered_labels = []

43
44
45
46
    for bboxes, input_info in zip(aug_results, aug_batch_input_metas):
        scale_factor = input_info['pcd_scale_factor']
        pcd_horizontal_flip = input_info['pcd_horizontal_flip']
        pcd_vertical_flip = input_info['pcd_vertical_flip']
zhangwenwei's avatar
zhangwenwei committed
47
48
        recovered_scores.append(bboxes['scores_3d'])
        recovered_labels.append(bboxes['labels_3d'])
zhangshilong's avatar
zhangshilong committed
49
        bboxes = bbox3d_mapping_back(bboxes['bbox_3d'], scale_factor,
wuyuefeng's avatar
wuyuefeng committed
50
                                     pcd_horizontal_flip, pcd_vertical_flip)
zhangwenwei's avatar
zhangwenwei committed
51
52
53
54
55
56
57
58
        recovered_bboxes.append(bboxes)

    aug_bboxes = recovered_bboxes[0].cat(recovered_bboxes)
    aug_bboxes_for_nms = xywhr2xyxyr(aug_bboxes.bev)
    aug_scores = torch.cat(recovered_scores, dim=0)
    aug_labels = torch.cat(recovered_labels, dim=0)

    # TODO: use a more elegent way to deal with nms
jshilong's avatar
jshilong committed
59
    if test_cfg.get('use_rotate_nms', False):
60
        nms_func = nms_bev
zhangwenwei's avatar
zhangwenwei committed
61
    else:
62
        nms_func = nms_normal_bev
zhangwenwei's avatar
zhangwenwei committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    merged_bboxes = []
    merged_scores = []
    merged_labels = []

    # Apply multi-class nms when merge bboxes
    if len(aug_labels) == 0:
        return bbox3d2result(aug_bboxes, aug_scores, aug_labels)

    for class_id in range(torch.max(aug_labels).item() + 1):
        class_inds = (aug_labels == class_id)
        bboxes_i = aug_bboxes[class_inds]
        bboxes_nms_i = aug_bboxes_for_nms[class_inds, :]
        scores_i = aug_scores[class_inds]
        labels_i = aug_labels[class_inds]
        if len(bboxes_nms_i) == 0:
            continue
        selected = nms_func(bboxes_nms_i, scores_i, test_cfg.nms_thr)

        merged_bboxes.append(bboxes_i[selected, :])
        merged_scores.append(scores_i[selected])
        merged_labels.append(labels_i[selected])

    merged_bboxes = merged_bboxes[0].cat(merged_bboxes)
    merged_scores = torch.cat(merged_scores, dim=0)
    merged_labels = torch.cat(merged_labels, dim=0)

    _, order = merged_scores.sort(0, descending=True)
jshilong's avatar
jshilong committed
91
    num = min(test_cfg.get('max_num', 500), len(aug_bboxes))
zhangwenwei's avatar
zhangwenwei committed
92
93
94
95
96
97
98
    order = order[:num]

    merged_bboxes = merged_bboxes[order]
    merged_scores = merged_scores[order]
    merged_labels = merged_labels[order]

    return bbox3d2result(merged_bboxes, merged_scores, merged_labels)