test_time_aug.py 5.32 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import mmcv
zhangwenwei's avatar
zhangwenwei committed
3
4
5
6
7
8
9
10
11
import warnings
from copy import deepcopy

from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose


@PIPELINES.register_module()
class MultiScaleFlipAug3D(object):
zhangwenwei's avatar
zhangwenwei committed
12
    """Test-time augmentation with multiple scales and flipping.
zhangwenwei's avatar
zhangwenwei committed
13
14
15
16
17
18

    Args:
        transforms (list[dict]): Transforms to apply in each augmentation.
        img_scale (tuple | list[tuple]: Images scales for resizing.
        pts_scale_ratio (float | list[float]): Points scale ratios for
            resizing.
liyinhao's avatar
liyinhao committed
19
        flip (bool): Whether apply flip augmentation. Defaults to False.
20
21
22
        flip_direction (str | list[str]): Flip augmentation directions
            for images, options are "horizontal" and "vertical".
            If flip_direction is list, multiple flip augmentations will
zhangwenwei's avatar
zhangwenwei committed
23
            be applied. It has no effect when ``flip == False``.
liyinhao's avatar
liyinhao committed
24
            Defaults to "horizontal".
25
        pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation
liyinhao's avatar
liyinhao committed
26
            to point cloud. Defaults to True. Note that it works only when
27
28
            'flip' is turned on.
        pcd_vertical_flip (bool): Whether apply vertical flip augmentation
liyinhao's avatar
liyinhao committed
29
            to point cloud. Defaults to True. Note that it works only when
30
            'flip' is turned on.
zhangwenwei's avatar
zhangwenwei committed
31
32
33
34
35
36
37
    """

    def __init__(self,
                 transforms,
                 img_scale,
                 pts_scale_ratio,
                 flip=False,
38
                 flip_direction='horizontal',
zhangwenwei's avatar
zhangwenwei committed
39
40
                 pcd_horizontal_flip=False,
                 pcd_vertical_flip=False):
zhangwenwei's avatar
zhangwenwei committed
41
42
43
44
45
46
47
48
49
50
        self.transforms = Compose(transforms)
        self.img_scale = img_scale if isinstance(img_scale,
                                                 list) else [img_scale]
        self.pts_scale_ratio = pts_scale_ratio \
            if isinstance(pts_scale_ratio, list) else[float(pts_scale_ratio)]

        assert mmcv.is_list_of(self.img_scale, tuple)
        assert mmcv.is_list_of(self.pts_scale_ratio, float)

        self.flip = flip
51
52
53
        self.pcd_horizontal_flip = pcd_horizontal_flip
        self.pcd_vertical_flip = pcd_vertical_flip

zhangwenwei's avatar
zhangwenwei committed
54
55
56
57
58
59
        self.flip_direction = flip_direction if isinstance(
            flip_direction, list) else [flip_direction]
        assert mmcv.is_list_of(self.flip_direction, str)
        if not self.flip and self.flip_direction != ['horizontal']:
            warnings.warn(
                'flip_direction has no effect when flip is set to False')
60
61
62
        if (self.flip and not any([(t['type'] == 'RandomFlip3D'
                                    or t['type'] == 'RandomFlip')
                                   for t in transforms])):
zhangwenwei's avatar
zhangwenwei committed
63
64
65
66
            warnings.warn(
                'flip has no effect when RandomFlip is not in transforms')

    def __call__(self, results):
67
68
69
70
71
72
73
74
75
        """Call function to augment common fields in results.

        Args:
            results (dict): Result dict contains the data to augment.

        Returns:
            dict: The result dict contains the data that is augmented with \
                different scales and flips.
        """
zhangwenwei's avatar
zhangwenwei committed
76
        aug_data = []
77
78
79
80
81

        # modified from `flip_aug = [False, True] if self.flip else [False]`
        # to reduce unnecessary scenes when using double flip augmentation
        # during test time
        flip_aug = [True] if self.flip else [False]
82
83
84
85
        pcd_horizontal_flip_aug = [False, True] \
            if self.flip and self.pcd_horizontal_flip else [False]
        pcd_vertical_flip_aug = [False, True] \
            if self.flip and self.pcd_vertical_flip else [False]
zhangwenwei's avatar
zhangwenwei committed
86
87
88
        for scale in self.img_scale:
            for pts_scale_ratio in self.pts_scale_ratio:
                for flip in flip_aug:
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
                    for pcd_horizontal_flip in pcd_horizontal_flip_aug:
                        for pcd_vertical_flip in pcd_vertical_flip_aug:
                            for direction in self.flip_direction:
                                # results.copy will cause bug
                                # since it is shallow copy
                                _results = deepcopy(results)
                                _results['scale'] = scale
                                _results['flip'] = flip
                                _results['pcd_scale_factor'] = \
                                    pts_scale_ratio
                                _results['flip_direction'] = direction
                                _results['pcd_horizontal_flip'] = \
                                    pcd_horizontal_flip
                                _results['pcd_vertical_flip'] = \
                                    pcd_vertical_flip
                                data = self.transforms(_results)
                                aug_data.append(data)
zhangwenwei's avatar
zhangwenwei committed
106
107
108
109
110
111
112
113
        # list of dict to dict of list
        aug_data_dict = {key: [] for key in aug_data[0]}
        for data in aug_data:
            for key, val in data.items():
                aug_data_dict[key].append(val)
        return aug_data_dict

    def __repr__(self):
114
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
115
116
117
        repr_str = self.__class__.__name__
        repr_str += f'(transforms={self.transforms}, '
        repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
yinchimaoliang's avatar
yinchimaoliang committed
118
        repr_str += f'pts_scale_ratio={self.pts_scale_ratio}, '
zhangwenwei's avatar
zhangwenwei committed
119
120
        repr_str += f'flip_direction={self.flip_direction})'
        return repr_str