Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
2844a116
Commit
2844a116
authored
Mar 23, 2022
by
Yifei Yang
Committed by
zhouzaida
Jul 19, 2022
Browse files
[Fix] Fix MultiScaleFlipAug (#1801)
* Fix MultiScaleFlipAug * fix as comment
parent
169f098d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
29 deletions
+97
-29
mmcv/transforms/processing.py
mmcv/transforms/processing.py
+25
-15
tests/test_transforms/test_transforms_processing.py
tests/test_transforms/test_transforms_processing.py
+72
-14
No files found.
mmcv/transforms/processing.py
View file @
2844a116
...
@@ -762,6 +762,8 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -762,6 +762,8 @@ class MultiScaleFlipAug(BaseTransform):
transforms (list[dict]): Transforms to be applied to each resized
transforms (list[dict]): Transforms to be applied to each resized
and flipped data.
and flipped data.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
flip (bool): Whether apply flip augmentation. Defaults to False.
flip (bool): Whether apply flip augmentation. Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions,
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
options are "horizontal", "vertical" and "diagonal". If
...
@@ -778,6 +780,7 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -778,6 +780,7 @@ class MultiScaleFlipAug(BaseTransform):
self
,
self
,
transforms
:
List
[
dict
],
transforms
:
List
[
dict
],
img_scale
:
Optional
[
Union
[
Tuple
,
List
[
Tuple
]]]
=
None
,
img_scale
:
Optional
[
Union
[
Tuple
,
List
[
Tuple
]]]
=
None
,
scale_factor
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
flip
:
bool
=
False
,
flip
:
bool
=
False
,
flip_direction
:
Union
[
str
,
List
[
str
]]
=
'horizontal'
,
flip_direction
:
Union
[
str
,
List
[
str
]]
=
'horizontal'
,
resize_cfg
:
dict
=
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
resize_cfg
:
dict
=
dict
(
type
=
'Resize'
,
keep_ratio
=
True
),
...
@@ -785,11 +788,20 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -785,11 +788,20 @@ class MultiScaleFlipAug(BaseTransform):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
transforms
=
Compose
(
transforms
)
# type: ignore
self
.
transforms
=
Compose
(
transforms
)
# type: ignore
assert
img_scale
is
not
None
if
img_scale
is
not
None
:
self
.
img_scale
=
img_scale
if
isinstance
(
img_scale
,
self
.
img_scale
=
img_scale
if
isinstance
(
img_scale
,
list
)
else
[
img_scale
]
list
)
else
[
img_scale
]
self
.
scale_key
=
'scale'
self
.
scale_key
=
'scale'
assert
mmcv
.
is_list_of
(
self
.
img_scale
,
tuple
)
assert
mmcv
.
is_list_of
(
self
.
img_scale
,
tuple
)
else
:
# if ``img_scale`` and ``scale_factor`` both be ``None``
if
scale_factor
is
None
:
self
.
img_scale
=
[
1.
]
else
:
self
.
img_scale
=
scale_factor
if
isinstance
(
scale_factor
,
list
)
else
[
scale_factor
]
self
.
scale_key
=
'scale_factor'
self
.
flip
=
flip
self
.
flip
=
flip
self
.
flip_direction
=
flip_direction
if
isinstance
(
self
.
flip_direction
=
flip_direction
if
isinstance
(
...
@@ -801,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -801,7 +813,7 @@ class MultiScaleFlipAug(BaseTransform):
self
.
resize_cfg
=
resize_cfg
self
.
resize_cfg
=
resize_cfg
self
.
flip_cfg
=
flip_cfg
self
.
flip_cfg
=
flip_cfg
def
transform
(
self
,
results
:
dict
)
->
dict
:
def
transform
(
self
,
results
:
dict
)
->
Tuple
[
List
,
List
]
:
"""Apply test time augment transforms on results.
"""Apply test time augment transforms on results.
Args:
Args:
...
@@ -813,6 +825,7 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -813,6 +825,7 @@ class MultiScaleFlipAug(BaseTransform):
"""
"""
aug_data
=
[]
aug_data
=
[]
input_data
=
[]
flip_args
=
[(
False
,
''
)]
flip_args
=
[(
False
,
''
)]
if
self
.
flip
:
if
self
.
flip
:
flip_args
+=
[(
True
,
direction
)
flip_args
+=
[(
True
,
direction
)
...
@@ -820,7 +833,7 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -820,7 +833,7 @@ class MultiScaleFlipAug(BaseTransform):
for
scale
in
self
.
img_scale
:
for
scale
in
self
.
img_scale
:
for
flip
,
direction
in
flip_args
:
for
flip
,
direction
in
flip_args
:
_resize_cfg
=
self
.
resize_cfg
.
copy
()
_resize_cfg
=
self
.
resize_cfg
.
copy
()
_resize_cfg
.
update
(
scale
=
scale
)
_resize_cfg
.
update
(
{
self
.
scale_key
:
scale
}
)
_resize_flip
=
[
_resize_cfg
]
_resize_flip
=
[
_resize_cfg
]
if
flip
:
if
flip
:
...
@@ -834,14 +847,11 @@ class MultiScaleFlipAug(BaseTransform):
...
@@ -834,14 +847,11 @@ class MultiScaleFlipAug(BaseTransform):
resize_flip
=
Compose
(
_resize_flip
)
resize_flip
=
Compose
(
_resize_flip
)
_results
=
results
.
copy
()
_results
=
results
.
copy
()
_results
=
resize_flip
(
_results
)
_results
=
resize_flip
(
_results
)
data
=
self
.
transforms
(
_results
)
input_image
,
data_sample
=
self
.
transforms
(
_results
)
aug_data
.
append
(
data
)
# list of dict to dict of list
input_data
.
append
(
input_image
)
aug_data_dict
=
{
key
:
[]
for
key
in
aug_data
[
0
]}
aug_data
.
append
(
data_sample
)
for
data
in
aug_data
:
return
input_data
,
aug_data
for
key
,
val
in
data
.
items
():
aug_data_dict
[
key
].
append
(
val
)
return
aug_data_dict
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
...
...
tests/test_transforms/test_transforms_processing.py
View file @
2844a116
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
copy
import
os.path
as
osp
import
os.path
as
osp
from
unittest.mock
import
Mock
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -8,6 +9,7 @@ import pytest
...
@@ -8,6 +9,7 @@ import pytest
import
mmcv
import
mmcv
from
mmcv.transforms
import
(
TRANSFORMS
,
Normalize
,
Pad
,
RandomFlip
,
from
mmcv.transforms
import
(
TRANSFORMS
,
Normalize
,
Pad
,
RandomFlip
,
RandomResize
,
Resize
)
RandomResize
,
Resize
)
from
mmcv.transforms.base
import
BaseTransform
try
:
try
:
import
torch
import
torch
...
@@ -538,6 +540,17 @@ class TestRandomGrayscale:
...
@@ -538,6 +540,17 @@ class TestRandomGrayscale:
assert
img
.
shape
==
(
10
,
10
,
1
)
assert
img
.
shape
==
(
10
,
10
,
1
)
@
TRANSFORMS
.
register_module
()
class
MockFormatBundle
(
BaseTransform
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
transform
(
self
,
results
):
data_sample
=
Mock
()
return
results
[
'img'
],
data_sample
class
TestMultiScaleFlipAug
:
class
TestMultiScaleFlipAug
:
@
classmethod
@
classmethod
...
@@ -547,12 +560,6 @@ class TestMultiScaleFlipAug:
...
@@ -547,12 +560,6 @@ class TestMultiScaleFlipAug:
cls
.
original_img
=
copy
.
deepcopy
(
cls
.
img
)
cls
.
original_img
=
copy
.
deepcopy
(
cls
.
img
)
def
test_error
(
self
):
def
test_error
(
self
):
# test assertion if img_scale is None
with
pytest
.
raises
(
AssertionError
):
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
img_scale
=
None
,
transforms
=
[])
TRANSFORMS
.
build
(
transform
)
# test assertion if img_scale is not tuple or list of tuple
# test assertion if img_scale is not tuple or list of tuple
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
transform
=
dict
(
transform
=
dict
(
...
@@ -574,28 +581,30 @@ class TestMultiScaleFlipAug:
...
@@ -574,28 +581,30 @@ class TestMultiScaleFlipAug:
# test with empty transforms
# test with empty transforms
transform
=
dict
(
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
type
=
'MultiScaleFlipAug'
,
transforms
=
[],
transforms
=
[
dict
(
type
=
'MockFormatBundle'
)
],
img_scale
=
[(
1333
,
800
),
(
800
,
600
),
(
640
,
480
)],
img_scale
=
[(
1333
,
800
),
(
800
,
600
),
(
640
,
480
)],
flip
=
True
,
flip
=
True
,
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
results
=
dict
()
results
=
dict
()
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
=
multi_scale_flip_aug_module
(
results
)
input
,
data_sample
=
multi_scale_flip_aug_module
(
results
)
assert
len
(
results
[
'img'
])
==
12
assert
len
(
input
)
==
12
assert
len
(
data_sample
)
==
12
# test with flip=False
# test with flip=False
transform
=
dict
(
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
type
=
'MultiScaleFlipAug'
,
transforms
=
[],
transforms
=
[
dict
(
type
=
'MockFormatBundle'
)
],
img_scale
=
[(
1333
,
800
),
(
800
,
600
),
(
640
,
480
)],
img_scale
=
[(
1333
,
800
),
(
800
,
600
),
(
640
,
480
)],
flip
=
False
,
flip
=
False
,
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
results
=
dict
()
results
=
dict
()
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
=
multi_scale_flip_aug_module
(
results
)
input
,
data_sample
=
multi_scale_flip_aug_module
(
results
)
assert
len
(
results
[
'img'
])
==
3
assert
len
(
input
)
==
3
assert
len
(
data_sample
)
==
3
# test with transforms
# test with transforms
img_norm_cfg
=
dict
(
img_norm_cfg
=
dict
(
...
@@ -606,6 +615,7 @@ class TestMultiScaleFlipAug:
...
@@ -606,6 +615,7 @@ class TestMultiScaleFlipAug:
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'MockFormatBundle'
)
]
]
transform
=
dict
(
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
type
=
'MultiScaleFlipAug'
,
...
@@ -616,8 +626,56 @@ class TestMultiScaleFlipAug:
...
@@ -616,8 +626,56 @@ class TestMultiScaleFlipAug:
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
results
=
dict
()
results
=
dict
()
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
results
=
multi_scale_flip_aug_module
(
results
)
input
,
data_sample
=
multi_scale_flip_aug_module
(
results
)
assert
len
(
results
[
'img'
])
==
12
assert
len
(
input
)
==
12
assert
len
(
data_sample
)
==
12
# test with scale_factor
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
transforms_cfg
=
[
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'MockFormatBundle'
)
]
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
transforms
=
transforms_cfg
,
scale_factor
=
[
0.5
,
1.
,
2.
],
flip
=
True
,
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
results
=
dict
()
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
input
,
data_sample
=
multi_scale_flip_aug_module
(
results
)
assert
len
(
input
)
==
12
assert
len
(
data_sample
)
==
12
# test no resize
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
transforms_cfg
=
[
dict
(
type
=
'Normalize'
,
**
img_norm_cfg
),
dict
(
type
=
'Pad'
,
size_divisor
=
32
),
dict
(
type
=
'ImageToTensor'
,
keys
=
[
'img'
]),
dict
(
type
=
'MockFormatBundle'
)
]
transform
=
dict
(
type
=
'MultiScaleFlipAug'
,
transforms
=
transforms_cfg
,
flip
=
True
,
flip_direction
=
[
'horizontal'
,
'vertical'
,
'diagonal'
])
multi_scale_flip_aug_module
=
TRANSFORMS
.
build
(
transform
)
results
=
dict
()
results
[
'img'
]
=
copy
.
deepcopy
(
self
.
original_img
)
input
,
data_sample
=
multi_scale_flip_aug_module
(
results
)
assert
len
(
input
)
==
4
assert
len
(
data_sample
)
==
4
class
TestRandomMultiscaleResize
:
class
TestRandomMultiscaleResize
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment