test_time_aug.py 5.5 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
3
import warnings
from copy import deepcopy
jshilong's avatar
jshilong committed
4
from typing import Dict, List, Optional, Tuple, Union
zhangwenwei's avatar
zhangwenwei committed
5

6
import mmcv
jshilong's avatar
jshilong committed
7
8
from mmcv import BaseTransform
from mmengine.dataset import Compose
9

10
from mmdet3d.registry import TRANSFORMS
zhangwenwei's avatar
zhangwenwei committed
11
12


13
@TRANSFORMS.register_module()
jshilong's avatar
jshilong committed
14
class MultiScaleFlipAug3D(BaseTransform):
zhangwenwei's avatar
zhangwenwei committed
15
    """Test-time augmentation with multiple scales and flipping.
zhangwenwei's avatar
zhangwenwei committed
16
17
18
19
20
21

    Args:
        transforms (list[dict]): Transforms to apply in each augmentation.
        img_scale (tuple | list[tuple]: Images scales for resizing.
        pts_scale_ratio (float | list[float]): Points scale ratios for
            resizing.
22
23
24
25
        flip (bool, optional): Whether apply flip augmentation.
            Defaults to False.
        flip_direction (str | list[str], optional): Flip augmentation
            directions for images, options are "horizontal" and "vertical".
26
            If flip_direction is list, multiple flip augmentations will
zhangwenwei's avatar
zhangwenwei committed
27
            be applied. It has no effect when ``flip == False``.
liyinhao's avatar
liyinhao committed
28
            Defaults to "horizontal".
29
30
31
32
33
34
        pcd_horizontal_flip (bool, optional): Whether apply horizontal
            flip augmentation to point cloud. Defaults to True.
            Note that it works only when 'flip' is turned on.
        pcd_vertical_flip (bool, optional): Whether apply vertical flip
            augmentation to point cloud. Defaults to True.
            Note that it works only when 'flip' is turned on.
zhangwenwei's avatar
zhangwenwei committed
35
36
37
    """

    def __init__(self,
jshilong's avatar
jshilong committed
38
39
40
41
42
43
44
                 transforms: List[dict],
                 img_scale: Optional[Union[Tuple[int], List[Tuple[int]]]],
                 pts_scale_ratio: Union[float, List[float]],
                 flip: bool = False,
                 flip_direction: str = 'horizontal',
                 pcd_horizontal_flip: bool = False,
                 pcd_vertical_flip: bool = False) -> None:
zhangwenwei's avatar
zhangwenwei committed
45
46
47
48
49
50
51
52
53
54
        self.transforms = Compose(transforms)
        self.img_scale = img_scale if isinstance(img_scale,
                                                 list) else [img_scale]
        self.pts_scale_ratio = pts_scale_ratio \
            if isinstance(pts_scale_ratio, list) else[float(pts_scale_ratio)]

        assert mmcv.is_list_of(self.img_scale, tuple)
        assert mmcv.is_list_of(self.pts_scale_ratio, float)

        self.flip = flip
55
56
57
        self.pcd_horizontal_flip = pcd_horizontal_flip
        self.pcd_vertical_flip = pcd_vertical_flip

zhangwenwei's avatar
zhangwenwei committed
58
59
60
61
62
63
        self.flip_direction = flip_direction if isinstance(
            flip_direction, list) else [flip_direction]
        assert mmcv.is_list_of(self.flip_direction, str)
        if not self.flip and self.flip_direction != ['horizontal']:
            warnings.warn(
                'flip_direction has no effect when flip is set to False')
64
65
66
        if (self.flip and not any([(t['type'] == 'RandomFlip3D'
                                    or t['type'] == 'RandomFlip')
                                   for t in transforms])):
zhangwenwei's avatar
zhangwenwei committed
67
68
69
            warnings.warn(
                'flip has no effect when RandomFlip is not in transforms')

jshilong's avatar
jshilong committed
70
    def transform(self, results: Dict) -> List[Dict]:
71
72
73
74
75
76
        """Call function to augment common fields in results.

        Args:
            results (dict): Result dict contains the data to augment.

        Returns:
jshilong's avatar
jshilong committed
77
            List[dict]: The list contains the data that is augmented with
78
79
                different scales and flips.
        """
jshilong's avatar
jshilong committed
80
        aug_data_list = []
81
82
83
84
85

        # modified from `flip_aug = [False, True] if self.flip else [False]`
        # to reduce unnecessary scenes when using double flip augmentation
        # during test time
        flip_aug = [True] if self.flip else [False]
86
87
88
89
        pcd_horizontal_flip_aug = [False, True] \
            if self.flip and self.pcd_horizontal_flip else [False]
        pcd_vertical_flip_aug = [False, True] \
            if self.flip and self.pcd_vertical_flip else [False]
zhangwenwei's avatar
zhangwenwei committed
90
        for scale in self.img_scale:
91
92
            # TODO refactor according to augtest docs
            self.transforms.transforms[0].scale = scale
zhangwenwei's avatar
zhangwenwei committed
93
94
            for pts_scale_ratio in self.pts_scale_ratio:
                for flip in flip_aug:
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                    for pcd_horizontal_flip in pcd_horizontal_flip_aug:
                        for pcd_vertical_flip in pcd_vertical_flip_aug:
                            for direction in self.flip_direction:
                                # results.copy will cause bug
                                # since it is shallow copy
                                _results = deepcopy(results)
                                _results['scale'] = scale
                                _results['flip'] = flip
                                _results['pcd_scale_factor'] = \
                                    pts_scale_ratio
                                _results['flip_direction'] = direction
                                _results['pcd_horizontal_flip'] = \
                                    pcd_horizontal_flip
                                _results['pcd_vertical_flip'] = \
                                    pcd_vertical_flip
                                data = self.transforms(_results)
jshilong's avatar
jshilong committed
111
112
113
                                aug_data_list.append(data)

        return aug_data_list
zhangwenwei's avatar
zhangwenwei committed
114
115

    def __repr__(self):
116
        """str: Return a string that describes the module."""
zhangwenwei's avatar
zhangwenwei committed
117
118
119
        repr_str = self.__class__.__name__
        repr_str += f'(transforms={self.transforms}, '
        repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
yinchimaoliang's avatar
yinchimaoliang committed
120
        repr_str += f'pts_scale_ratio={self.pts_scale_ratio}, '
zhangwenwei's avatar
zhangwenwei committed
121
122
        repr_str += f'flip_direction={self.flip_direction})'
        return repr_str