test_time_aug.py 5.48 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
3
import warnings
from copy import deepcopy
jshilong's avatar
jshilong committed
4
from typing import Dict, List, Optional, Tuple, Union
zhangwenwei's avatar
zhangwenwei committed
5

6
import mmengine
jshilong's avatar
jshilong committed
7
8
from mmcv import BaseTransform
from mmengine.dataset import Compose
9

10
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
14
class MultiScaleFlipAug3D(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
15
    """Test-time augmentation with multiple scales and flipping.
zhangwenwei's avatar
zhangwenwei committed
16
17
18

    Args:
        transforms (list[dict]): Transforms to apply in each augmentation.
19
        img_scale (tuple | list[tuple]): Images scales for resizing.
zhangwenwei's avatar
zhangwenwei committed
20
21
        pts_scale_ratio (float | list[float]): Points scale ratios for
            resizing.
22
23
24
        flip (bool): Whether apply flip augmentation. Defaults to False.
        flip_direction (str | list[str]): Flip augmentation directions
            for images, options are "horizontal" and "vertical".
25
            If flip_direction is list, multiple flip augmentations will
zhangwenwei's avatar
zhangwenwei committed
26
            be applied. It has no effect when ``flip == False``.
27
            Defaults to 'horizontal'.
28
29
        pcd_horizontal_flip (bool): Whether to apply horizontal flip
            augmentation to point cloud. Defaults to False.
30
            Note that it works only when 'flip' is turned on.
31
32
        pcd_vertical_flip (bool): Whether to apply vertical flip
            augmentation to point cloud. Defaults to False.
33
            Note that it works only when 'flip' is turned on.
zhangwenwei's avatar
zhangwenwei committed
34
35
36
    """

    def __init__(self,
jshilong's avatar
jshilong committed
37
38
39
40
41
42
43
                 transforms: List[dict],
                 img_scale: Optional[Union[Tuple[int], List[Tuple[int]]]],
                 pts_scale_ratio: Union[float, List[float]],
                 flip: bool = False,
                 flip_direction: str = 'horizontal',
                 pcd_horizontal_flip: bool = False,
                 pcd_vertical_flip: bool = False) -> None:
zhangwenwei's avatar
zhangwenwei committed
44
45
46
47
        self.transforms = Compose(transforms)
        self.img_scale = img_scale if isinstance(img_scale,
                                                 list) else [img_scale]
        self.pts_scale_ratio = pts_scale_ratio \
48
            if isinstance(pts_scale_ratio, list) else [float(pts_scale_ratio)]
zhangwenwei's avatar
zhangwenwei committed
49

50
51
        assert mmengine.is_list_of(self.img_scale, tuple)
        assert mmengine.is_list_of(self.pts_scale_ratio, float)
zhangwenwei's avatar
zhangwenwei committed
52
53

        self.flip = flip
54
55
56
        self.pcd_horizontal_flip = pcd_horizontal_flip
        self.pcd_vertical_flip = pcd_vertical_flip

zhangwenwei's avatar
zhangwenwei committed
57
58
        self.flip_direction = flip_direction if isinstance(
            flip_direction, list) else [flip_direction]
59
        assert mmengine.is_list_of(self.flip_direction, str)
zhangwenwei's avatar
zhangwenwei committed
60
61
62
        if not self.flip and self.flip_direction != ['horizontal']:
            warnings.warn(
                'flip_direction has no effect when flip is set to False')
63
64
65
        if (self.flip and not any([(t['type'] == 'RandomFlip3D'
                                    or t['type'] == 'RandomFlip')
                                   for t in transforms])):
zhangwenwei's avatar
zhangwenwei committed
66
67
68
            warnings.warn(
                'flip has no effect when RandomFlip is not in transforms')

jshilong's avatar
jshilong committed
69
    def transform(self, results: Dict) -> List[Dict]:
70
71
72
73
74
75
        """Call function to augment common fields in results.

        Args:
            results (dict): Result dict contains the data to augment.

        Returns:
jshilong's avatar
jshilong committed
76
            List[dict]: The list contains the data that is augmented with
77
            different scales and flips.
78
        """
jshilong's avatar
jshilong committed
79
        aug_data_list = []
80
81
82
83
84

        # modified from `flip_aug = [False, True] if self.flip else [False]`
        # to reduce unnecessary scenes when using double flip augmentation
        # during test time
        flip_aug = [True] if self.flip else [False]
85
86
87
88
        pcd_horizontal_flip_aug = [False, True] \
            if self.flip and self.pcd_horizontal_flip else [False]
        pcd_vertical_flip_aug = [False, True] \
            if self.flip and self.pcd_vertical_flip else [False]
zhangwenwei's avatar
zhangwenwei committed
89
        for scale in self.img_scale:
90
91
            # TODO refactor according to augtest docs
            self.transforms.transforms[0].scale = scale
zhangwenwei's avatar
zhangwenwei committed
92
93
            for pts_scale_ratio in self.pts_scale_ratio:
                for flip in flip_aug:
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
                    for pcd_horizontal_flip in pcd_horizontal_flip_aug:
                        for pcd_vertical_flip in pcd_vertical_flip_aug:
                            for direction in self.flip_direction:
                                # results.copy will cause bug
                                # since it is shallow copy
                                _results = deepcopy(results)
                                _results['scale'] = scale
                                _results['flip'] = flip
                                _results['pcd_scale_factor'] = \
                                    pts_scale_ratio
                                _results['flip_direction'] = direction
                                _results['pcd_horizontal_flip'] = \
                                    pcd_horizontal_flip
                                _results['pcd_vertical_flip'] = \
                                    pcd_vertical_flip
                                data = self.transforms(_results)
jshilong's avatar
jshilong committed
110
111
112
                                aug_data_list.append(data)

        return aug_data_list
zhangwenwei's avatar
zhangwenwei committed
113

114
    def __repr__(self) -> str:
115
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
116
117
118
        repr_str = self.__class__.__name__
        repr_str += f'(transforms={self.transforms}, '
        repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
yinchimaoliang's avatar
yinchimaoliang committed
119
        repr_str += f'pts_scale_ratio={self.pts_scale_ratio}, '
zhangwenwei's avatar
zhangwenwei committed
120
121
        repr_str += f'flip_direction={self.flip_direction})'
        return repr_str