Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
78ee07ea
Commit
78ee07ea
authored
Jul 03, 2020
by
wangtai
Committed by
zhangwenwei
Jul 03, 2020
Browse files
Fix RandomFlip3D in test time augmentation
parent
660f3ccc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
23 deletions
+64
-23
docs/tutorials/data_pipeline.md
docs/tutorials/data_pipeline.md
+18
-7
mmdet3d/datasets/pipelines/test_time_aug.py
mmdet3d/datasets/pipelines/test_time_aug.py
+41
-16
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+5
-0
No files found.
docs/tutorials/data_pipeline.md
View file @
78ee07ea
...
@@ -38,14 +38,24 @@ test_pipeline = [
...
@@ -38,14 +38,24 @@ test_pipeline = [
dict
(
dict
(
type
=
'MultiScaleFlipAug'
,
type
=
'MultiScaleFlipAug'
,
img_scale
=
(
1333
,
800
),
img_scale
=
(
1333
,
800
),
pts_scale_ratio
=
1.0
,
flip
=
False
,
flip
=
False
,
pcd_horizontal_flip
=
False
,
pcd_vertical_flip
=
False
,
transforms
=
[
transforms
=
[
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
dict
(
dict
(
type
=
'RandomFlip'
),
type
=
'GlobalRotScaleTrans'
,
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
rot_range
=
[
0
,
0
],
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
scale_ratio_range
=
[
1.
,
1.
],
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
translation_std
=
[
0
,
0
,
0
]),
dict
(
type
=
'Collect'
,
keys
=
[
'img'
]),
dict
(
type
=
'RandomFlip3D'
),
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
])
]
]
```
```
...
@@ -122,7 +132,8 @@ For each operation, we list the related dict fields that are added/updated/remov
...
@@ -122,7 +132,8 @@ For each operation, we list the related dict fields that are added/updated/remov
### Test time augmentation
### Test time augmentation
`MultiScaleFlipAug`
`MultiScaleFlipAug3D`
-
update: all the dict fields (update values to the collection of augmented data)
## Extend and use custom pipelines
## Extend and use custom pipelines
...
...
mmdet3d/datasets/pipelines/test_time_aug.py
View file @
78ee07ea
...
@@ -17,10 +17,17 @@ class MultiScaleFlipAug3D(object):
...
@@ -17,10 +17,17 @@ class MultiScaleFlipAug3D(object):
pts_scale_ratio (float | list[float]): Points scale ratios for
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
flip_direction (str | list[str]): Flip augmentation directions
options are "horizontal" and "vertical". If flip_direction is list,
for images, options are "horizontal" and "vertical".
multiple flip augmentations will be applied.
If flip_direction is list, multiple flip augmentations will
It has no effect when flip == False. Default: "horizontal".
be applied. It has no effect when flip == False.
Default: "horizontal".
pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation
to point cloud. Default: True. Note that it works only when
'flip' is turned on.
pcd_vertical_flip (bool): Whether apply vertical flip augmentation
to point cloud. Default: True. Note that it works only when
'flip' is turned on.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -28,7 +35,9 @@ class MultiScaleFlipAug3D(object):
...
@@ -28,7 +35,9 @@ class MultiScaleFlipAug3D(object):
img_scale
,
img_scale
,
pts_scale_ratio
,
pts_scale_ratio
,
flip
=
False
,
flip
=
False
,
flip_direction
=
'horizontal'
):
flip_direction
=
'horizontal'
,
pcd_horizontal_flip
=
True
,
pcd_vertical_flip
=
True
):
self
.
transforms
=
Compose
(
transforms
)
self
.
transforms
=
Compose
(
transforms
)
self
.
img_scale
=
img_scale
if
isinstance
(
img_scale
,
self
.
img_scale
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
list
)
else
[
img_scale
]
...
@@ -39,32 +48,48 @@ class MultiScaleFlipAug3D(object):
...
@@ -39,32 +48,48 @@ class MultiScaleFlipAug3D(object):
assert
mmcv
.
is_list_of
(
self
.
pts_scale_ratio
,
float
)
assert
mmcv
.
is_list_of
(
self
.
pts_scale_ratio
,
float
)
self
.
flip
=
flip
self
.
flip
=
flip
self
.
pcd_horizontal_flip
=
pcd_horizontal_flip
self
.
pcd_vertical_flip
=
pcd_vertical_flip
self
.
flip_direction
=
flip_direction
if
isinstance
(
self
.
flip_direction
=
flip_direction
if
isinstance
(
flip_direction
,
list
)
else
[
flip_direction
]
flip_direction
,
list
)
else
[
flip_direction
]
assert
mmcv
.
is_list_of
(
self
.
flip_direction
,
str
)
assert
mmcv
.
is_list_of
(
self
.
flip_direction
,
str
)
if
not
self
.
flip
and
self
.
flip_direction
!=
[
'horizontal'
]:
if
not
self
.
flip
and
self
.
flip_direction
!=
[
'horizontal'
]:
warnings
.
warn
(
warnings
.
warn
(
'flip_direction has no effect when flip is set to False'
)
'flip_direction has no effect when flip is set to False'
)
if
(
self
.
flip
if
(
self
.
flip
and
not
any
([(
t
[
'type'
]
==
'RandomFlip3D'
and
not
any
([
t
[
'type'
]
==
'RandomFlip'
for
t
in
transforms
])):
or
t
[
'type'
]
==
'RandomFlip'
)
for
t
in
transforms
])):
warnings
.
warn
(
warnings
.
warn
(
'flip has no effect when RandomFlip is not in transforms'
)
'flip has no effect when RandomFlip is not in transforms'
)
def
__call__
(
self
,
results
):
def
__call__
(
self
,
results
):
aug_data
=
[]
aug_data
=
[]
flip_aug
=
[
False
,
True
]
if
self
.
flip
else
[
False
]
flip_aug
=
[
False
,
True
]
if
self
.
flip
else
[
False
]
pcd_horizontal_flip_aug
=
[
False
,
True
]
\
if
self
.
flip
and
self
.
pcd_horizontal_flip
else
[
False
]
pcd_vertical_flip_aug
=
[
False
,
True
]
\
if
self
.
flip
and
self
.
pcd_vertical_flip
else
[
False
]
for
scale
in
self
.
img_scale
:
for
scale
in
self
.
img_scale
:
for
pts_scale_ratio
in
self
.
pts_scale_ratio
:
for
pts_scale_ratio
in
self
.
pts_scale_ratio
:
for
flip
in
flip_aug
:
for
flip
in
flip_aug
:
for
direction
in
self
.
flip_direction
:
for
pcd_horizontal_flip
in
pcd_horizontal_flip_aug
:
# results.copy will cause bug since it is shallow copy
for
pcd_vertical_flip
in
pcd_vertical_flip_aug
:
_results
=
deepcopy
(
results
)
for
direction
in
self
.
flip_direction
:
_results
[
'scale'
]
=
scale
# results.copy will cause bug
_results
[
'flip'
]
=
flip
# since it is shallow copy
_results
[
'pcd_scale_factor'
]
=
pts_scale_ratio
_results
=
deepcopy
(
results
)
_results
[
'flip_direction'
]
=
direction
_results
[
'scale'
]
=
scale
data
=
self
.
transforms
(
_results
)
_results
[
'flip'
]
=
flip
aug_data
.
append
(
data
)
_results
[
'pcd_scale_factor'
]
=
\
pts_scale_ratio
_results
[
'flip_direction'
]
=
direction
_results
[
'pcd_horizontal_flip'
]
=
\
pcd_horizontal_flip
_results
[
'pcd_vertical_flip'
]
=
\
pcd_vertical_flip
data
=
self
.
transforms
(
_results
)
aug_data
.
append
(
data
)
# list of dict to dict of list
# list of dict to dict of list
aug_data_dict
=
{
key
:
[]
for
key
in
aug_data
[
0
]}
aug_data_dict
=
{
key
:
[]
for
key
in
aug_data
[
0
]}
for
data
in
aug_data
:
for
data
in
aug_data
:
...
...
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
78ee07ea
...
@@ -47,6 +47,11 @@ class RandomFlip3D(RandomFlip):
...
@@ -47,6 +47,11 @@ class RandomFlip3D(RandomFlip):
def
random_flip_data_3d
(
self
,
input_dict
,
direction
=
'horizontal'
):
def
random_flip_data_3d
(
self
,
input_dict
,
direction
=
'horizontal'
):
assert
direction
in
[
'horizontal'
,
'vertical'
]
assert
direction
in
[
'horizontal'
,
'vertical'
]
if
len
(
input_dict
[
'bbox3d_fields'
])
==
0
:
# test mode
input_dict
[
'bbox3d_fields'
].
append
(
'empty_box3d'
)
input_dict
[
'empty_box3d'
]
=
input_dict
[
'box_type_3d'
](
np
.
array
([],
dtype
=
np
.
float32
))
assert
len
(
input_dict
[
'bbox3d_fields'
])
==
1
for
key
in
input_dict
[
'bbox3d_fields'
]:
for
key
in
input_dict
[
'bbox3d_fields'
]:
input_dict
[
'points'
]
=
input_dict
[
key
].
flip
(
input_dict
[
'points'
]
=
input_dict
[
key
].
flip
(
direction
,
points
=
input_dict
[
'points'
])
direction
,
points
=
input_dict
[
'points'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment