Commit 78ee07ea authored by wangtai's avatar wangtai Committed by zhangwenwei
Browse files

Fix RandomFlip3D in test time augmentation

parent 660f3ccc
...@@ -38,14 +38,24 @@ test_pipeline = [ ...@@ -38,14 +38,24 @@ test_pipeline = [
dict( dict(
type='MultiScaleFlipAug', type='MultiScaleFlipAug',
img_scale=(1333, 800), img_scale=(1333, 800),
pts_scale_ratio=1.0,
flip=False, flip=False,
pcd_horizontal_flip=False,
pcd_vertical_flip=False,
transforms=[ transforms=[
dict(type='Resize', keep_ratio=True), dict(
dict(type='RandomFlip'), type='GlobalRotScaleTrans',
dict(type='Normalize', **img_norm_cfg), rot_range=[0, 0],
dict(type='Pad', size_divisor=32), scale_ratio_range=[1., 1.],
dict(type='ImageToTensor', keys=['img']), translation_std=[0, 0, 0]),
dict(type='Collect', keys=['img']), dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
]) ])
] ]
``` ```
...@@ -122,7 +132,8 @@ For each operation, we list the related dict fields that are added/updated/remov ...@@ -122,7 +132,8 @@ For each operation, we list the related dict fields that are added/updated/remov
### Test time augmentation ### Test time augmentation
`MultiScaleFlipAug` `MultiScaleFlipAug3D`
- update: all the dict fields (update values to the collection of augmented data)
## Extend and use custom pipelines ## Extend and use custom pipelines
......
...@@ -17,10 +17,17 @@ class MultiScaleFlipAug3D(object): ...@@ -17,10 +17,17 @@ class MultiScaleFlipAug3D(object):
pts_scale_ratio (float | list[float]): Points scale ratios for pts_scale_ratio (float | list[float]): Points scale ratios for
resizing. resizing.
flip (bool): Whether apply flip augmentation. Default: False. flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions, flip_direction (str | list[str]): Flip augmentation directions
options are "horizontal" and "vertical". If flip_direction is list, for images, options are "horizontal" and "vertical".
multiple flip augmentations will be applied. If flip_direction is list, multiple flip augmentations will
It has no effect when flip == False. Default: "horizontal". be applied. It has no effect when flip == False.
Default: "horizontal".
pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation
to point cloud. Default: True. Note that it works only when
'flip' is turned on.
pcd_vertical_flip (bool): Whether apply vertical flip augmentation
to point cloud. Default: True. Note that it works only when
'flip' is turned on.
""" """
def __init__(self, def __init__(self,
...@@ -28,7 +35,9 @@ class MultiScaleFlipAug3D(object): ...@@ -28,7 +35,9 @@ class MultiScaleFlipAug3D(object):
img_scale, img_scale,
pts_scale_ratio, pts_scale_ratio,
flip=False, flip=False,
flip_direction='horizontal'): flip_direction='horizontal',
pcd_horizontal_flip=True,
pcd_vertical_flip=True):
self.transforms = Compose(transforms) self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale, self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
...@@ -39,32 +48,48 @@ class MultiScaleFlipAug3D(object): ...@@ -39,32 +48,48 @@ class MultiScaleFlipAug3D(object):
assert mmcv.is_list_of(self.pts_scale_ratio, float) assert mmcv.is_list_of(self.pts_scale_ratio, float)
self.flip = flip self.flip = flip
self.pcd_horizontal_flip = pcd_horizontal_flip
self.pcd_vertical_flip = pcd_vertical_flip
self.flip_direction = flip_direction if isinstance( self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction] flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str) assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']: if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn( warnings.warn(
'flip_direction has no effect when flip is set to False') 'flip_direction has no effect when flip is set to False')
if (self.flip if (self.flip and not any([(t['type'] == 'RandomFlip3D'
and not any([t['type'] == 'RandomFlip' for t in transforms])): or t['type'] == 'RandomFlip')
for t in transforms])):
warnings.warn( warnings.warn(
'flip has no effect when RandomFlip is not in transforms') 'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results): def __call__(self, results):
aug_data = [] aug_data = []
flip_aug = [False, True] if self.flip else [False] flip_aug = [False, True] if self.flip else [False]
pcd_horizontal_flip_aug = [False, True] \
if self.flip and self.pcd_horizontal_flip else [False]
pcd_vertical_flip_aug = [False, True] \
if self.flip and self.pcd_vertical_flip else [False]
for scale in self.img_scale: for scale in self.img_scale:
for pts_scale_ratio in self.pts_scale_ratio: for pts_scale_ratio in self.pts_scale_ratio:
for flip in flip_aug: for flip in flip_aug:
for direction in self.flip_direction: for pcd_horizontal_flip in pcd_horizontal_flip_aug:
# results.copy will cause bug since it is shallow copy for pcd_vertical_flip in pcd_vertical_flip_aug:
_results = deepcopy(results) for direction in self.flip_direction:
_results['scale'] = scale # results.copy will cause bug
_results['flip'] = flip # since it is shallow copy
_results['pcd_scale_factor'] = pts_scale_ratio _results = deepcopy(results)
_results['flip_direction'] = direction _results['scale'] = scale
data = self.transforms(_results) _results['flip'] = flip
aug_data.append(data) _results['pcd_scale_factor'] = \
pts_scale_ratio
_results['flip_direction'] = direction
_results['pcd_horizontal_flip'] = \
pcd_horizontal_flip
_results['pcd_vertical_flip'] = \
pcd_vertical_flip
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list # list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]} aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data: for data in aug_data:
......
...@@ -47,6 +47,11 @@ class RandomFlip3D(RandomFlip): ...@@ -47,6 +47,11 @@ class RandomFlip3D(RandomFlip):
def random_flip_data_3d(self, input_dict, direction='horizontal'): def random_flip_data_3d(self, input_dict, direction='horizontal'):
assert direction in ['horizontal', 'vertical'] assert direction in ['horizontal', 'vertical']
if len(input_dict['bbox3d_fields']) == 0: # test mode
input_dict['bbox3d_fields'].append('empty_box3d')
input_dict['empty_box3d'] = input_dict['box_type_3d'](
np.array([], dtype=np.float32))
assert len(input_dict['bbox3d_fields']) == 1
for key in input_dict['bbox3d_fields']: for key in input_dict['bbox3d_fields']:
input_dict['points'] = input_dict[key].flip( input_dict['points'] = input_dict[key].flip(
direction, points=input_dict['points']) direction, points=input_dict['points'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment