Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
1baf0566
Commit
1baf0566
authored
Jun 24, 2025
by
limm
Browse files
add tests part
parent
495d9ed9
Pipeline
#2800
canceled with stages
Changes
146
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4861 additions
and
0 deletions
+4861
-0
tests/test_datasets/test_samplers/test_repeat_aug.py
tests/test_datasets/test_samplers/test_repeat_aug.py
+98
-0
tests/test_datasets/test_transforms/test_auto_augment.py
tests/test_datasets/test_transforms/test_auto_augment.py
+1330
-0
tests/test_datasets/test_transforms/test_formatting.py
tests/test_datasets/test_transforms/test_formatting.py
+219
-0
tests/test_datasets/test_transforms/test_processing.py
tests/test_datasets/test_transforms/test_processing.py
+959
-0
tests/test_datasets/test_transforms/test_wrappers.py
tests/test_datasets/test_transforms/test_wrappers.py
+43
-0
tests/test_engine/test_hooks/test_arcface_hooks.py
tests/test_engine/test_hooks/test_arcface_hooks.py
+102
-0
tests/test_engine/test_hooks/test_class_num_check_hook.py
tests/test_engine/test_hooks/test_class_num_check_hook.py
+52
-0
tests/test_engine/test_hooks/test_densecl_hook.py
tests/test_engine/test_hooks/test_densecl_hook.py
+113
-0
tests/test_engine/test_hooks/test_ema_hook.py
tests/test_engine/test_hooks/test_ema_hook.py
+224
-0
tests/test_engine/test_hooks/test_precise_bn_hook.py
tests/test_engine/test_hooks/test_precise_bn_hook.py
+232
-0
tests/test_engine/test_hooks/test_retrievers_hooks.py
tests/test_engine/test_hooks/test_retrievers_hooks.py
+34
-0
tests/test_engine/test_hooks/test_simsiam_hook.py
tests/test_engine/test_hooks/test_simsiam_hook.py
+117
-0
tests/test_engine/test_hooks/test_swav_hook.py
tests/test_engine/test_hooks/test_swav_hook.py
+127
-0
tests/test_engine/test_hooks/test_switch_recipe_hook.py
tests/test_engine/test_hooks/test_switch_recipe_hook.py
+371
-0
tests/test_engine/test_hooks/test_visualization_hook.py
tests/test_engine/test_hooks/test_visualization_hook.py
+148
-0
tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py
..._optimizers/test_layer_decay_optim_wrapper_constructor.py
+107
-0
tests/test_evaluation/test_metrics/test_gqa.py
tests/test_evaluation/test_metrics/test_gqa.py
+30
-0
tests/test_evaluation/test_metrics/test_metric_utils.py
tests/test_evaluation/test_metrics/test_metric_utils.py
+33
-0
tests/test_evaluation/test_metrics/test_multi_label.py
tests/test_evaluation/test_metrics/test_multi_label.py
+388
-0
tests/test_evaluation/test_metrics/test_multi_task_metrics.py
...s/test_evaluation/test_metrics/test_multi_task_metrics.py
+134
-0
No files found.
tests/test_datasets/test_samplers/test_repeat_aug.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
math
from
unittest
import
TestCase
from
unittest.mock
import
patch
import
torch
from
mmengine.logging
import
MMLogger
from
mmpretrain.datasets
import
RepeatAugSampler
file
=
'mmpretrain.datasets.samplers.repeat_aug.'
class
MockDist
:
def
__init__
(
self
,
dist_info
=
(
0
,
1
),
seed
=
7
):
self
.
dist_info
=
dist_info
self
.
seed
=
seed
def
get_dist_info
(
self
):
return
self
.
dist_info
def
sync_random_seed
(
self
):
return
self
.
seed
def
is_main_process
(
self
):
return
self
.
dist_info
[
0
]
==
0
class
TestRepeatAugSampler
(
TestCase
):
def
setUp
(
self
):
self
.
data_length
=
100
self
.
dataset
=
list
(
range
(
self
.
data_length
))
@
patch
(
file
+
'get_dist_info'
,
return_value
=
(
0
,
1
))
def
test_non_dist
(
self
,
mock
):
sampler
=
RepeatAugSampler
(
self
.
dataset
,
num_repeats
=
3
,
shuffle
=
False
)
self
.
assertEqual
(
sampler
.
world_size
,
1
)
self
.
assertEqual
(
sampler
.
rank
,
0
)
self
.
assertEqual
(
sampler
.
total_size
,
self
.
data_length
*
3
)
self
.
assertEqual
(
sampler
.
num_samples
,
self
.
data_length
*
3
)
self
.
assertEqual
(
sampler
.
num_selected_samples
,
self
.
data_length
)
self
.
assertEqual
(
len
(
sampler
),
sampler
.
num_selected_samples
)
indices
=
[
x
for
x
in
range
(
self
.
data_length
)
for
_
in
range
(
3
)]
self
.
assertEqual
(
list
(
sampler
),
indices
[:
self
.
data_length
])
logger
=
MMLogger
.
get_current_instance
()
with
self
.
assertLogs
(
logger
,
'WARN'
)
as
log
:
sampler
=
RepeatAugSampler
(
self
.
dataset
,
shuffle
=
False
)
self
.
assertIn
(
'always picks a fixed part'
,
log
.
output
[
0
])
@
patch
(
file
+
'get_dist_info'
,
return_value
=
(
2
,
3
))
@
patch
(
file
+
'is_main_process'
,
return_value
=
False
)
def
test_dist
(
self
,
mock1
,
mock2
):
sampler
=
RepeatAugSampler
(
self
.
dataset
,
num_repeats
=
3
,
shuffle
=
False
)
self
.
assertEqual
(
sampler
.
world_size
,
3
)
self
.
assertEqual
(
sampler
.
rank
,
2
)
self
.
assertEqual
(
sampler
.
num_samples
,
self
.
data_length
)
self
.
assertEqual
(
sampler
.
total_size
,
self
.
data_length
*
3
)
self
.
assertEqual
(
sampler
.
num_selected_samples
,
math
.
ceil
(
self
.
data_length
/
3
))
self
.
assertEqual
(
len
(
sampler
),
sampler
.
num_selected_samples
)
indices
=
[
x
for
x
in
range
(
self
.
data_length
)
for
_
in
range
(
3
)]
self
.
assertEqual
(
list
(
sampler
),
indices
[
2
::
3
][:
sampler
.
num_selected_samples
])
logger
=
MMLogger
.
get_current_instance
()
with
patch
.
object
(
logger
,
'warning'
)
as
mock_log
:
sampler
=
RepeatAugSampler
(
self
.
dataset
,
shuffle
=
False
)
mock_log
.
assert_not_called
()
@
patch
(
file
+
'get_dist_info'
,
return_value
=
(
0
,
1
))
@
patch
(
file
+
'sync_random_seed'
,
return_value
=
7
)
def
test_shuffle
(
self
,
mock1
,
mock2
):
# test seed=None
sampler
=
RepeatAugSampler
(
self
.
dataset
,
seed
=
None
)
self
.
assertEqual
(
sampler
.
seed
,
7
)
# test random seed
sampler
=
RepeatAugSampler
(
self
.
dataset
,
shuffle
=
True
,
seed
=
0
)
sampler
.
set_epoch
(
10
)
g
=
torch
.
Generator
()
g
.
manual_seed
(
10
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
indices
=
[
x
for
x
in
indices
for
_
in
range
(
3
)][:
sampler
.
num_selected_samples
]
self
.
assertEqual
(
list
(
sampler
),
indices
)
sampler
=
RepeatAugSampler
(
self
.
dataset
,
shuffle
=
True
,
seed
=
42
)
sampler
.
set_epoch
(
10
)
g
=
torch
.
Generator
()
g
.
manual_seed
(
42
+
10
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
indices
=
[
x
for
x
in
indices
for
_
in
range
(
3
)][:
sampler
.
num_selected_samples
]
self
.
assertEqual
(
list
(
sampler
),
indices
)
tests/test_datasets/test_transforms/test_auto_augment.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
math
from
unittest
import
TestCase
from
unittest.mock
import
ANY
,
patch
import
numpy
as
np
from
mmpretrain.registry
import
TRANSFORMS
def
construct_toy_data
():
img
=
np
.
random
.
randint
(
0
,
256
,
(
100
,
200
,
3
),
dtype
=
np
.
uint8
)
results
=
dict
()
# image
results
[
'ori_img'
]
=
img
results
[
'img'
]
=
img
results
[
'img2'
]
=
copy
.
deepcopy
(
img
)
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
results
[
'img_fields'
]
=
[
'img'
,
'img2'
]
return
results
def
construct_toy_data_photometric
():
img
=
np
.
array
([[
0
,
128
,
255
],
[
1
,
127
,
254
],
[
2
,
129
,
253
]],
dtype
=
np
.
uint8
)
img
=
np
.
stack
([
img
,
img
,
img
],
axis
=-
1
)
results
=
dict
()
# image
results
[
'ori_img'
]
=
img
results
[
'img'
]
=
img
results
[
'img2'
]
=
copy
.
deepcopy
(
img
)
results
[
'img_shape'
]
=
img
.
shape
results
[
'ori_shape'
]
=
img
.
shape
results
[
'img_fields'
]
=
[
'img'
,
'img2'
]
return
results
class
TestAutoAugment
(
TestCase
):
def
test_construct
(
self
):
policies
=
[[
dict
(
type
=
'Posterize'
,
bits
=
4
,
prob
=
0.4
),
dict
(
type
=
'Rotate'
,
angle
=
30.
,
prob
=
0.6
)
]]
cfg
=
dict
(
type
=
'AutoAugment'
,
policies
=
policies
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
construct_toy_data
()
with
patch
.
object
(
transform
.
transforms
[
0
],
'transform'
)
as
mock
:
transform
(
results
)
mock
.
assert_called_once
()
cfg
=
dict
(
type
=
'AutoAugment'
,
policies
=
'imagenet'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
with
patch
.
object
(
transform
.
transforms
[
5
],
'transform'
)
as
mock
:
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
1
)):
transform
(
results
)
mock
.
assert_called
()
# test hparams
cfg
=
dict
(
type
=
'AutoAugment'
,
policies
=
policies
,
hparams
=
dict
(
pad_val
=
[
255
,
255
,
255
]))
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
transform
.
policies
[
0
][
1
][
'pad_val'
],
[
255
,
255
,
255
])
self
.
assertNotIn
(
'pad_val'
,
transform
.
policies
[
0
][
0
])
with
self
.
assertRaisesRegex
(
AssertionError
,
'choose from .*imagenet'
):
cfg
=
dict
(
type
=
'AutoAugment'
,
policies
=
'unknown'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
def
test_repr
(
self
):
policies
=
[[
dict
(
type
=
'Posterize'
,
bits
=
4
,
prob
=
0.4
),
dict
(
type
=
'Rotate'
,
angle
=
30.
,
prob
=
0.6
)
]]
cfg
=
dict
(
type
=
'AutoAugment'
,
policies
=
policies
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Posterize,
\t
Rotate'
,
repr
(
transform
))
class
TestRandAugment
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'RandAugment'
,
magnitude_level
=
7
,
num_policies
=
1
,
policies
=
'timm_increasing'
)
def
test_construct
(
self
):
policies
=
[
dict
(
type
=
'Posterize'
,
magnitude_range
=
(
4
,
0
)),
dict
(
type
=
'Rotate'
,
magnitude_range
=
(
0
,
30
))
]
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
policies
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
len
(
list
(
transform
)),
2
)
results
=
construct_toy_data
()
with
patch
.
object
(
transform
.
transforms
[
1
],
'transform'
)
as
mock
:
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
1
)):
transform
(
results
)
mock
.
assert_called_once
()
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
'timm_increasing'
}
transform
=
TRANSFORMS
.
build
(
cfg
)
with
patch
.
object
(
transform
.
transforms
[
5
],
'transform'
)
as
mock
:
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
1
)):
transform
(
results
)
mock
.
assert_called
()
# test hparams
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
policies
,
'hparams'
:
dict
(
pad_val
=
[
255
,
255
,
255
]),
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
transform
.
policies
[
1
][
'pad_val'
],
[
255
,
255
,
255
])
self
.
assertNotIn
(
'pad_val'
,
transform
.
policies
[
0
])
# test magnitude related parameters
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
[
dict
(
type
=
'Equalize'
),
dict
(
type
=
'Rotate'
,
magnitude_range
=
(
0
,
30
))
]
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertNotIn
(
'magnitude_range'
,
transform
.
policies
[
0
])
self
.
assertNotIn
(
'magnitude_level'
,
transform
.
policies
[
0
])
self
.
assertNotIn
(
'magnitude_range'
,
transform
.
policies
[
0
])
self
.
assertNotIn
(
'total_level'
,
transform
.
policies
[
0
])
self
.
assertEqual
(
transform
.
policies
[
1
][
'magnitude_range'
],
(
0
,
30
))
self
.
assertEqual
(
transform
.
policies
[
1
][
'magnitude_level'
],
7
)
self
.
assertEqual
(
transform
.
policies
[
1
][
'magnitude_std'
],
0.
)
self
.
assertEqual
(
transform
.
policies
[
1
][
'total_level'
],
10
)
# test invalid policies
with
self
.
assertRaisesRegex
(
AssertionError
,
'choose from .*timm_increasing'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
'unknown'
}
transform
=
TRANSFORMS
.
build
(
cfg
)
# test invalid magnitude_std
with
self
.
assertRaisesRegex
(
AssertionError
,
'got "unknown" instead'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_std'
:
'unknown'
}
transform
=
TRANSFORMS
.
build
(
cfg
)
def
test_repr
(
self
):
policies
=
[
dict
(
type
=
'Posterize'
,
magnitude_range
=
(
4
,
0
)),
dict
(
type
=
'Equalize'
)
]
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'policies'
:
policies
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
' Posterize (4, 0)
\n
Equalize
\n
'
,
repr
(
transform
))
class
TestShear
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Shear'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'got "unknown" instead'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'direction'
:
'unknown'
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
# test params inputs
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'direction'
:
'horizontal'
,
'pad_val'
:
255
,
'interpolation'
:
'nearest'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
0.2
,
direction
=
'horizontal'
,
border_value
=
255
,
interpolation
=
'nearest'
)
# test random_negative_prob
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
-
0.2
,
direction
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
# test prob
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test sequeue pad_val
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'direction'
:
'horizontal'
,
'pad_val'
:
(
255
,
255
,
255
),
'interpolation'
:
'nearest'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
0.2
,
direction
=
'horizontal'
,
border_value
=
(
255
,
255
,
255
),
interpolation
=
'nearest'
)
# test magnitude_range
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.3
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
0.18
,
direction
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
# test magnitude_std is positive
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.3
),
'magnitude_std'
:
1
}
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
1
)):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
self
.
assertAlmostEqual
(
mock
.
call_args
[
0
][
1
],
0.1811
,
places
=
4
)
# test magnitude_std = 'inf'
with
patch
(
'mmcv.imshear'
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.3
),
'magnitude_std'
:
'inf'
}
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
9
)):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
self
.
assertAlmostEqual
(
mock
.
call_args
[
0
][
1
],
0.0882
,
places
=
4
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Shear(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Shear(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestTranslate
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Translate'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'got "unknown" instead'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'direction'
:
'unknown'
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.imtranslate'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'direction'
:
'horizontal'
,
'pad_val'
:
255
,
'interpolation'
:
'nearest'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
200
*
0.2
,
direction
=
'horizontal'
,
border_value
=
255
,
interpolation
=
'nearest'
)
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'direction'
:
'vertical'
,
'pad_val'
:
255
,
'interpolation'
:
'nearest'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
100
*
0.2
,
direction
=
'vertical'
,
border_value
=
255
,
interpolation
=
'nearest'
)
# test sequeue pad_val
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'direction'
:
'horizontal'
,
'pad_val'
:
[
255
,
255
,
255
],
'interpolation'
:
'nearest'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
200
*
0.2
,
direction
=
'horizontal'
,
border_value
=
(
255
,
255
,
255
),
interpolation
=
'nearest'
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.2
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
-
0.2
*
200
,
direction
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.3
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
0.18
*
200
,
direction
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Translate(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Translate(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestRotate
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Rotate'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
,
'magnitude_range'
:
(
1
,
2
)}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.imrotate'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
,
'center'
:
(
10
,
10
),
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'scale'
:
1.5
,
'pad_val'
:
255
,
'interpolation'
:
'bilinear'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
30
,
center
=
(
10
,
10
),
scale
=
1.5
,
border_value
=
255
,
interpolation
=
'bilinear'
)
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
,
'center'
:
(
10
,
10
),
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'scale'
:
1.5
,
'pad_val'
:
(
255
,
255
,
255
),
'interpolation'
:
'bilinear'
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
30
,
center
=
(
10
,
10
),
scale
=
1.5
,
border_value
=
(
255
,
255
,
255
),
interpolation
=
'bilinear'
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
-
30
,
center
=
ANY
,
scale
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
30
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
18
,
center
=
ANY
,
scale
=
ANY
,
border_value
=
ANY
,
interpolation
=
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'angle'
:
30
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Rotate(angle=30'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
30
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Rotate(angle=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 30)'
,
repr
(
transform
))
class
TestAutoContrast
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'AutoContrast'
)
def
test_transform
(
self
):
transform_func
=
'mmcv.auto_contrast'
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
# test magnitude_range
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
30
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
0.5
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'AutoContrast(prob=0.5)'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
30
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'AutoContrast(prob='
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range=(0, 30)'
,
repr
(
transform
))
class
TestInvert
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Invert'
)
def
test_transform
(
self
):
transform_func
=
'mmcv.iminvert'
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
# test magnitude_range
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
30
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
0.5
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Invert(prob=0.5)'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
30
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Invert(prob='
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range=(0, 30)'
,
repr
(
transform
))
class
TestEqualize
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Equalize'
)
def
test_transform
(
self
):
transform_func
=
'mmcv.imequalize'
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
# test magnitude_range
# No effect
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
30
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
0.5
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Equalize(prob=0.5)'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
30
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Equalize(prob='
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range=(0, 30)'
,
repr
(
transform
))
class
TestSolarize
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Solarize'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'thr'
:
1
,
'magnitude_range'
:
(
1
,
2
)}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.solarize'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'thr'
:
128
,
'prob'
:
1.
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
thr
=
128
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'thr'
:
128
,
'prob'
:
0.
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
# cannot accept `random_negative_prob` argument
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'thr'
:
128
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
with
self
.
assertRaisesRegex
(
TypeError
,
'multiple values'
):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
256
,
0
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
thr
=
256
*
0.4
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'thr'
:
128
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Solarize(thr=128'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
256
,
0
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Solarize(thr=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(256, 0)'
,
repr
(
transform
))
class
TestSolarizeAdd
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'SolarizeAdd'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'str'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'thr'
:
'hi'
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
# test params inputs
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
100
,
'thr'
:
128
,
'prob'
:
1.
}
results
=
construct_toy_data_photometric
()
expected
=
np
.
where
(
results
[
'img'
]
<
128
,
np
.
minimum
(
results
[
'img'
]
+
100
,
255
),
results
[
'img'
])
TRANSFORMS
.
build
(
cfg
)(
results
)
np
.
testing
.
assert_allclose
(
results
[
'img'
],
expected
)
# test prob
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
100
,
'thr'
:
128
,
'prob'
:
0.
}
results
=
construct_toy_data_photometric
()
expected
=
copy
.
deepcopy
(
results
[
'img'
])
TRANSFORMS
.
build
(
cfg
)(
results
)
np
.
testing
.
assert_allclose
(
results
[
'img'
],
expected
)
# test random_negative_prob
# cannot accept `random_negative_prob` argument
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
100
,
'thr'
:
128
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
with
self
.
assertRaisesRegex
(
TypeError
,
'multiple values'
):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
# test magnitude_range
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
110
),
}
results
=
construct_toy_data_photometric
()
expected
=
np
.
where
(
results
[
'img'
]
<
128
,
np
.
minimum
(
results
[
'img'
]
+
110
*
0.6
,
255
),
results
[
'img'
])
TRANSFORMS
.
build
(
cfg
)(
results
)
np
.
testing
.
assert_allclose
(
results
[
'img'
],
expected
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
100
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'SolarizeAdd(magnitude=100'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
110
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'SolarizeAdd(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 110)'
,
repr
(
transform
))
class
TestPosterize
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Posterize'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
1
,
'magnitude_range'
:
(
1
,
2
)}
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'got 100 instead'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
100
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.posterize'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
4
,
'prob'
:
1.
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
bits
=
4
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
4
,
'prob'
:
0.
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
# cannot accept `random_negative_prob` argument
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
4
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
with
self
.
assertRaisesRegex
(
TypeError
,
'multiple values'
):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
4
,
0
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
bits
=
math
.
ceil
(
4
*
0.4
))
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'bits'
:
4
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Posterize(bits=4'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
4
,
0
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Posterize(bits=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(4, 0)'
,
repr
(
transform
))
class
TestContrast
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Contrast'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.adjust_contrast'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.5
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
-
0.5
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.5
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.6
*
0.5
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Contrast(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Contrast(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestColorTransform
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'ColorTransform'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.adjust_color'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
alpha
=
1
+
0.5
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
alpha
=
1
-
0.5
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.5
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
alpha
=
1
+
0.6
*
0.5
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'ColorTransform(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'ColorTransform(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestBrightness
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Brightness'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.adjust_brightness'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.5
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
-
0.5
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.5
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.6
*
0.5
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Brightness(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Brightness(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestSharpness
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Sharpness'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
1
,
'magnitude_range'
:
(
1
,
2
)
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.adjust_sharpness'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.5
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
0.
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test random_negative_prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.5
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
-
0.5
)
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'random_negative_prob'
:
0.
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
0
,
0.5
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
factor
=
1
+
0.6
*
0.5
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Sharpness(magnitude=0.1'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
0.3
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Sharpness(magnitude=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 0.3)'
,
repr
(
transform
))
class
TestCutout
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Cutout'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
10
,
'magnitude_range'
:
(
10
,
20
)
}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'mmcv.cutout'
# test params inputs
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
(
10
,
15
),
'prob'
:
1.
,
'pad_val'
:
255
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
(
10
,
15
),
pad_val
=
255
)
# test prob
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
10
,
'prob'
:
0.
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test sequeue pad_val
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
(
10
,
15
),
'prob'
:
1.
,
'pad_val'
:
[
255
,
255
,
255
],
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
(
10
,
15
),
pad_val
=
(
255
,
255
,
255
))
# test random_negative_prob
# cannot accept `random_negative_prob` argument
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
10
,
'random_negative_prob'
:
1.
,
'prob'
:
1.
,
}
with
self
.
assertRaisesRegex
(
TypeError
,
'multiple values'
):
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
# test magnitude_range
with
patch
(
transform_func
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'prob'
:
1.
,
'magnitude_level'
:
6
,
'magnitude_range'
:
(
1
,
41
),
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
ANY
,
40
*
0.6
+
1
,
pad_val
=
ANY
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'shape'
:
15
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Cutout(shape=15'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0
,
41
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'Cutout(shape=None'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0, 41)'
,
repr
(
transform
))
class
TestGaussianBlur
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'GaussianBlur'
)
def
test_initialize
(
self
):
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
with
self
.
assertRaisesRegex
(
AssertionError
,
'only one of'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'radius'
:
1
,
'magnitude_range'
:
(
1
,
2
)}
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
transform_func
=
'PIL.ImageFilter.GaussianBlur'
from
PIL.ImageFilter
import
GaussianBlur
# test params inputs
with
patch
(
transform_func
,
wraps
=
GaussianBlur
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'radius'
:
0.5
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_called_once_with
(
radius
=
0.5
)
# test prob
with
patch
(
transform_func
,
wraps
=
GaussianBlur
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'radius'
:
0.5
,
'prob'
:
0.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
mock
.
assert_not_called
()
# test magnitude_range
with
patch
(
transform_func
,
wraps
=
GaussianBlur
)
as
mock
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0.1
,
2
),
'magnitude_std'
:
'inf'
,
'prob'
:
1.
,
}
TRANSFORMS
.
build
(
cfg
)(
construct_toy_data
())
self
.
assertTrue
(
0.1
<
mock
.
call_args
[
1
][
'radius'
]
<
2
)
def
test_repr
(
self
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'radius'
:
0.1
}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'GaussianBlur(radius=0.1, prob=0.5'
,
repr
(
transform
))
self
.
assertNotIn
(
'magnitude_range'
,
repr
(
transform
))
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'magnitude_range'
:
(
0.1
,
2
)}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'GaussianBlur(radius=None, prob=0.5'
,
repr
(
transform
))
self
.
assertIn
(
'magnitude_range=(0.1, 2)'
,
repr
(
transform
))
tests/test_datasets/test_transforms/test_formatting.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
os.path
as
osp
import
unittest
import
mmcv
import
numpy
as
np
import
torch
from
PIL
import
Image
from
mmpretrain.registry
import
TRANSFORMS
from
mmpretrain.structures
import
DataSample
,
MultiTaskDataSample
class
TestPackInputs
(
unittest
.
TestCase
):
def
test_transform
(
self
):
img_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
data
=
{
'sample_idx'
:
1
,
'img_path'
:
img_path
,
'ori_shape'
:
(
300
,
400
),
'img_shape'
:
(
300
,
400
),
'scale_factor'
:
1.0
,
'flip'
:
False
,
'img'
:
mmcv
.
imread
(
img_path
),
'gt_label'
:
2
,
'custom_key'
:
torch
.
tensor
([
1
,
2
,
3
])
}
cfg
=
dict
(
type
=
'PackInputs'
,
algorithm_keys
=
[
'custom_key'
])
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertIn
(
'data_samples'
,
results
)
self
.
assertIsInstance
(
results
[
'data_samples'
],
DataSample
)
self
.
assertIn
(
'flip'
,
results
[
'data_samples'
].
metainfo_keys
())
self
.
assertIsInstance
(
results
[
'data_samples'
].
gt_label
,
torch
.
Tensor
)
self
.
assertIsInstance
(
results
[
'data_samples'
].
custom_key
,
torch
.
Tensor
)
# Test grayscale image
data
[
'img'
]
=
data
[
'img'
].
mean
(
-
1
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertEqual
(
results
[
'inputs'
].
shape
,
(
1
,
300
,
400
))
# Test video input
data
[
'img'
]
=
np
.
random
.
randint
(
0
,
256
,
(
10
,
3
,
1
,
224
,
224
),
dtype
=
np
.
uint8
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertEqual
(
results
[
'inputs'
].
shape
,
(
10
,
3
,
1
,
224
,
224
))
# Test Pillow input
data
[
'img'
]
=
Image
.
open
(
img_path
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertEqual
(
results
[
'inputs'
].
shape
,
(
3
,
300
,
400
))
# Test without `img` and `gt_label`
del
data
[
'img'
]
del
data
[
'gt_label'
]
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertNotIn
(
'gt_label'
,
results
[
'data_samples'
])
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'PackInputs'
,
meta_keys
=
[
'flip'
,
'img_shape'
])
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
"PackInputs(input_key='img', algorithm_keys=(), "
"meta_keys=['flip', 'img_shape'])"
)
class
TestTranspose
(
unittest
.
TestCase
):
def
test_transform
(
self
):
cfg
=
dict
(
type
=
'Transpose'
,
keys
=
[
'img'
],
order
=
[
2
,
0
,
1
])
transform
=
TRANSFORMS
.
build
(
cfg
)
data
=
{
'img'
:
np
.
random
.
randint
(
0
,
256
,
(
224
,
224
,
3
),
dtype
=
'uint8'
)}
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertEqual
(
results
[
'img'
].
shape
,
(
3
,
224
,
224
))
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'Transpose'
,
keys
=
[
'img'
],
order
=
(
2
,
0
,
1
))
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
"Transpose(keys=['img'], order=(2, 0, 1))"
)
class
TestToPIL
(
unittest
.
TestCase
):
def
test_transform
(
self
):
cfg
=
dict
(
type
=
'ToPIL'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
data
=
{
'img'
:
np
.
random
.
randint
(
0
,
256
,
(
224
,
224
,
3
),
dtype
=
'uint8'
)}
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIsInstance
(
results
[
'img'
],
Image
.
Image
)
cfg
=
dict
(
type
=
'ToPIL'
,
to_rgb
=
True
)
transform
=
TRANSFORMS
.
build
(
cfg
)
data
=
{
'img'
:
np
.
random
.
randint
(
0
,
256
,
(
224
,
224
,
3
),
dtype
=
'uint8'
)}
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIsInstance
(
results
[
'img'
],
Image
.
Image
)
np
.
equal
(
np
.
array
(
results
[
'img'
]),
data
[
'img'
][:,
:,
::
-
1
])
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'ToPIL'
,
to_rgb
=
True
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'NumpyToPIL(to_rgb=True)'
)
class
TestToNumpy
(
unittest
.
TestCase
):
def
test_transform
(
self
):
img_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
data
=
{
'img'
:
Image
.
open
(
img_path
),
}
cfg
=
dict
(
type
=
'ToNumpy'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIsInstance
(
results
[
'img'
],
np
.
ndarray
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
'uint8'
)
cfg
=
dict
(
type
=
'ToNumpy'
,
to_bgr
=
True
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIsInstance
(
results
[
'img'
],
np
.
ndarray
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
'uint8'
)
np
.
equal
(
results
[
'img'
],
np
.
array
(
data
[
'img'
])[:,
:,
::
-
1
])
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'ToNumpy'
,
to_bgr
=
True
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'PILToNumpy(to_bgr=True, dtype=None)'
)
class
TestCollect
(
unittest
.
TestCase
):
def
test_transform
(
self
):
data
=
{
'img'
:
[
1
,
2
,
3
],
'gt_label'
:
1
}
cfg
=
dict
(
type
=
'Collect'
,
keys
=
[
'img'
])
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'img'
,
results
)
self
.
assertNotIn
(
'gt_label'
,
results
)
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'Collect'
,
keys
=
[
'img'
])
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
"Collect(keys=['img'])"
)
class
TestPackMultiTaskInputs
(
unittest
.
TestCase
):
def
test_transform
(
self
):
img_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
data
=
{
'sample_idx'
:
1
,
'img_path'
:
img_path
,
'ori_shape'
:
(
300
,
400
),
'img_shape'
:
(
300
,
400
),
'scale_factor'
:
1.0
,
'flip'
:
False
,
'img'
:
mmcv
.
imread
(
img_path
),
'gt_label'
:
{
'task1'
:
1
,
'task3'
:
3
},
}
cfg
=
dict
(
type
=
'PackMultiTaskInputs'
,
multi_task_fields
=
[
'gt_label'
])
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertIn
(
'data_samples'
,
results
)
self
.
assertIsInstance
(
results
[
'data_samples'
],
MultiTaskDataSample
)
self
.
assertIn
(
'flip'
,
results
[
'data_samples'
].
task1
.
metainfo_keys
())
self
.
assertIsInstance
(
results
[
'data_samples'
].
task1
.
gt_label
,
torch
.
Tensor
)
# Test grayscale image
data
[
'img'
]
=
data
[
'img'
].
mean
(
-
1
)
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertIn
(
'inputs'
,
results
)
self
.
assertIsInstance
(
results
[
'inputs'
],
torch
.
Tensor
)
self
.
assertEqual
(
results
[
'inputs'
].
shape
,
(
1
,
300
,
400
))
# Test without `img` and `gt_label`
del
data
[
'img'
]
del
data
[
'gt_label'
]
results
=
transform
(
copy
.
deepcopy
(
data
))
self
.
assertNotIn
(
'gt_label'
,
results
[
'data_samples'
])
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'PackMultiTaskInputs'
,
multi_task_fields
=
[
'gt_label'
],
task_handlers
=
dict
(
task1
=
dict
(
type
=
'PackInputs'
)),
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
"PackMultiTaskInputs(multi_task_fields=['gt_label'], "
"input_key='img', task_handlers={'task1': PackInputs})"
)
tests/test_datasets/test_transforms/test_processing.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
math
import
os.path
as
osp
import
random
from
unittest
import
TestCase
from
unittest.mock
import
ANY
,
call
,
patch
import
mmengine
import
numpy
as
np
import
pytest
import
torch
import
torchvision
from
mmcv.transforms
import
Compose
from
mmengine.utils
import
digit_version
from
PIL
import
Image
from
torchvision
import
transforms
from
mmpretrain.datasets.transforms.processing
import
VISION_TRANSFORMS
from
mmpretrain.registry
import
TRANSFORMS
try
:
import
albumentations
except
ImportError
:
albumentations
=
None
def
construct_toy_data
():
img
=
np
.
array
([[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]],
dtype
=
np
.
uint8
)
img
=
np
.
stack
([
img
,
img
,
img
],
axis
=-
1
)
results
=
dict
()
# image
results
[
'ori_img'
]
=
img
results
[
'img'
]
=
copy
.
deepcopy
(
img
)
results
[
'ori_shape'
]
=
img
.
shape
results
[
'img_shape'
]
=
img
.
shape
return
results
class
TestRandomCrop
(
TestCase
):
def
test_assertion
(
self
):
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
(
1
,
2
,
3
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
(
1
,
-
2
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
224
,
padding_mode
=
'co'
)
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
# test random crop by default.
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test int padding and int pad_val.
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
(
224
,
224
),
padding
=
2
,
pad_val
=
1
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test int padding and sequence pad_val.
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
224
,
padding
=
2
,
pad_val
=
(
0
,
50
,
0
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test sequence padding.
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
224
,
padding
=
(
2
,
3
,
4
,
5
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test pad_if_needed.
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
300
,
pad_if_needed
=
True
,
padding_mode
=
'edge'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
300
,
300
,
3
))
# test large crop size.
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
300
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
256
,
256
,
3
))
# test equal size.
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
256
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
256
,
256
,
3
))
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'RandomCrop'
,
crop_size
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'RandomCrop(crop_size=(224, 224), padding=None, '
'pad_if_needed=False, pad_val=0, padding_mode=constant)'
)
class
TestRandomResizedCrop
(
TestCase
):
def
test_assertion
(
self
):
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
(
1
,
2
,
3
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
(
1
,
-
2
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
ValueError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
crop_ratio_range
=
(
1
,
0.1
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
ValueError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
aspect_ratio_range
=
(
1
,
0.1
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
max_attempts
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
interpolation
=
'ne'
)
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
# test random crop by default.
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test crop_ratio_range.
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
(
224
,
224
),
crop_ratio_range
=
(
0.5
,
0.8
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test aspect_ratio_range.
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
aspect_ratio_range
=
(
0.5
,
0.8
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test max_attempts.
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
,
max_attempts
=
0
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test fall back with extreme low in_ratio
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
10
,
256
,
3
),
np
.
uint8
))
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test fall back with extreme low in_ratio
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
10
,
3
),
np
.
uint8
))
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test large crop size.
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
300
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
300
,
300
,
3
))
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'RandomResizedCrop'
,
scale
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'RandomResizedCrop(scale=(224, 224), '
'crop_ratio_range=(0.08, 1.0), aspect_ratio_range=(0.75, 1.3333), '
'max_attempts=10, interpolation=bilinear, backend=cv2)'
)
class
TestEfficientNetRandomCrop
(
TestCase
):
def
test_assertion
(
self
):
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
(
1
,
1
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
min_covered
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
crop_padding
=-
1
)
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
# test random crop by default.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test crop_ratio_range.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
crop_ratio_range
=
(
0.5
,
0.8
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test aspect_ratio_range.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
aspect_ratio_range
=
(
0.5
,
0.8
))
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test max_attempts.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
max_attempts
=
0
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test min_covered.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
min_covered
=
.
9
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test crop_padding.
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
,
min_covered
=
0.9
,
crop_padding
=
10
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test large crop size.
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
300
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
300
,
300
,
3
))
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'EfficientNetRandomCrop'
,
scale
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'EfficientNetRandomCrop(scale=(224, 224), '
'crop_ratio_range=(0.08, 1.0), aspect_ratio_range=(0.75, 1.3333), '
'max_attempts=10, interpolation=bicubic, backend=cv2, '
'min_covered=0.1, crop_padding=32)'
)
class
TestResizeEdge
(
TestCase
):
def
test_transform
(
self
):
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
128
,
256
,
3
),
np
.
uint8
))
# test resize short edge by default.
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
448
,
3
))
# test resize long edge.
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
,
edge
=
'long'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
112
,
224
,
3
))
# test resize width.
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
,
edge
=
'width'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
112
,
224
,
3
))
# test resize height.
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
,
edge
=
'height'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
448
,
3
))
# test invalid edge
with
self
.
assertRaisesRegex
(
AssertionError
,
'Invalid edge "hi"'
):
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
,
edge
=
'hi'
)
TRANSFORMS
.
build
(
cfg
)
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'ResizeEdge'
,
scale
=
224
,
edge
=
'height'
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'ResizeEdge(scale=224, edge=height, backend=cv2, '
'interpolation=bilinear)'
)
class
TestEfficientNetCenterCrop
(
TestCase
):
def
test_assertion
(
self
):
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
(
1
,
1
))
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
224
,
crop_padding
=-
1
)
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
# test random crop by default.
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test crop_padding.
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
224
,
crop_padding
=
10
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
224
,
224
,
3
))
# test large crop size.
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
))
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
300
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
300
,
300
,
3
))
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'EfficientNetCenterCrop'
,
crop_size
=
224
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'EfficientNetCenterCrop(crop_size=224, '
'crop_padding=32, interpolation=bicubic, backend=cv2)'
)
class
TestRandomErasing
(
TestCase
):
def
test_initialize
(
self
):
# test erase_prob assertion
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=-
1.
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
1
)
TRANSFORMS
.
build
(
cfg
)
# test area_ratio assertion
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
min_area_ratio
=-
1.
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
max_area_ratio
=
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
# min_area_ratio should be smaller than max_area_ratio
cfg
=
dict
(
type
=
'RandomErasing'
,
min_area_ratio
=
0.6
,
max_area_ratio
=
0.4
)
TRANSFORMS
.
build
(
cfg
)
# test aspect_range assertion
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
aspect_range
=
'str'
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
aspect_range
=-
1
)
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
# In aspect_range (min, max), min should be smaller than max.
cfg
=
dict
(
type
=
'RandomErasing'
,
aspect_range
=
[
1.6
,
0.6
])
TRANSFORMS
.
build
(
cfg
)
# test mode assertion
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
mode
=
'unknown'
)
TRANSFORMS
.
build
(
cfg
)
# test fill_std assertion
with
self
.
assertRaises
(
AssertionError
):
cfg
=
dict
(
type
=
'RandomErasing'
,
fill_std
=
'unknown'
)
TRANSFORMS
.
build
(
cfg
)
# test implicit conversion of aspect_range
cfg
=
dict
(
type
=
'RandomErasing'
,
aspect_range
=
0.5
)
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
assert
random_erasing
.
aspect_range
==
(
0.5
,
2.
)
cfg
=
dict
(
type
=
'RandomErasing'
,
aspect_range
=
2.
)
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
assert
random_erasing
.
aspect_range
==
(
0.5
,
2.
)
# test implicit conversion of fill_color
cfg
=
dict
(
type
=
'RandomErasing'
,
fill_color
=
15
)
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
assert
random_erasing
.
fill_color
==
[
15
,
15
,
15
]
# test implicit conversion of fill_std
cfg
=
dict
(
type
=
'RandomErasing'
,
fill_std
=
0.5
)
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
assert
random_erasing
.
fill_std
==
[
0.5
,
0.5
,
0.5
]
def
test_transform
(
self
):
# test when erase_prob=0.
results
=
construct_toy_data
()
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
0.
,
mode
=
'const'
,
fill_color
=
(
255
,
255
,
255
))
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
results
=
random_erasing
(
results
)
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
results
[
'ori_img'
])
# test mode 'const'
results
=
construct_toy_data
()
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
1.
,
mode
=
'const'
,
fill_color
=
(
255
,
255
,
255
))
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
0
)):
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
results
=
random_erasing
(
results
)
expect_out
=
np
.
array
(
[[
1
,
255
,
3
,
4
],
[
5
,
255
,
7
,
8
],
[
9
,
10
,
11
,
12
]],
dtype
=
np
.
uint8
)
expect_out
=
np
.
stack
([
expect_out
]
*
3
,
axis
=-
1
)
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
expect_out
)
# test mode 'rand' with normal distribution
results
=
construct_toy_data
()
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
1.
,
mode
=
'rand'
)
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
0
)):
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
results
=
random_erasing
(
results
)
expect_out
=
results
[
'ori_img'
]
expect_out
[:
2
,
1
]
=
[[
159
,
98
,
76
],
[
14
,
69
,
122
]]
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
expect_out
)
# test mode 'rand' with uniform distribution
results
=
construct_toy_data
()
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
1.
,
mode
=
'rand'
,
fill_std
=
(
10
,
255
,
0
))
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
0
)):
random_erasing
=
TRANSFORMS
.
build
(
cfg
)
results
=
random_erasing
(
results
)
expect_out
=
results
[
'ori_img'
]
expect_out
[:
2
,
1
]
=
[[
113
,
255
,
128
],
[
126
,
83
,
128
]]
np
.
testing
.
assert_array_equal
(
results
[
'img'
],
expect_out
)
def
test_repr
(
self
):
cfg
=
dict
(
type
=
'RandomErasing'
,
erase_prob
=
0.5
,
mode
=
'const'
,
aspect_range
=
(
0.3
,
1.3
),
fill_color
=
(
255
,
255
,
255
))
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'RandomErasing(erase_prob=0.5, min_area_ratio=0.02, '
'max_area_ratio=0.4, aspect_range=(0.3, 1.3), mode=const, '
'fill_color=(255, 255, 255), fill_std=None)'
)
class
TestColorJitter
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'ColorJitter'
,
brightness
=
0.5
,
contrast
=
0.5
,
saturation
=
0.5
,
hue
=
0.2
)
def
test_initialize
(
self
):
cfg
=
dict
(
type
=
'ColorJitter'
,
brightness
=
(
0.8
,
1.2
),
contrast
=
[
0.5
,
1.5
],
saturation
=
0.
,
hue
=
0.2
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
transform
.
brightness
,
(
0.8
,
1.2
))
self
.
assertEqual
(
transform
.
contrast
,
(
0.5
,
1.5
))
self
.
assertIsNone
(
transform
.
saturation
)
self
.
assertEqual
(
transform
.
hue
,
(
-
0.2
,
0.2
))
with
self
.
assertRaisesRegex
(
ValueError
,
'If hue is a single number'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'hue'
:
-
0.2
}
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaisesRegex
(
TypeError
,
'hue should be a single'
):
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'hue'
:
[
0.5
,
0.4
,
0.2
]}
TRANSFORMS
.
build
(
cfg
)
logger
=
mmengine
.
MMLogger
.
get_current_instance
()
with
self
.
assertLogs
(
logger
,
'WARN'
)
as
log
:
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'hue'
:
[
-
1
,
0.4
]}
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertIn
(
'ColorJitter hue values'
,
log
.
output
[
0
])
self
.
assertEqual
(
transform
.
hue
,
(
-
0.5
,
0.4
))
def
test_transform
(
self
):
ori_img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
)
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
# test transform
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
ori_img
.
dtype
)
assert
not
np
.
equal
(
results
[
'img'
],
ori_img
).
all
()
# test call with brightness, contrast and saturation are all 0
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
dict
(
type
=
'ColorJitter'
,
brightness
=
0.
,
contrast
=
0.
,
saturation
=
0.
)
transform
=
TRANSFORMS
.
build
(
cfg
)
results
=
transform
(
results
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
ori_img
.
dtype
)
assert
np
.
equal
(
results
[
'img'
],
ori_img
).
all
()
# test call index
cfg
=
{
**
self
.
DEFAULT_ARGS
,
'contrast'
:
0.
}
transform
=
TRANSFORMS
.
build
(
cfg
)
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
0
)):
mmcv_module
=
'mmpretrain.datasets.transforms.processing.mmcv'
call_list
=
[
call
.
adjust_color
(
ANY
,
alpha
=
ANY
,
backend
=
'pillow'
),
call
.
adjust_hue
(
ANY
,
ANY
,
backend
=
'pillow'
),
call
.
adjust_brightness
(
ANY
,
ANY
,
backend
=
'pillow'
),
]
with
patch
(
mmcv_module
,
autospec
=
True
)
as
mock
:
transform
(
results
)
self
.
assertEqual
(
mock
.
mock_calls
,
call_list
)
def
test_repr
(
self
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'ColorJitter(brightness=(0.5, 1.5), '
'contrast=(0.5, 1.5), saturation=(0.5, 1.5), hue=(-0.2, 0.2))'
)
class
TestLighting
(
TestCase
):
def
setUp
(
self
):
EIGVAL
=
[
0.2175
,
0.0188
,
0.0045
]
EIGVEC
=
[
[
-
0.5836
,
-
0.6948
,
0.4203
],
[
-
0.5808
,
-
0.0045
,
-
0.814
],
[
-
0.5675
,
0.7192
,
0.4009
],
]
self
.
DEFAULT_ARGS
=
dict
(
type
=
'Lighting'
,
eigval
=
EIGVAL
,
eigvec
=
EIGVEC
,
alphastd
=
25.5
,
to_rgb
=
False
)
def
test_assertion
(
self
):
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'eigval'
]
=
-
1
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'eigvec'
]
=
None
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'alphastd'
]
=
'Lighting'
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'eigvec'
]
=
dict
()
TRANSFORMS
.
build
(
cfg
)
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'eigvec'
]
=
[
[
-
0.5836
,
-
0.6948
,
0.4203
],
[
-
0.5808
,
-
0.0045
,
-
0.814
],
[
-
0.5675
,
0.7192
,
0.4009
,
0.10
],
]
TRANSFORMS
.
build
(
cfg
)
def
test_transform
(
self
):
ori_img
=
np
.
ones
((
256
,
256
,
3
),
np
.
uint8
)
*
127
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
# Test transform with non-img-keyword result
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
lightening_module
=
TRANSFORMS
.
build
(
cfg
)
empty_results
=
dict
()
lightening_module
(
empty_results
)
# test call
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
lightening_module
=
TRANSFORMS
.
build
(
cfg
)
with
patch
(
'numpy.random'
,
np
.
random
.
RandomState
(
0
)):
results
=
lightening_module
(
results
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
ori_img
.
dtype
)
assert
not
np
.
equal
(
results
[
'img'
],
ori_img
).
all
()
# test call with alphastd == 0
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'alphastd'
]
=
0.0
lightening_module
=
TRANSFORMS
.
build
(
cfg
)
results
=
lightening_module
(
results
)
self
.
assertEqual
(
results
[
'img'
].
dtype
,
ori_img
.
dtype
)
assert
np
.
equal
(
results
[
'img'
],
ori_img
).
all
()
def
test_repr
(
self
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'Lighting(eigval=[0.2175, 0.0188, 0.0045], eigvec'
'=[[-0.5836, -0.6948, 0.4203], [-0.5808, -0.0045, -0.814], ['
'-0.5675, 0.7192, 0.4009]], alphastd=25.5, to_rgb=False)'
)
class
TestAlbumentations
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'Albumentations'
,
transforms
=
[
dict
(
type
=
'ChannelShuffle'
,
p
=
1
)])
@
pytest
.
mark
.
skipif
(
albumentations
is
None
,
reason
=
'No Albumentations module.'
)
def
test_assertion
(
self
):
# Test with non-list transforms
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
1
TRANSFORMS
.
build
(
cfg
)
# Test with non-dict transforms item.
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
[
dict
(
p
=
1
)]
TRANSFORMS
.
build
(
cfg
)
# Test with dict transforms item without keyword 'type'.
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
[[]]
TRANSFORMS
.
build
(
cfg
)
# Test with dict transforms item with wrong type.
with
self
.
assertRaises
(
TypeError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
[
dict
(
type
=
[])]
TRANSFORMS
.
build
(
cfg
)
# Test with dict transforms item with wrong type.
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'keymap'
]
=
[]
TRANSFORMS
.
build
(
cfg
)
@
pytest
.
mark
.
skipif
(
albumentations
is
None
,
reason
=
'No Albumentations module.'
)
def
test_transform
(
self
):
ori_img
=
np
.
random
.
randint
(
0
,
256
,
(
256
,
256
,
3
),
np
.
uint8
)
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
# Test transform with non-img-keyword result
with
self
.
assertRaises
(
AssertionError
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
albu_module
=
TRANSFORMS
.
build
(
cfg
)
empty_results
=
dict
()
albu_module
(
empty_results
)
# Test normal case
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
albu_module
=
TRANSFORMS
.
build
(
cfg
)
ablu_result
=
albu_module
(
results
)
# Test using 'Albu'
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'type'
]
=
'Albu'
albu_module
=
TRANSFORMS
.
build
(
cfg
)
ablu_result
=
albu_module
(
results
)
# Test with keymap
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'keymap'
]
=
dict
(
img
=
'image'
)
albu_module
=
TRANSFORMS
.
build
(
cfg
)
ablu_result
=
albu_module
(
results
)
# Test with nested transform
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
nested_transform_cfg
=
[
dict
(
type
=
'ShiftScaleRotate'
,
shift_limit
=
0.0625
,
scale_limit
=
0.0
,
rotate_limit
=
0
,
interpolation
=
1
,
p
=
0.5
),
dict
(
type
=
'RandomBrightnessContrast'
,
brightness_limit
=
[
0.1
,
0.3
],
contrast_limit
=
[
0.1
,
0.3
],
p
=
0.2
),
dict
(
type
=
'ChannelShuffle'
,
p
=
0.1
),
dict
(
type
=
'OneOf'
,
transforms
=
[
dict
(
type
=
'Blur'
,
blur_limit
=
3
,
p
=
1.0
),
dict
(
type
=
'MedianBlur'
,
blur_limit
=
3
,
p
=
1.0
)
],
p
=
0.1
),
]
cfg
[
'transforms'
]
=
nested_transform_cfg
mmpretrain_module
=
TRANSFORMS
.
build
(
cfg
)
mmpretrain_module
(
results
)
# test to be same with albumentations 3rd package
np
.
random
.
seed
(
0
)
random
.
seed
(
0
)
import
albumentations
as
A
ablu_transform_3rd
=
A
.
Compose
([
A
.
RandomCrop
(
width
=
256
,
height
=
256
),
A
.
HorizontalFlip
(
p
=
0.5
),
A
.
RandomBrightnessContrast
(
p
=
0.2
),
])
transformed_image_3rd
=
ablu_transform_3rd
(
image
=
copy
.
deepcopy
(
ori_img
))[
'image'
]
np
.
random
.
seed
(
0
)
random
.
seed
(
0
)
results
=
dict
(
img
=
copy
.
deepcopy
(
ori_img
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
[
dict
(
type
=
'RandomCrop'
,
width
=
256
,
height
=
256
),
dict
(
type
=
'HorizontalFlip'
,
p
=
0.5
),
dict
(
type
=
'RandomBrightnessContrast'
,
p
=
0.2
)
]
mmpretrain_module
=
TRANSFORMS
.
build
(
cfg
)
transformed_image_mmpretrain
=
mmpretrain_module
(
results
)[
'img'
]
assert
np
.
equal
(
transformed_image_3rd
,
transformed_image_mmpretrain
).
all
()
# Test class obj case
results
=
dict
(
img
=
np
.
random
.
randint
(
0
,
256
,
(
200
,
300
,
3
),
np
.
uint8
))
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'transforms'
]
=
[
dict
(
type
=
albumentations
.
SmallestMaxSize
,
max_size
=
400
,
p
=
1
)
]
albu_module
=
TRANSFORMS
.
build
(
cfg
)
ablu_result
=
albu_module
(
results
)
assert
'img'
in
ablu_result
assert
min
(
ablu_result
[
'img'
].
shape
[:
2
])
==
400
assert
ablu_result
[
'img_shape'
]
==
(
400
,
600
)
@
pytest
.
mark
.
skipif
(
albumentations
is
None
,
reason
=
'No Albumentations module.'
)
def
test_repr
(
self
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
"Albumentations(transforms=[{'type': "
"'ChannelShuffle', 'p': 1}])"
)
class
TestSimMIMMaskGenerator
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'SimMIMMaskGenerator'
,
input_size
=
192
,
mask_patch_size
=
32
,
model_patch_size
=
4
,
mask_ratio
=
0.6
)
def
test_transform
(
self
):
img
=
np
.
random
.
randint
(
0
,
256
,
(
3
,
192
,
192
),
np
.
uint8
)
results
=
{
'img'
:
img
}
module
=
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
results
=
module
(
results
)
self
.
assertTupleEqual
(
results
[
'img'
].
shape
,
(
3
,
192
,
192
))
self
.
assertTupleEqual
(
results
[
'mask'
].
shape
,
(
48
,
48
))
def
test_repr
(
self
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
self
.
assertEqual
(
repr
(
transform
),
'SimMIMMaskGenerator(input_size=192, '
'mask_patch_size=32, model_patch_size=4, mask_ratio=0.6)'
)
class
TestBEiTMaskGenerator
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'BEiTMaskGenerator'
,
input_size
=
(
14
,
14
),
num_masking_patches
=
75
,
max_num_patches
=
None
,
min_num_patches
=
16
)
def
test_transform
(
self
):
module
=
TRANSFORMS
.
build
(
self
.
DEFAULT_ARGS
)
results
=
module
({})
self
.
assertTupleEqual
(
results
[
'mask'
].
shape
,
(
14
,
14
))
def
test_repr
(
self
):
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
transform
=
TRANSFORMS
.
build
(
cfg
)
log_aspect_ratio
=
(
math
.
log
(
0.3
),
math
.
log
(
1
/
0.3
))
self
.
assertEqual
(
repr
(
transform
),
'BEiTMaskGenerator(height=14, width=14, '
'num_patches=196, num_masking_patches=75, min_num_patches=16, '
f
'max_num_patches=75, log_aspect_ratio=
{
log_aspect_ratio
}
)'
)
class
TestVisionTransformWrapper
(
TestCase
):
def
test_register
(
self
):
for
t
in
VISION_TRANSFORMS
:
self
.
assertIn
(
'torchvision/'
,
t
)
self
.
assertIn
(
t
,
TRANSFORMS
)
def
test_transform
(
self
):
img_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
'../../data/color.jpg'
)
data
=
{
'img'
:
Image
.
open
(
img_path
)}
# test normal transform
vision_trans
=
transforms
.
RandomResizedCrop
(
224
)
vision_transformed_img
=
vision_trans
(
data
[
'img'
])
mmcls_trans
=
TRANSFORMS
.
build
(
dict
(
type
=
'torchvision/RandomResizedCrop'
,
size
=
224
))
mmcls_transformed_img
=
mmcls_trans
(
data
)[
'img'
]
np
.
equal
(
np
.
array
(
vision_transformed_img
),
np
.
array
(
mmcls_transformed_img
))
# test convert type dtype
data
=
{
'img'
:
torch
.
randn
(
3
,
224
,
224
)}
vision_trans
=
transforms
.
ConvertImageDtype
(
torch
.
float
)
vision_transformed_img
=
vision_trans
(
data
[
'img'
])
mmcls_trans
=
TRANSFORMS
.
build
(
dict
(
type
=
'torchvision/ConvertImageDtype'
,
dtype
=
'float'
))
mmcls_transformed_img
=
mmcls_trans
(
data
)[
'img'
]
np
.
equal
(
np
.
array
(
vision_transformed_img
),
np
.
array
(
mmcls_transformed_img
))
# test transform with interpolation
data
=
{
'img'
:
Image
.
open
(
img_path
)}
if
digit_version
(
torchvision
.
__version__
)
>
digit_version
(
'0.8.0'
):
from
torchvision.transforms
import
InterpolationMode
interpolation_t
=
InterpolationMode
.
NEAREST
else
:
interpolation_t
=
Image
.
NEAREST
vision_trans
=
transforms
.
Resize
(
224
,
interpolation_t
)
vision_transformed_img
=
vision_trans
(
data
[
'img'
])
mmcls_trans
=
TRANSFORMS
.
build
(
dict
(
type
=
'torchvision/Resize'
,
size
=
224
,
interpolation
=
'nearest'
))
mmcls_transformed_img
=
mmcls_trans
(
data
)[
'img'
]
np
.
equal
(
np
.
array
(
vision_transformed_img
),
np
.
array
(
mmcls_transformed_img
))
# test compose transforms
data
=
{
'img'
:
Image
.
open
(
img_path
)}
vision_trans
=
transforms
.
Compose
([
transforms
.
Resize
(
176
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
PILToTensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float
),
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
])
vision_transformed_img
=
vision_trans
(
data
[
'img'
])
pipeline_cfg
=
[
dict
(
type
=
'LoadImageFromFile'
),
dict
(
type
=
'NumpyToPIL'
,
to_rgb
=
True
),
dict
(
type
=
'torchvision/Resize'
,
size
=
176
),
dict
(
type
=
'torchvision/RandomHorizontalFlip'
),
dict
(
type
=
'torchvision/PILToTensor'
),
dict
(
type
=
'torchvision/ConvertImageDtype'
,
dtype
=
'float'
),
dict
(
type
=
'torchvision/Normalize'
,
mean
=
(
0.485
,
0.456
,
0.406
),
std
=
(
0.229
,
0.224
,
0.225
),
)
]
pipeline
=
[
TRANSFORMS
.
build
(
t
)
for
t
in
pipeline_cfg
]
mmcls_trans
=
Compose
(
transforms
=
pipeline
)
mmcls_data
=
{
'img_path'
:
img_path
}
mmcls_transformed_img
=
mmcls_trans
(
mmcls_data
)[
'img'
]
np
.
equal
(
np
.
array
(
vision_transformed_img
),
np
.
array
(
mmcls_transformed_img
))
def
test_repr
(
self
):
vision_trans
=
transforms
.
RandomResizedCrop
(
224
)
mmcls_trans
=
TRANSFORMS
.
build
(
dict
(
type
=
'torchvision/RandomResizedCrop'
,
size
=
224
))
self
.
assertEqual
(
f
'TorchVision
{
repr
(
vision_trans
)
}
'
,
repr
(
mmcls_trans
))
tests/test_datasets/test_transforms/test_wrappers.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
from
mmcv.transforms
import
Resize
from
mmpretrain.datasets
import
GaussianBlur
,
MultiView
,
Solarize
def
test_multi_view
():
original_img
=
np
.
ones
((
4
,
4
,
3
),
dtype
=
np
.
uint8
)
# test 1 pipeline with 2 views
pipeline1
=
[
Resize
(
2
),
GaussianBlur
(
magnitude_range
=
(
0.1
,
2
),
magnitude_std
=
'inf'
)
]
transform
=
MultiView
([
pipeline1
],
2
)
results
=
dict
(
img
=
original_img
)
results
=
transform
(
results
)
assert
len
(
results
[
'img'
])
==
2
assert
results
[
'img'
][
0
].
shape
==
(
2
,
2
,
3
)
transform
=
MultiView
([
pipeline1
],
[
2
])
results
=
dict
(
img
=
original_img
)
results
=
transform
(
results
)
assert
len
(
results
[
'img'
])
==
2
assert
results
[
'img'
][
0
].
shape
==
(
2
,
2
,
3
)
# test 2 pipeline with 3 views
pipeline2
=
[
Solarize
(
thr
=
128
),
GaussianBlur
(
magnitude_range
=
(
0.1
,
2
),
magnitude_std
=
'inf'
)
]
transform
=
MultiView
([
pipeline1
,
pipeline2
],
[
1
,
2
])
results
=
dict
(
img
=
original_img
)
results
=
transform
(
results
)
assert
len
(
results
[
'img'
])
==
3
assert
results
[
'img'
][
0
].
shape
==
(
2
,
2
,
3
)
assert
results
[
'img'
][
1
].
shape
==
(
4
,
4
,
3
)
# test repr
assert
isinstance
(
str
(
transform
),
str
)
tests/test_engine/test_hooks/test_arcface_hooks.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
tempfile
from
unittest
import
TestCase
import
numpy
as
np
import
torch
from
mmengine.runner
import
Runner
from
torch.utils.data
import
DataLoader
,
Dataset
class
ExampleDataset
(
Dataset
):
def
__init__
(
self
):
self
.
index
=
0
self
.
metainfo
=
None
def
__getitem__
(
self
,
idx
):
results
=
dict
(
imgs
=
torch
.
rand
((
224
,
224
,
3
)).
float
(),
)
return
results
def
get_gt_labels
(
self
):
gt_labels
=
np
.
array
([
0
,
1
,
2
,
4
,
0
,
4
,
1
,
2
,
2
,
1
])
return
gt_labels
def
__len__
(
self
):
return
10
class
TestSetAdaptiveMarginsHook
(
TestCase
):
DEFAULT_HOOK_CFG
=
dict
(
type
=
'SetAdaptiveMarginsHook'
)
DEFAULT_MODEL
=
dict
(
type
=
'ImageClassifier'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
34
,
num_stages
=
4
,
out_indices
=
(
3
,
),
style
=
'pytorch'
),
neck
=
dict
(
type
=
'GlobalAveragePooling'
),
head
=
dict
(
type
=
'ArcFaceClsHead'
,
in_channels
=
512
,
num_classes
=
5
))
def
test_before_train
(
self
):
default_hooks
=
dict
(
timer
=
dict
(
type
=
'IterTimerHook'
),
logger
=
None
,
param_scheduler
=
dict
(
type
=
'ParamSchedulerHook'
),
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
1
),
sampler_seed
=
dict
(
type
=
'DistSamplerSeedHook'
),
visualization
=
dict
(
type
=
'VisualizationHook'
,
enable
=
False
),
)
tmpdir
=
tempfile
.
TemporaryDirectory
()
loader
=
DataLoader
(
ExampleDataset
(),
batch_size
=
2
)
self
.
runner
=
Runner
(
model
=
self
.
DEFAULT_MODEL
,
work_dir
=
tmpdir
.
name
,
train_dataloader
=
loader
,
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
1
),
log_level
=
'WARNING'
,
optim_wrapper
=
dict
(
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
)),
param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
milestones
=
[
1
,
2
],
gamma
=
0.1
),
default_scope
=
'mmpretrain'
,
default_hooks
=
default_hooks
,
experiment_name
=
'test_construct_with_arcface'
,
custom_hooks
=
[
self
.
DEFAULT_HOOK_CFG
])
default_margins
=
torch
.
tensor
([
0.5
]
*
5
)
torch
.
allclose
(
self
.
runner
.
model
.
head
.
margins
.
cpu
(),
default_margins
)
self
.
runner
.
call_hook
(
'before_train'
)
# counts = [2 ,3 , 3, 0, 2] -> [2 ,3 , 3, 1, 2] at least occur once
# feqercy**-0.25 = [0.84089642, 0.75983569, 0.75983569, 1., 0.84089642]
# normized = [0.33752196, 0. , 0. , 1. , 0.33752196]
# margins = [0.20188488, 0.05, 0.05, 0.5, 0.20188488]
expert_margins
=
torch
.
tensor
(
[
0.20188488
,
0.05
,
0.05
,
0.5
,
0.20188488
])
torch
.
allclose
(
self
.
runner
.
model
.
head
.
margins
.
cpu
(),
expert_margins
)
model_cfg
=
{
**
self
.
DEFAULT_MODEL
}
model_cfg
[
'head'
]
=
dict
(
type
=
'LinearClsHead'
,
num_classes
=
1000
,
in_channels
=
512
,
loss
=
dict
(
type
=
'CrossEntropyLoss'
,
loss_weight
=
1.0
),
topk
=
(
1
,
5
),
)
self
.
runner
=
Runner
(
model
=
model_cfg
,
work_dir
=
tmpdir
.
name
,
train_dataloader
=
loader
,
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
1
),
log_level
=
'WARNING'
,
optim_wrapper
=
dict
(
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
)),
param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
milestones
=
[
1
,
2
],
gamma
=
0.1
),
default_scope
=
'mmpretrain'
,
default_hooks
=
default_hooks
,
experiment_name
=
'test_construct_wo_arcface'
,
custom_hooks
=
[
self
.
DEFAULT_HOOK_CFG
])
with
self
.
assertRaises
(
ValueError
):
self
.
runner
.
call_hook
(
'before_train'
)
tests/test_engine/test_hooks/test_class_num_check_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
from
unittest.mock
import
MagicMock
,
patch
from
mmpretrain.engine
import
ClassNumCheckHook
class
TestClassNumCheckHook
(
TestCase
):
def
setUp
(
self
):
self
.
runner
=
MagicMock
()
self
.
dataset
=
MagicMock
()
self
.
hook
=
ClassNumCheckHook
()
def
test_check_head
(
self
):
# check sequence of string
with
self
.
assertRaises
(
AssertionError
):
self
.
hook
.
_check_head
(
self
.
runner
,
self
.
dataset
)
# check no CLASSES
with
patch
.
object
(
self
.
runner
.
logger
,
'warning'
)
as
mock
:
self
.
dataset
.
CLASSES
=
None
self
.
hook
.
_check_head
(
self
.
runner
,
self
.
dataset
)
mock
.
assert_called_once
()
# check no modules
self
.
dataset
.
CLASSES
=
[
'str'
]
*
10
self
.
hook
.
_check_head
(
self
.
runner
,
self
.
dataset
)
# check number of classes not match
self
.
dataset
.
CLASSES
=
[
'str'
]
*
10
module1
=
MagicMock
(
spec_set
=
True
)
module2
=
MagicMock
(
num_classes
=
5
)
self
.
runner
.
model
.
named_modules
.
return_value
=
iter
([(
None
,
module1
),
(
None
,
module2
)])
with
self
.
assertRaises
(
AssertionError
):
self
.
hook
.
_check_head
(
self
.
runner
,
self
.
dataset
)
def
test_before_train
(
self
):
with
patch
.
object
(
self
.
hook
,
'_check_head'
)
as
mock
:
self
.
hook
.
before_train
(
self
.
runner
)
mock
.
assert_called_once
()
def
test_before_val
(
self
):
with
patch
.
object
(
self
.
hook
,
'_check_head'
)
as
mock
:
self
.
hook
.
before_val
(
self
.
runner
)
mock
.
assert_called_once
()
def
test_before_test
(
self
):
with
patch
.
object
(
self
.
hook
,
'_check_head'
)
as
mock
:
self
.
hook
.
before_test
(
self
.
runner
)
mock
.
assert_called_once
()
tests/test_engine/test_hooks/test_densecl_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
tempfile
from
unittest
import
TestCase
import
torch
import
torch.nn
as
nn
from
mmengine.device
import
get_device
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseModule
from
mmengine.optim
import
OptimWrapper
from
mmengine.runner
import
Runner
from
mmengine.structures
import
LabelData
from
torch.utils.data
import
Dataset
from
mmpretrain.engine
import
DenseCLHook
from
mmpretrain.models.selfsup
import
BaseSelfSupervisor
from
mmpretrain.registry
import
MODELS
from
mmpretrain.structures
import
DataSample
from
mmpretrain.utils
import
get_ori_model
class
DummyDataset
(
Dataset
):
METAINFO
=
dict
()
# type: ignore
data
=
torch
.
randn
(
12
,
2
)
label
=
torch
.
ones
(
12
)
@
property
def
metainfo
(
self
):
return
self
.
METAINFO
def
__len__
(
self
):
return
self
.
data
.
size
(
0
)
def
__getitem__
(
self
,
index
):
data_sample
=
DataSample
()
gt_label
=
LabelData
(
value
=
self
.
label
[
index
])
setattr
(
data_sample
,
'gt_label'
,
gt_label
)
return
dict
(
inputs
=
[
self
.
data
[
index
]],
data_samples
=
data_sample
)
@
MODELS
.
register_module
()
class
DenseCLDummyLayer
(
BaseModule
):
def
__init__
(
self
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
linear
=
nn
.
Linear
(
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
class
ToyModel
(
BaseSelfSupervisor
):
def
__init__
(
self
):
super
().
__init__
(
backbone
=
dict
(
type
=
'DenseCLDummyLayer'
))
self
.
loss_lambda
=
0.5
def
loss
(
self
,
inputs
,
data_samples
):
labels
=
[]
for
x
in
data_samples
:
labels
.
append
(
x
.
gt_label
.
value
)
labels
=
torch
.
stack
(
labels
)
outputs
=
self
.
backbone
(
inputs
[
0
])
loss
=
(
labels
-
outputs
).
sum
()
outputs
=
dict
(
loss
=
loss
)
return
outputs
class
TestDenseCLHook
(
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
temp_dir
.
cleanup
()
def
test_densecl_hook
(
self
):
device
=
get_device
()
dummy_dataset
=
DummyDataset
()
toy_model
=
ToyModel
().
to
(
device
)
densecl_hook
=
DenseCLHook
(
start_iters
=
1
)
# test DenseCLHook with model wrapper
runner
=
Runner
(
model
=
toy_model
,
work_dir
=
self
.
temp_dir
.
name
,
train_dataloader
=
dict
(
dataset
=
dummy_dataset
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
collate_fn
=
dict
(
type
=
'default_collate'
),
batch_size
=
1
,
num_workers
=
0
),
optim_wrapper
=
OptimWrapper
(
torch
.
optim
.
Adam
(
toy_model
.
parameters
())),
param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
milestones
=
[
1
]),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
),
custom_hooks
=
[
densecl_hook
],
default_hooks
=
dict
(
logger
=
None
),
log_processor
=
dict
(
window_size
=
1
),
experiment_name
=
'test_densecl_hook'
,
default_scope
=
'mmpretrain'
)
runner
.
train
()
if
runner
.
iter
>=
1
:
assert
get_ori_model
(
runner
.
model
).
loss_lambda
==
0.5
else
:
assert
get_ori_model
(
runner
.
model
).
loss_lambda
==
0.
tests/test_engine/test_hooks/test_ema_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
os.path
as
osp
import
tempfile
from
collections
import
OrderedDict
from
unittest
import
TestCase
from
unittest.mock
import
ANY
,
MagicMock
,
call
import
torch
import
torch.nn
as
nn
from
mmengine.device
import
get_device
from
mmengine.evaluator
import
Evaluator
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseModel
from
mmengine.optim
import
OptimWrapper
from
mmengine.runner
import
Runner
from
mmengine.testing
import
assert_allclose
from
torch.utils.data
import
Dataset
from
mmpretrain.engine
import
EMAHook
class
SimpleModel
(
BaseModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
para
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
*
args
,
mode
=
'tensor'
,
**
kwargs
):
if
mode
==
'predict'
:
return
self
.
para
.
clone
()
elif
mode
==
'loss'
:
return
{
'loss'
:
self
.
para
.
mean
()}
class
DummyDataset
(
Dataset
):
METAINFO
=
dict
()
# type: ignore
data
=
torch
.
randn
(
6
,
2
)
label
=
torch
.
ones
(
6
)
@
property
def
metainfo
(
self
):
return
self
.
METAINFO
def
__len__
(
self
):
return
self
.
data
.
size
(
0
)
def
__getitem__
(
self
,
index
):
return
dict
(
inputs
=
self
.
data
[
index
],
data_sample
=
self
.
label
[
index
])
class
TestEMAHook
(
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
state_dict
=
OrderedDict
(
meta
=
dict
(
epoch
=
1
,
iter
=
2
),
# The actual ema para
state_dict
=
{
'para'
:
torch
.
tensor
([
1.
])},
# The actual original para
ema_state_dict
=
{
'module.para'
:
torch
.
tensor
([
2.
])},
)
self
.
ckpt
=
osp
.
join
(
self
.
temp_dir
.
name
,
'ema.pth'
)
torch
.
save
(
state_dict
,
self
.
ckpt
)
def
tearDown
(
self
):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
temp_dir
.
cleanup
()
def
test_load_state_dict
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
ema_hook
=
EMAHook
()
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
batch_size
=
3
,
num_workers
=
0
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
),
work_dir
=
self
.
temp_dir
.
name
,
resume
=
False
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
ema_hook
],
default_scope
=
'mmpretrain'
,
experiment_name
=
'load_state_dict'
)
runner
.
train
()
assert_allclose
(
runner
.
model
.
para
,
torch
.
tensor
([
1.
],
device
=
device
))
def
test_evaluate_on_ema
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
# Test validate on ema model
evaluator
=
Evaluator
([
MagicMock
()])
runner
=
Runner
(
model
=
model
,
val_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
batch_size
=
3
,
num_workers
=
0
),
val_evaluator
=
evaluator
,
val_cfg
=
dict
(),
work_dir
=
self
.
temp_dir
.
name
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
dict
(
type
=
'EMAHook'
)],
default_scope
=
'mmpretrain'
,
experiment_name
=
'validate_on_ema'
)
runner
.
val
()
evaluator
.
metrics
[
0
].
process
.
assert_has_calls
([
call
(
ANY
,
[
torch
.
tensor
([
1.
]).
to
(
device
)]),
])
self
.
assertNotIn
(
call
(
ANY
,
[
torch
.
tensor
([
2.
]).
to
(
device
)]),
evaluator
.
metrics
[
0
].
process
.
mock_calls
)
# Test test on ema model
evaluator
=
Evaluator
([
MagicMock
()])
runner
=
Runner
(
model
=
model
,
test_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
batch_size
=
3
,
num_workers
=
0
),
test_evaluator
=
evaluator
,
test_cfg
=
dict
(),
work_dir
=
self
.
temp_dir
.
name
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
dict
(
type
=
'EMAHook'
)],
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_on_ema'
)
runner
.
test
()
evaluator
.
metrics
[
0
].
process
.
assert_has_calls
([
call
(
ANY
,
[
torch
.
tensor
([
1.
]).
to
(
device
)]),
])
self
.
assertNotIn
(
call
(
ANY
,
[
torch
.
tensor
([
2.
]).
to
(
device
)]),
evaluator
.
metrics
[
0
].
process
.
mock_calls
)
# Test validate on both models
evaluator
=
Evaluator
([
MagicMock
()])
runner
=
Runner
(
model
=
model
,
val_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
3
,
num_workers
=
0
),
val_evaluator
=
evaluator
,
val_cfg
=
dict
(),
work_dir
=
self
.
temp_dir
.
name
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
dict
(
type
=
'EMAHook'
,
evaluate_on_origin
=
True
)],
default_scope
=
'mmpretrain'
,
experiment_name
=
'validate_on_ema_false'
,
)
runner
.
val
()
evaluator
.
metrics
[
0
].
process
.
assert_has_calls
([
call
(
ANY
,
[
torch
.
tensor
([
1.
]).
to
(
device
)]),
call
(
ANY
,
[
torch
.
tensor
([
2.
]).
to
(
device
)]),
])
# Test test on both models
evaluator
=
Evaluator
([
MagicMock
()])
runner
=
Runner
(
model
=
model
,
test_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
3
,
num_workers
=
0
),
test_evaluator
=
evaluator
,
test_cfg
=
dict
(),
work_dir
=
self
.
temp_dir
.
name
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
dict
(
type
=
'EMAHook'
,
evaluate_on_origin
=
True
)],
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_on_ema_false'
,
)
runner
.
test
()
evaluator
.
metrics
[
0
].
process
.
assert_has_calls
([
call
(
ANY
,
[
torch
.
tensor
([
1.
]).
to
(
device
)]),
call
(
ANY
,
[
torch
.
tensor
([
2.
]).
to
(
device
)]),
])
# Test evaluate_on_ema=False
evaluator
=
Evaluator
([
MagicMock
()])
with
self
.
assertWarnsRegex
(
UserWarning
,
'evaluate_on_origin'
):
runner
=
Runner
(
model
=
model
,
test_dataloader
=
dict
(
dataset
=
DummyDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
False
),
batch_size
=
3
,
num_workers
=
0
),
test_evaluator
=
evaluator
,
test_cfg
=
dict
(),
work_dir
=
self
.
temp_dir
.
name
,
load_from
=
self
.
ckpt
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
dict
(
type
=
'EMAHook'
,
evaluate_on_ema
=
False
)],
default_scope
=
'mmpretrain'
,
experiment_name
=
'not_test_on_ema'
)
runner
.
test
()
evaluator
.
metrics
[
0
].
process
.
assert_has_calls
([
call
(
ANY
,
[
torch
.
tensor
([
2.
]).
to
(
device
)]),
])
self
.
assertNotIn
(
call
(
ANY
,
[
torch
.
tensor
([
1.
]).
to
(
device
)]),
evaluator
.
metrics
[
0
].
process
.
mock_calls
)
tests/test_engine/test_hooks/test_precise_bn_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
logging
import
tempfile
from
unittest
import
TestCase
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
torch
import
torch.nn
as
nn
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseModel
from
mmengine.runner
import
Runner
from
torch.utils.data
import
DataLoader
,
Dataset
from
mmpretrain.models.utils
import
ClsDataPreprocessor
from
mmpretrain.registry
import
HOOKS
from
mmpretrain.structures
import
DataSample
class
ExampleDataset
(
Dataset
):
def
__init__
(
self
):
self
.
index
=
0
self
.
metainfo
=
None
def
__getitem__
(
self
,
idx
):
results
=
dict
(
imgs
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
))
return
results
def
__len__
(
self
):
return
10
class
MockDataPreprocessor
(
ClsDataPreprocessor
):
"""mock preprocessor that do nothing."""
def
forward
(
self
,
data
,
training
=
False
):
return
dict
(
inputs
=
data
[
'imgs'
],
data_samples
=
DataSample
())
class
ExampleModel
(
BaseModel
):
def
__init__
(
self
):
super
(
ExampleModel
,
self
).
__init__
()
self
.
data_preprocessor
=
MockDataPreprocessor
()
self
.
conv
=
nn
.
Linear
(
1
,
1
)
self
.
bn
=
nn
.
BatchNorm1d
(
1
)
self
.
test_cfg
=
None
def
forward
(
self
,
inputs
,
data_samples
,
mode
=
'tensor'
):
inputs
=
inputs
.
to
(
next
(
self
.
parameters
()).
device
)
return
self
.
bn
(
self
.
conv
(
inputs
))
def
train_step
(
self
,
data
,
optim_wrapper
):
outputs
=
{
'loss'
:
0.5
,
'num_samples'
:
1
}
return
outputs
class
SingleBNModel
(
ExampleModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
bn
=
nn
.
BatchNorm1d
(
1
)
self
.
test_cfg
=
None
def
forward
(
self
,
inputs
,
data_samples
,
mode
=
'tensor'
):
return
self
.
bn
(
inputs
)
class
GNExampleModel
(
ExampleModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
gn
=
nn
.
GroupNorm
(
1
,
1
)
self
.
test_cfg
=
None
class
NoBNExampleModel
(
ExampleModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Linear
(
1
,
1
)
delattr
(
self
,
'bn'
)
self
.
test_cfg
=
None
def
forward
(
self
,
inputs
,
data_samples
,
mode
=
'tensor'
):
return
self
.
conv
(
inputs
)
class
TestPreciseBNHookHook
(
TestCase
):
DEFAULT_ARGS
=
dict
(
type
=
'PreciseBNHook'
,
num_samples
=
4
,
interval
=
1
)
count
=
0
def
setUp
(
self
)
->
None
:
# optimizer
self
.
optim_wrapper
=
dict
(
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
0.0001
))
# learning policy
self
.
epoch_param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
by_epoch
=
True
,
milestones
=
[
1
,
2
],
gamma
=
0.1
)
self
.
iter_param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
by_epoch
=
False
,
milestones
=
[
1
,
2
],
gamma
=
0.1
)
self
.
default_hooks
=
dict
(
timer
=
dict
(
type
=
'IterTimerHook'
),
logger
=
None
,
param_scheduler
=
dict
(
type
=
'ParamSchedulerHook'
),
checkpoint
=
dict
(
type
=
'CheckpointHook'
,
interval
=
1
),
sampler_seed
=
dict
(
type
=
'DistSamplerSeedHook'
),
visualization
=
dict
(
type
=
'VisualizationHook'
,
enable
=
False
),
)
self
.
epoch_train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
1
)
self
.
iter_train_cfg
=
dict
(
by_epoch
=
False
,
max_iters
=
5
)
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
self
.
preciseBN_cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
test_dataset
=
ExampleDataset
()
self
.
loader
=
DataLoader
(
test_dataset
,
batch_size
=
2
)
self
.
model
=
ExampleModel
()
def
test_construct
(
self
):
self
.
runner
=
Runner
(
model
=
self
.
model
,
work_dir
=
self
.
tmpdir
.
name
,
train_dataloader
=
self
.
loader
,
train_cfg
=
self
.
epoch_train_cfg
,
log_level
=
'WARNING'
,
optim_wrapper
=
self
.
optim_wrapper
,
param_scheduler
=
self
.
epoch_param_scheduler
,
default_scope
=
'mmpretrain'
,
default_hooks
=
self
.
default_hooks
,
experiment_name
=
'test_construct'
,
custom_hooks
=
None
)
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
precise_bn
=
HOOKS
.
build
(
cfg
)
self
.
assertEqual
(
precise_bn
.
num_samples
,
4
)
self
.
assertEqual
(
precise_bn
.
interval
,
1
)
with
pytest
.
raises
(
AssertionError
):
# num_samples must be larger than 0
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'num_samples'
]
=
-
1
HOOKS
.
build
(
cfg
)
with
pytest
.
raises
(
AssertionError
):
# interval must be larger than 0
cfg
=
copy
.
deepcopy
(
self
.
DEFAULT_ARGS
)
cfg
[
'interval'
]
=
0
HOOKS
.
build
(
cfg
)
@
patch
(
'mmengine.dist.get_dist_info'
,
MagicMock
(
return_value
=
(
1
,
2
)))
@
patch
(
'torch.distributed.all_reduce'
,
MagicMock
())
def
test_after_train_epoch_multi_machines
(
self
):
# Test with normal conv model in single machine
self
.
preciseBN_cfg
[
'priority'
]
=
'ABOVE_NORMAL'
self
.
runner
=
Runner
(
model
=
self
.
model
,
work_dir
=
self
.
tmpdir
.
name
,
train_dataloader
=
self
.
loader
,
train_cfg
=
self
.
epoch_train_cfg
,
log_level
=
'WARNING'
,
optim_wrapper
=
self
.
optim_wrapper
,
param_scheduler
=
self
.
epoch_param_scheduler
,
default_scope
=
'mmpretrain'
,
default_hooks
=
self
.
default_hooks
,
experiment_name
=
'test_after_train_epoch_multi_machines'
,
custom_hooks
=
[
self
.
preciseBN_cfg
])
self
.
runner
.
train
()
def
test_after_train_epoch
(
self
):
self
.
preciseBN_cfg
[
'priority'
]
=
'ABOVE_NORMAL'
self
.
runner
=
Runner
(
model
=
self
.
model
,
work_dir
=
self
.
tmpdir
.
name
,
train_dataloader
=
self
.
loader
,
train_cfg
=
self
.
epoch_train_cfg
,
log_level
=
'WARNING'
,
optim_wrapper
=
self
.
optim_wrapper
,
param_scheduler
=
self
.
epoch_param_scheduler
,
default_scope
=
'mmpretrain'
,
default_hooks
=
self
.
default_hooks
,
experiment_name
=
'test_after_train_epoch'
,
custom_hooks
=
[
self
.
preciseBN_cfg
])
# Test with normal conv model in single machine
self
.
runner
.
_train_loop
=
self
.
epoch_train_cfg
self
.
runner
.
train
()
# Test with only BN model
self
.
runner
.
model
=
SingleBNModel
()
self
.
runner
.
_train_loop
=
self
.
epoch_train_cfg
self
.
runner
.
train
()
# Test with GN model
self
.
runner
.
model
=
GNExampleModel
()
self
.
runner
.
_train_loop
=
self
.
epoch_train_cfg
self
.
runner
.
train
()
# Test with no BN model
self
.
runner
.
model
=
NoBNExampleModel
()
self
.
runner
.
_train_loop
=
self
.
epoch_train_cfg
self
.
runner
.
train
()
def
test_after_train_iter
(
self
):
# test precise bn hook in iter base loop
self
.
preciseBN_cfg
[
'priority'
]
=
'ABOVE_NORMAL'
test_dataset
=
ExampleDataset
()
self
.
loader
=
DataLoader
(
test_dataset
,
batch_size
=
2
)
self
.
runner
=
Runner
(
model
=
self
.
model
,
work_dir
=
self
.
tmpdir
.
name
,
train_dataloader
=
self
.
loader
,
train_cfg
=
self
.
iter_train_cfg
,
log_level
=
'WARNING'
,
optim_wrapper
=
self
.
optim_wrapper
,
param_scheduler
=
self
.
iter_param_scheduler
,
default_scope
=
'mmpretrain'
,
default_hooks
=
self
.
default_hooks
,
experiment_name
=
'test_after_train_iter'
,
custom_hooks
=
[
self
.
preciseBN_cfg
])
self
.
runner
.
train
()
def
tearDown
(
self
)
->
None
:
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory.
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
tmpdir
.
cleanup
()
tests/test_engine/test_hooks/test_retrievers_hooks.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
from
unittest.mock
import
MagicMock
import
torch
from
mmpretrain.engine
import
PrepareProtoBeforeValLoopHook
from
mmpretrain.models.retrievers
import
BaseRetriever
class
ToyRetriever
(
BaseRetriever
):
def
forward
(
self
,
inputs
,
data_samples
=
None
,
mode
:
str
=
'loss'
):
self
.
prototype_inited
is
False
def
prepare_prototype
(
self
):
"""Preprocessing the prototype before predict."""
self
.
prototype_vecs
=
torch
.
tensor
([
0
])
self
.
prototype_inited
=
True
class
TestPrepareProtBeforeValLoopHook
(
TestCase
):
def
setUp
(
self
):
self
.
hook
=
PrepareProtoBeforeValLoopHook
self
.
runner
=
MagicMock
()
self
.
runner
.
model
=
ToyRetriever
()
def
test_before_val
(
self
):
self
.
runner
.
model
.
prepare_prototype
()
self
.
assertTrue
(
self
.
runner
.
model
.
prototype_inited
)
self
.
hook
.
before_val
(
self
,
self
.
runner
)
self
.
assertIsNotNone
(
self
.
runner
.
model
.
prototype_vecs
)
self
.
assertTrue
(
self
.
runner
.
model
.
prototype_inited
)
tests/test_engine/test_hooks/test_simsiam_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
tempfile
from
unittest
import
TestCase
import
torch
import
torch.nn
as
nn
from
mmengine.device
import
get_device
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseModule
from
mmengine.runner
import
Runner
from
mmengine.structures
import
LabelData
from
torch.utils.data
import
Dataset
from
mmpretrain.engine
import
SimSiamHook
from
mmpretrain.models.selfsup
import
BaseSelfSupervisor
from
mmpretrain.registry
import
MODELS
from
mmpretrain.structures
import
DataSample
class
DummyDataset
(
Dataset
):
METAINFO
=
dict
()
# type: ignore
data
=
torch
.
randn
(
12
,
2
)
label
=
torch
.
ones
(
12
)
@
property
def
metainfo
(
self
):
return
self
.
METAINFO
def
__len__
(
self
):
return
self
.
data
.
size
(
0
)
def
__getitem__
(
self
,
index
):
data_sample
=
DataSample
()
gt_label
=
LabelData
(
value
=
self
.
label
[
index
])
setattr
(
data_sample
,
'gt_label'
,
gt_label
)
return
dict
(
inputs
=
[
self
.
data
[
index
]],
data_samples
=
data_sample
)
@
MODELS
.
register_module
()
class
SimSiamDummyLayer
(
BaseModule
):
def
__init__
(
self
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
predictor
=
nn
.
Linear
(
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
predictor
(
x
)
class
ToyModel
(
BaseSelfSupervisor
):
def
__init__
(
self
):
super
().
__init__
(
backbone
=
dict
(
type
=
'SimSiamDummyLayer'
))
def
extract_feat
(
self
):
pass
def
loss
(
self
,
inputs
,
data_samples
):
labels
=
[]
for
x
in
data_samples
:
labels
.
append
(
x
.
gt_label
.
value
)
labels
=
torch
.
stack
(
labels
)
outputs
=
self
.
backbone
(
inputs
[
0
])
loss
=
(
labels
-
outputs
).
sum
()
outputs
=
dict
(
loss
=
loss
)
return
outputs
class
TestSimSiamHook
(
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
temp_dir
.
cleanup
()
def
test_simsiam_hook
(
self
):
device
=
get_device
()
dummy_dataset
=
DummyDataset
()
toy_model
=
ToyModel
().
to
(
device
)
simsiam_hook
=
SimSiamHook
(
fix_pred_lr
=
True
,
lr
=
0.05
,
adjust_by_epoch
=
False
)
# test SimSiamHook
runner
=
Runner
(
model
=
toy_model
,
work_dir
=
self
.
temp_dir
.
name
,
train_dataloader
=
dict
(
dataset
=
dummy_dataset
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
collate_fn
=
dict
(
type
=
'default_collate'
),
batch_size
=
1
,
num_workers
=
0
),
optim_wrapper
=
dict
(
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.05
),
paramwise_cfg
=
dict
(
custom_keys
=
{
'predictor'
:
dict
(
fix_lr
=
True
)})),
param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
milestones
=
[
1
]),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
),
custom_hooks
=
[
simsiam_hook
],
default_hooks
=
dict
(
logger
=
None
),
log_processor
=
dict
(
window_size
=
1
),
experiment_name
=
'test_simsiam_hook'
,
default_scope
=
'mmpretrain'
)
runner
.
train
()
for
param_group
in
runner
.
optim_wrapper
.
optimizer
.
param_groups
:
if
'fix_lr'
in
param_group
and
param_group
[
'fix_lr'
]:
assert
param_group
[
'lr'
]
==
0.05
else
:
assert
param_group
[
'lr'
]
!=
0.05
tests/test_engine/test_hooks/test_swav_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
tempfile
from
unittest
import
TestCase
import
torch
import
torch.nn
as
nn
from
mmengine.device
import
get_device
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseModule
from
mmengine.optim
import
OptimWrapper
from
mmengine.runner
import
Runner
from
mmengine.structures
import
LabelData
from
torch.utils.data
import
Dataset
from
mmpretrain.engine
import
SwAVHook
from
mmpretrain.models.heads
import
SwAVHead
from
mmpretrain.models.selfsup
import
BaseSelfSupervisor
from
mmpretrain.registry
import
MODELS
from
mmpretrain.structures
import
DataSample
from
mmpretrain.utils
import
get_ori_model
class
DummyDataset
(
Dataset
):
METAINFO
=
dict
()
# type: ignore
data
=
torch
.
randn
(
12
,
2
)
label
=
torch
.
ones
(
12
)
@
property
def
metainfo
(
self
):
return
self
.
METAINFO
def
__len__
(
self
):
return
self
.
data
.
size
(
0
)
def
__getitem__
(
self
,
index
):
data_sample
=
DataSample
()
gt_label
=
LabelData
(
value
=
self
.
label
[
index
])
setattr
(
data_sample
,
'gt_label'
,
gt_label
)
return
dict
(
inputs
=
[
self
.
data
[
index
]],
data_samples
=
data_sample
)
@
MODELS
.
register_module
()
class
SwAVDummyLayer
(
BaseModule
):
def
__init__
(
self
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
self
.
linear
=
nn
.
Linear
(
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
class
ToyModel
(
BaseSelfSupervisor
):
def
__init__
(
self
):
super
().
__init__
(
backbone
=
dict
(
type
=
'SwAVDummyLayer'
))
self
.
prototypes_test
=
nn
.
Linear
(
1
,
1
)
self
.
head
=
SwAVHead
(
loss
=
dict
(
type
=
'SwAVLoss'
,
feat_dim
=
2
,
num_crops
=
[
2
,
6
],
num_prototypes
=
3
))
def
loss
(
self
,
inputs
,
data_samples
):
labels
=
[]
for
x
in
data_samples
:
labels
.
append
(
x
.
gt_label
.
value
)
labels
=
torch
.
stack
(
labels
)
outputs
=
self
.
backbone
(
inputs
[
0
])
loss
=
(
labels
-
outputs
).
sum
()
outputs
=
dict
(
loss
=
loss
)
return
outputs
class
TestSwAVHook
(
TestCase
):
def
setUp
(
self
):
self
.
temp_dir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
):
# `FileHandler` should be closed in Windows, otherwise we cannot
# delete the temporary directory
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
temp_dir
.
cleanup
()
def
test_swav_hook
(
self
):
device
=
get_device
()
dummy_dataset
=
DummyDataset
()
toy_model
=
ToyModel
().
to
(
device
)
swav_hook
=
SwAVHook
(
batch_size
=
1
,
epoch_queue_starts
=
15
,
crops_for_assign
=
[
0
,
1
],
feat_dim
=
128
,
queue_length
=
300
,
frozen_layers_cfg
=
dict
(
prototypes
=
2
))
# test SwAVHook
runner
=
Runner
(
model
=
toy_model
,
work_dir
=
self
.
temp_dir
.
name
,
train_dataloader
=
dict
(
dataset
=
dummy_dataset
,
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
collate_fn
=
dict
(
type
=
'default_collate'
),
batch_size
=
1
,
num_workers
=
0
),
optim_wrapper
=
OptimWrapper
(
torch
.
optim
.
Adam
(
toy_model
.
parameters
())),
param_scheduler
=
dict
(
type
=
'MultiStepLR'
,
milestones
=
[
1
]),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
),
custom_hooks
=
[
swav_hook
],
default_hooks
=
dict
(
logger
=
None
),
log_processor
=
dict
(
window_size
=
1
),
experiment_name
=
'test_swav_hook'
,
default_scope
=
'mmpretrain'
)
runner
.
train
()
for
hook
in
runner
.
hooks
:
if
isinstance
(
hook
,
SwAVHook
):
assert
hook
.
queue_length
==
300
assert
get_ori_model
(
runner
.
model
).
head
.
loss_module
.
use_queue
is
False
tests/test_engine/test_hooks/test_switch_recipe_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
logging
import
os.path
as
osp
import
tempfile
from
typing
import
List
from
unittest
import
TestCase
from
unittest.mock
import
MagicMock
import
torch
import
torch.nn
as
nn
from
mmcv.transforms
import
Compose
from
mmengine.dataset
import
BaseDataset
,
ConcatDataset
,
RepeatDataset
from
mmengine.device
import
get_device
from
mmengine.logging
import
MMLogger
from
mmengine.model
import
BaseDataPreprocessor
,
BaseModel
from
mmengine.optim
import
OptimWrapper
from
mmengine.runner
import
Runner
from
mmpretrain.engine
import
SwitchRecipeHook
from
mmpretrain.models
import
CrossEntropyLoss
from
mmpretrain.models.heads.cls_head
import
ClsHead
from
mmpretrain.models.losses
import
LabelSmoothLoss
from
mmpretrain.models.utils.batch_augments
import
RandomBatchAugment
class
SimpleDataPreprocessor
(
BaseDataPreprocessor
):
def
__init__
(
self
):
super
().
__init__
()
self
.
batch_augments
=
None
def
forward
(
self
,
data
,
training
):
data
=
self
.
cast_data
(
data
)
assert
'inputs'
in
data
,
'No `input` in data.'
inputs
=
data
[
'inputs'
]
labels
=
data
[
'labels'
]
if
self
.
batch_augments
is
not
None
and
training
:
inputs
,
labels
=
self
.
batch_augments
(
inputs
,
labels
)
return
{
'inputs'
:
inputs
,
'labels'
:
labels
}
class
SimpleModel
(
BaseModel
):
def
__init__
(
self
):
super
().
__init__
()
self
.
data_preprocessor
=
SimpleDataPreprocessor
()
self
.
gap
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
3
,
10
)
self
.
loss_module
=
CrossEntropyLoss
(
use_soft
=
True
)
def
forward
(
self
,
inputs
,
labels
,
mode
=
'tensor'
):
if
mode
==
'loss'
:
score
=
self
.
fc
(
self
.
gap
(
inputs
).
view
(
inputs
.
size
(
0
),
-
1
))
loss
=
self
.
loss_module
(
score
,
labels
)
return
{
'loss'
:
loss
}
else
:
return
None
class
ExampleDataset
(
BaseDataset
):
def
load_data_list
(
self
)
->
List
[
dict
]:
return
[{
'inputs'
:
torch
.
rand
(
3
,
12
,
12
),
'labels'
:
torch
.
rand
(
10
),
}
for
_
in
range
(
10
)]
class
EmptyTransform
:
def
__call__
(
self
,
results
):
return
{}
class
TestSwitchRecipeHook
(
TestCase
):
def
setUp
(
self
):
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
def
tearDown
(
self
)
->
None
:
logging
.
shutdown
()
MMLogger
.
_instance_dict
.
clear
()
self
.
tmpdir
.
cleanup
()
def
test_init
(
self
):
# test `action_epoch` is set
with
self
.
assertRaisesRegex
(
AssertionError
,
'Please set'
):
SwitchRecipeHook
([
dict
(
batch_augments
=
None
)])
# test `action_epoch` is not repeated
with
self
.
assertRaisesRegex
(
AssertionError
,
'is repeated'
):
SwitchRecipeHook
([
dict
(
action_epoch
=
1
),
dict
(
action_epoch
=
1
)])
# test recipe build
hook
=
SwitchRecipeHook
([
dict
(
action_epoch
=
1
,
train_pipeline
=
[
dict
(
type
=
'LoadImageFromFile'
)],
loss
=
dict
(
type
=
'LabelSmoothLoss'
,
label_smooth_val
=
0.1
),
batch_augments
=
dict
(
augments
=
dict
(
type
=
'Mixup'
,
alpha
=
0.8
)),
)
])
self
.
assertIn
(
1
,
hook
.
schedule
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'train_pipeline'
],
Compose
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'loss'
],
LabelSmoothLoss
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'batch_augments'
],
RandomBatchAugment
)
# test recipe build with instance
hook
=
SwitchRecipeHook
([
dict
(
action_epoch
=
1
,
train_pipeline
=
[
MagicMock
()],
loss
=
MagicMock
(),
batch_augments
=
MagicMock
(),
)
])
self
.
assertIn
(
1
,
hook
.
schedule
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'train_pipeline'
],
Compose
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'loss'
],
MagicMock
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'batch_augments'
],
MagicMock
)
# test empty pieline and train_augments
hook
=
SwitchRecipeHook
(
[
dict
(
action_epoch
=
1
,
train_pipeline
=
[],
batch_augments
=
None
)])
self
.
assertIn
(
1
,
hook
.
schedule
)
self
.
assertIsInstance
(
hook
.
schedule
[
1
][
'train_pipeline'
],
Compose
)
self
.
assertIsNone
(
hook
.
schedule
[
1
][
'batch_augments'
])
def
test_do_switch
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
loss
=
CrossEntropyLoss
(
use_soft
=
True
)
loss
.
forward
=
MagicMock
(
side_effect
=
lambda
x
,
y
:
CrossEntropyLoss
.
forward
(
loss
,
x
,
y
))
batch_augments
=
RandomBatchAugment
(
dict
(
type
=
'Mixup'
,
alpha
=
0.5
))
switch_hook
=
SwitchRecipeHook
([
dict
(
action_epoch
=
2
,
train_pipeline
=
[
MagicMock
(
side_effect
=
lambda
x
:
x
)],
loss
=
loss
,
batch_augments
=
MagicMock
(
side_effect
=
lambda
x
,
y
:
batch_augments
(
x
,
y
)),
)
])
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ExampleDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
switch_hook
],
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_switch'
)
runner
.
train
()
self
.
assertEqual
(
switch_hook
.
schedule
[
2
][
'batch_augments'
].
call_count
,
2
)
self
.
assertEqual
(
switch_hook
.
schedule
[
2
][
'loss'
].
forward
.
call_count
,
2
)
self
.
assertEqual
(
switch_hook
.
schedule
[
2
][
'train_pipeline'
].
transforms
[
0
].
call_count
,
10
)
# Due to the unknown error in Windows environment, the unit test for
# `num_workers>0` is disabled temporarily
# switch_hook = SwitchRecipeHook(
# [dict(
# action_epoch=2,
# train_pipeline=[EmptyTransform()],
# )])
# runner = Runner(
# model=model,
# train_dataloader=dict(
# dataset=ExampleDataset(),
# sampler=dict(type='DefaultSampler', shuffle=True),
# batch_size=5,
# num_workers=1,
# persistent_workers=True,
# collate_fn=dict(type='default_collate'),
# ),
# optim_wrapper=OptimWrapper(
# optimizer=torch.optim.Adam(model.parameters(), lr=0.)),
# train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=10),
# work_dir=self.tmpdir.name,
# default_hooks=dict(logger=None),
# custom_hooks=[switch_hook],
# default_scope='mmpretrain',
# experiment_name='test_switch_multi_workers')
# with self.assertRaisesRegex(AssertionError, 'No `input` in data.'):
# # If the pipeline switch works, the data_preprocessor cannot
# # receive `inputs`.
# runner.train()
def
test_resume
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
loss
=
CrossEntropyLoss
(
use_soft
=
True
)
loss
.
forward
=
MagicMock
(
side_effect
=
lambda
x
,
y
:
CrossEntropyLoss
.
forward
(
loss
,
x
,
y
))
batch_augments
=
RandomBatchAugment
(
dict
(
type
=
'Mixup'
,
alpha
=
0.5
))
switch_hook
=
SwitchRecipeHook
([
dict
(
action_epoch
=
1
,
train_pipeline
=
[
MagicMock
(
side_effect
=
lambda
x
:
x
)]),
dict
(
action_epoch
=
2
,
loss
=
loss
),
dict
(
action_epoch
=
4
,
batch_augments
=
MagicMock
(
side_effect
=
lambda
x
,
y
:
batch_augments
(
x
,
y
)),
),
])
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ExampleDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
switch_hook
],
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_resume1'
)
runner
.
train
()
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ExampleDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
4
,
val_interval
=
10
),
resume
=
True
,
load_from
=
osp
.
join
(
self
.
tmpdir
.
name
,
'epoch_2.pth'
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
custom_hooks
=
[
switch_hook
],
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_resume2'
)
with
self
.
assertLogs
(
runner
.
logger
,
'INFO'
)
as
logs
:
runner
.
train
()
prefix
=
'INFO:mmengine:'
self
.
assertIn
(
prefix
+
'Switch train pipeline (resume recipe of epoch 1).'
,
logs
.
output
)
self
.
assertIn
(
prefix
+
'Switch loss (resume recipe of epoch 2).'
,
logs
.
output
)
self
.
assertIn
(
prefix
+
'Switch batch augments at epoch 4.'
,
logs
.
output
)
def
test_switch_train_pipeline
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ConcatDataset
([
ExampleDataset
(),
ExampleDataset
()]),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_concat_dataset'
)
pipeline
=
MagicMock
()
SwitchRecipeHook
.
_switch_train_pipeline
(
runner
,
pipeline
)
self
.
assertIs
(
runner
.
train_dataloader
.
dataset
.
datasets
[
0
].
pipeline
,
pipeline
)
self
.
assertIs
(
runner
.
train_dataloader
.
dataset
.
datasets
[
1
].
pipeline
,
pipeline
)
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
RepeatDataset
(
ExampleDataset
(),
3
),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_repeat_dataset'
)
pipeline
=
MagicMock
()
SwitchRecipeHook
.
_switch_train_pipeline
(
runner
,
pipeline
)
self
.
assertIs
(
runner
.
train_dataloader
.
dataset
.
dataset
.
pipeline
,
pipeline
)
def
test_switch_loss
(
self
):
device
=
get_device
()
model
=
SimpleModel
().
to
(
device
)
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ExampleDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_model_loss'
)
loss
=
CrossEntropyLoss
(
use_soft
=
True
)
SwitchRecipeHook
.
_switch_loss
(
runner
,
loss
)
self
.
assertIs
(
runner
.
model
.
loss_module
,
loss
)
model
.
add_module
(
'head'
,
ClsHead
())
del
model
.
loss_module
runner
=
Runner
(
model
=
model
,
train_dataloader
=
dict
(
dataset
=
ExampleDataset
(),
sampler
=
dict
(
type
=
'DefaultSampler'
,
shuffle
=
True
),
batch_size
=
5
,
num_workers
=
0
,
collate_fn
=
dict
(
type
=
'default_collate'
),
),
optim_wrapper
=
OptimWrapper
(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.
)),
train_cfg
=
dict
(
by_epoch
=
True
,
max_epochs
=
2
,
val_interval
=
10
),
work_dir
=
self
.
tmpdir
.
name
,
default_hooks
=
dict
(
logger
=
None
),
default_scope
=
'mmpretrain'
,
experiment_name
=
'test_head_loss'
)
loss
=
CrossEntropyLoss
(
use_soft
=
True
)
SwitchRecipeHook
.
_switch_loss
(
runner
,
loss
)
self
.
assertIs
(
runner
.
model
.
head
.
loss_module
,
loss
)
tests/test_engine/test_hooks/test_visualization_hook.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
tempfile
from
unittest
import
TestCase
from
unittest.mock
import
ANY
,
MagicMock
,
patch
import
torch
from
mmengine.runner
import
EpochBasedTrainLoop
,
IterBasedTrainLoop
from
mmpretrain.engine
import
VisualizationHook
from
mmpretrain.registry
import
HOOKS
from
mmpretrain.structures
import
DataSample
from
mmpretrain.visualization
import
UniversalVisualizer
class
TestVisualizationHook
(
TestCase
):
def
setUp
(
self
)
->
None
:
UniversalVisualizer
.
get_instance
(
'visualizer'
)
data_sample
=
DataSample
().
set_gt_label
(
1
).
set_pred_label
(
2
)
data_sample
.
set_metainfo
({
'img_path'
:
'tests/data/color.jpg'
})
self
.
data_batch
=
{
'inputs'
:
torch
.
randint
(
0
,
256
,
(
10
,
3
,
224
,
224
)),
'data_sample'
:
[
data_sample
]
*
10
}
self
.
outputs
=
[
data_sample
]
*
10
self
.
tmpdir
=
tempfile
.
TemporaryDirectory
()
def
test_draw_samples
(
self
):
# test enable=False
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
False
)
hook
:
VisualizationHook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
_draw_samples
(
1
,
self
.
data_batch
,
self
.
outputs
,
step
=
1
)
mock
.
assert_not_called
()
# test enable=True
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
,
show
=
True
)
hook
:
VisualizationHook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
_draw_samples
(
0
,
self
.
data_batch
,
self
.
outputs
,
step
=
0
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
0
],
step
=
0
,
show
=
True
,
name
=
'color.jpg'
)
# test samples without path
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
)
hook
:
VisualizationHook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
outputs
=
[
DataSample
()]
*
10
hook
.
_draw_samples
(
0
,
self
.
data_batch
,
outputs
,
step
=
0
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
outputs
[
0
],
step
=
0
,
show
=
False
,
name
=
'0'
)
# test out_dir
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
,
out_dir
=
self
.
tmpdir
.
name
)
hook
:
VisualizationHook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
_draw_samples
(
0
,
self
.
data_batch
,
self
.
outputs
,
step
=
0
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
0
],
step
=
0
,
show
=
False
,
name
=
'color.jpg'
,
out_file
=
osp
.
join
(
self
.
tmpdir
.
name
,
'color.jpg_0.png'
))
# test sample idx
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
,
interval
=
4
)
hook
:
VisualizationHook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
_draw_samples
(
1
,
self
.
data_batch
,
self
.
outputs
,
step
=
0
)
mock
.
assert_called_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
2
],
step
=
0
,
show
=
False
,
name
=
'color.jpg'
,
)
mock
.
assert_called_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
6
],
step
=
0
,
show
=
False
,
name
=
'color.jpg'
,
)
def
test_after_val_iter
(
self
):
runner
=
MagicMock
()
# test epoch-based
runner
.
train_loop
=
MagicMock
(
spec
=
EpochBasedTrainLoop
)
runner
.
epoch
=
5
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
)
hook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
after_val_iter
(
runner
,
0
,
self
.
data_batch
,
self
.
outputs
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
0
],
step
=
5
,
show
=
False
,
name
=
'color.jpg'
,
)
# test iter-based
runner
.
train_loop
=
MagicMock
(
spec
=
IterBasedTrainLoop
)
runner
.
iter
=
300
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
)
hook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
after_val_iter
(
runner
,
0
,
self
.
data_batch
,
self
.
outputs
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
0
],
step
=
300
,
show
=
False
,
name
=
'color.jpg'
,
)
def
test_after_test_iter
(
self
):
runner
=
MagicMock
()
cfg
=
dict
(
type
=
'VisualizationHook'
,
enable
=
True
)
hook
=
HOOKS
.
build
(
cfg
)
with
patch
.
object
(
hook
.
_visualizer
,
'visualize_cls'
)
as
mock
:
hook
.
after_test_iter
(
runner
,
0
,
self
.
data_batch
,
self
.
outputs
)
mock
.
assert_called_once_with
(
image
=
ANY
,
data_sample
=
self
.
outputs
[
0
],
step
=
0
,
show
=
False
,
name
=
'color.jpg'
,
)
def
tearDown
(
self
)
->
None
:
self
.
tmpdir
.
cleanup
()
tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
torch
from
torch
import
nn
from
mmpretrain.engine
import
LearningRateDecayOptimWrapperConstructor
from
mmpretrain.models
import
ImageClassifier
,
VisionTransformer
class
ToyViTBackbone
(
nn
.
Module
):
get_layer_depth
=
VisionTransformer
.
get_layer_depth
def
__init__
(
self
,
num_layers
=
2
):
super
().
__init__
()
self
.
cls_token
=
nn
.
Parameter
(
torch
.
ones
(
1
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
ones
(
1
))
self
.
num_layers
=
num_layers
self
.
layers
=
nn
.
ModuleList
()
for
_
in
range
(
num_layers
):
layer
=
nn
.
Conv2d
(
3
,
3
,
1
)
self
.
layers
.
append
(
layer
)
class
ToyViT
(
nn
.
Module
):
get_layer_depth
=
ImageClassifier
.
get_layer_depth
def
__init__
(
self
):
super
().
__init__
()
# add some variables to meet unit test coverate rate
self
.
backbone
=
ToyViTBackbone
()
self
.
head
=
nn
.
Linear
(
1
,
1
)
class
TestLearningRateDecayOptimWrapperConstructor
(
TestCase
):
base_lr
=
1.0
base_wd
=
0.05
def
test_add_params
(
self
):
model
=
ToyViT
()
optim_wrapper_cfg
=
dict
(
type
=
'OptimWrapper'
,
optimizer
=
dict
(
type
=
'AdamW'
,
lr
=
self
.
base_lr
,
betas
=
(
0.9
,
0.999
),
weight_decay
=
self
.
base_wd
))
paramwise_cfg
=
dict
(
layer_decay_rate
=
2.0
,
bias_decay_mult
=
0.
,
custom_keys
=
{
'.cls_token'
:
dict
(
decay_mult
=
0.0
),
'.pos_embed'
:
dict
(
decay_mult
=
0.0
),
})
constructor
=
LearningRateDecayOptimWrapperConstructor
(
optim_wrapper_cfg
=
optim_wrapper_cfg
,
paramwise_cfg
=
paramwise_cfg
,
)
optimizer_wrapper
=
constructor
(
model
)
expected_groups
=
[{
'weight_decay'
:
0.0
,
'lr'
:
8
*
self
.
base_lr
,
'param_name'
:
'backbone.cls_token'
,
},
{
'weight_decay'
:
0.0
,
'lr'
:
8
*
self
.
base_lr
,
'param_name'
:
'backbone.pos_embed'
,
},
{
'weight_decay'
:
self
.
base_wd
,
'lr'
:
4
*
self
.
base_lr
,
'param_name'
:
'backbone.layers.0.weight'
,
},
{
'weight_decay'
:
0.0
,
'lr'
:
4
*
self
.
base_lr
,
'param_name'
:
'backbone.layers.0.bias'
,
},
{
'weight_decay'
:
self
.
base_wd
,
'lr'
:
2
*
self
.
base_lr
,
'param_name'
:
'backbone.layers.1.weight'
,
},
{
'weight_decay'
:
0.0
,
'lr'
:
2
*
self
.
base_lr
,
'param_name'
:
'backbone.layers.1.bias'
,
},
{
'weight_decay'
:
self
.
base_wd
,
'lr'
:
1
*
self
.
base_lr
,
'param_name'
:
'head.weight'
,
},
{
'weight_decay'
:
0.0
,
'lr'
:
1
*
self
.
base_lr
,
'param_name'
:
'head.bias'
,
}]
self
.
assertIsInstance
(
optimizer_wrapper
.
optimizer
,
torch
.
optim
.
AdamW
)
self
.
assertEqual
(
optimizer_wrapper
.
optimizer
.
defaults
[
'lr'
],
self
.
base_lr
)
self
.
assertEqual
(
optimizer_wrapper
.
optimizer
.
defaults
[
'weight_decay'
],
self
.
base_wd
)
param_groups
=
optimizer_wrapper
.
optimizer
.
param_groups
self
.
assertEqual
(
len
(
param_groups
),
len
(
expected_groups
))
for
expect
,
param
in
zip
(
expected_groups
,
param_groups
):
self
.
assertEqual
(
param
[
'param_name'
],
expect
[
'param_name'
])
self
.
assertEqual
(
param
[
'lr'
],
expect
[
'lr'
])
self
.
assertEqual
(
param
[
'weight_decay'
],
expect
[
'weight_decay'
])
tests/test_evaluation/test_metrics/test_gqa.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
mmengine.evaluator
import
Evaluator
from
mmpretrain.structures
import
DataSample
class
TestScienceQAMetric
:
def
test_evaluate
(
self
):
meta_info
=
{
'pred_answer'
:
'dog'
,
'gt_answer'
:
'dog'
,
}
data_sample
=
DataSample
(
metainfo
=
meta_info
)
data_samples
=
[
data_sample
for
_
in
range
(
10
)]
evaluator
=
Evaluator
(
dict
(
type
=
'mmpretrain.GQAAcc'
))
evaluator
.
process
(
data_samples
)
res
=
evaluator
.
evaluate
(
4
)
assert
res
[
'GQA/acc'
]
==
1.0
meta_info
=
{
'pred_answer'
:
'dog'
,
'gt_answer'
:
'cat'
,
}
data_sample
=
DataSample
(
metainfo
=
meta_info
)
data_samples
=
[
data_sample
for
_
in
range
(
10
)]
evaluator
=
Evaluator
(
dict
(
type
=
'mmpretrain.GQAAcc'
))
evaluator
.
process
(
data_samples
)
res
=
evaluator
.
evaluate
(
4
)
assert
res
[
'GQA/acc'
]
==
0.0
tests/test_evaluation/test_metrics/test_metric_utils.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
import
pytest
import
torch
from
mmpretrain.models.losses.utils
import
convert_to_one_hot
def
ori_convert_to_one_hot
(
targets
:
torch
.
Tensor
,
classes
)
->
torch
.
Tensor
:
assert
(
torch
.
max
(
targets
).
item
()
<
classes
),
'Class Index must be less than number of classes'
one_hot_targets
=
torch
.
zeros
((
targets
.
shape
[
0
],
classes
),
dtype
=
torch
.
long
,
device
=
targets
.
device
)
one_hot_targets
.
scatter_
(
1
,
targets
.
long
(),
1
)
return
one_hot_targets
def
test_convert_to_one_hot
():
# label should smaller than classes
targets
=
torch
.
tensor
([
1
,
2
,
3
,
8
,
5
])
classes
=
5
with
pytest
.
raises
(
AssertionError
):
_
=
convert_to_one_hot
(
targets
,
classes
)
# test with original impl
classes
=
10
targets
=
torch
.
randint
(
high
=
classes
,
size
=
(
10
,
1
))
ori_one_hot_targets
=
torch
.
zeros
((
targets
.
shape
[
0
],
classes
),
dtype
=
torch
.
long
,
device
=
targets
.
device
)
ori_one_hot_targets
.
scatter_
(
1
,
targets
.
long
(),
1
)
one_hot_targets
=
convert_to_one_hot
(
targets
,
classes
)
assert
torch
.
equal
(
ori_one_hot_targets
,
one_hot_targets
)
tests/test_evaluation/test_metrics/test_multi_label.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
numpy
as
np
import
sklearn.metrics
import
torch
from
mmengine.evaluator
import
Evaluator
from
mmengine.registry
import
init_default_scope
from
mmpretrain.evaluation.metrics
import
AveragePrecision
,
MultiLabelMetric
from
mmpretrain.structures
import
DataSample
init_default_scope
(
'mmpretrain'
)
class
TestMultiLabel
(
TestCase
):
def
test_calculate
(
self
):
"""Test using the metric from static method."""
y_true
=
[[
0
],
[
1
,
3
],
[
0
,
1
,
2
],
[
3
]]
y_pred
=
[[
0
,
3
],
[
0
,
2
],
[
1
,
2
],
[
2
,
3
]]
y_true_binary
=
np
.
array
([
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
1
],
])
y_pred_binary
=
np
.
array
([
[
1
,
0
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
],
[
0
,
0
,
1
,
1
],
])
y_pred_score
=
np
.
array
([
[
0.8
,
0
,
0
,
0.6
],
[
0.2
,
0
,
0.6
,
0
],
[
0
,
0.9
,
0.6
,
0
],
[
0
,
0
,
0.2
,
0.3
],
])
# Test with sequence of category indexes
res
=
MultiLabelMetric
.
calculate
(
y_pred
,
y_true
,
pred_indices
=
True
,
target_indices
=
True
,
num_classes
=
4
)
self
.
assertIsInstance
(
res
,
tuple
)
precision
,
recall
,
f1_score
,
support
=
res
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
y_pred_binary
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
y_pred_binary
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
y_pred_binary
,
average
=
'macro'
)
*
100
self
.
assertTensorEqual
(
precision
,
expect_precision
)
self
.
assertTensorEqual
(
recall
,
expect_recall
)
self
.
assertTensorEqual
(
f1_score
,
expect_f1
)
self
.
assertTensorEqual
(
support
,
7
)
# Test with onehot input
res
=
MultiLabelMetric
.
calculate
(
y_pred_binary
,
torch
.
from_numpy
(
y_true_binary
))
self
.
assertIsInstance
(
res
,
tuple
)
precision
,
recall
,
f1_score
,
support
=
res
# Expected values come from sklearn
self
.
assertTensorEqual
(
precision
,
expect_precision
)
self
.
assertTensorEqual
(
recall
,
expect_recall
)
self
.
assertTensorEqual
(
f1_score
,
expect_f1
)
self
.
assertTensorEqual
(
support
,
7
)
# Test with topk argument
res
=
MultiLabelMetric
.
calculate
(
y_pred_score
,
y_true
,
target_indices
=
True
,
topk
=
1
,
num_classes
=
4
)
self
.
assertIsInstance
(
res
,
tuple
)
precision
,
recall
,
f1_score
,
support
=
res
# Expected values come from sklearn
top1_y_pred
=
np
.
array
([
[
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
],
])
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
self
.
assertTensorEqual
(
precision
,
expect_precision
)
self
.
assertTensorEqual
(
recall
,
expect_recall
)
self
.
assertTensorEqual
(
f1_score
,
expect_f1
)
self
.
assertTensorEqual
(
support
,
7
)
# Test with thr argument
res
=
MultiLabelMetric
.
calculate
(
y_pred_score
,
y_true
,
target_indices
=
True
,
thr
=
0.25
,
num_classes
=
4
)
self
.
assertIsInstance
(
res
,
tuple
)
precision
,
recall
,
f1_score
,
support
=
res
# Expected values come from sklearn
thr_y_pred
=
np
.
array
([
[
1
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
],
[
0
,
0
,
0
,
1
],
])
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
self
.
assertTensorEqual
(
precision
,
expect_precision
)
self
.
assertTensorEqual
(
recall
,
expect_recall
)
self
.
assertTensorEqual
(
f1_score
,
expect_f1
)
self
.
assertTensorEqual
(
support
,
7
)
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
TypeError
,
"<class 'str'> is not"
):
MultiLabelMetric
.
calculate
(
y_pred
,
'hi'
,
num_classes
=
10
)
# Test with invalid input
with
self
.
assertRaisesRegex
(
AssertionError
,
'Invalid `average` argument,'
):
MultiLabelMetric
.
calculate
(
y_pred
,
y_true
,
average
=
'm'
,
num_classes
=
10
)
y_true_binary
=
np
.
array
([[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
1
]])
y_pred_binary
=
np
.
array
([[
1
,
0
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
]])
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
AssertionError
,
'The size of pred'
):
MultiLabelMetric
.
calculate
(
y_pred_binary
,
y_true_binary
)
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
TypeError
,
'The `pred` and `target` must'
):
MultiLabelMetric
.
calculate
(
y_pred_binary
,
5
)
def
test_evaluate
(
self
):
y_true
=
[[
0
],
[
1
,
3
],
[
0
,
1
,
2
],
[
3
]]
y_true_binary
=
torch
.
tensor
([
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
1
],
])
y_pred_score
=
torch
.
tensor
([
[
0.8
,
0
,
0
,
0.6
],
[
0.2
,
0
,
0.6
,
0
],
[
0
,
0.9
,
0.6
,
0
],
[
0
,
0
,
0.2
,
0.3
],
])
pred
=
[
DataSample
(
num_classes
=
4
).
set_pred_score
(
i
).
set_gt_label
(
j
)
for
i
,
j
in
zip
(
y_pred_score
,
y_true
)
]
# Test with default argument
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
thr05_y_pred
=
np
.
array
([
[
1
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
],
[
0
,
0
,
0
,
0
],
])
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'macro'
)
*
100
self
.
assertEqual
(
res
[
'multi-label/precision'
],
expect_precision
)
self
.
assertEqual
(
res
[
'multi-label/recall'
],
expect_recall
)
self
.
assertEqual
(
res
[
'multi-label/f1-score'
],
expect_f1
)
# Test with topk argument
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
,
topk
=
1
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
top1_y_pred
=
np
.
array
([
[
1
,
0
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
0
,
1
],
])
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
top1_y_pred
,
average
=
'macro'
)
*
100
self
.
assertEqual
(
res
[
'multi-label/precision_top1'
],
expect_precision
)
self
.
assertEqual
(
res
[
'multi-label/recall_top1'
],
expect_recall
)
self
.
assertEqual
(
res
[
'multi-label/f1-score_top1'
],
expect_f1
)
# Test with both argument
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
,
thr
=
0.25
,
topk
=
1
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
# Expected values come from sklearn
thr_y_pred
=
np
.
array
([
[
1
,
0
,
0
,
1
],
[
0
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
],
[
0
,
0
,
0
,
1
],
])
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
thr_y_pred
,
average
=
'macro'
)
*
100
self
.
assertEqual
(
res
[
'multi-label/precision_thr-0.25'
],
expect_precision
)
self
.
assertEqual
(
res
[
'multi-label/recall_thr-0.25'
],
expect_recall
)
self
.
assertEqual
(
res
[
'multi-label/f1-score_thr-0.25'
],
expect_f1
)
# Test with average micro
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
,
average
=
'micro'
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
# Expected values come from sklearn
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'micro'
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'micro'
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
thr05_y_pred
,
average
=
'micro'
)
*
100
self
.
assertAlmostEqual
(
res
[
'multi-label/precision_micro'
],
expect_precision
,
places
=
4
)
self
.
assertAlmostEqual
(
res
[
'multi-label/recall_micro'
],
expect_recall
,
places
=
4
)
self
.
assertAlmostEqual
(
res
[
'multi-label/f1-score_micro'
],
expect_f1
,
places
=
4
)
# Test with average None
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
,
average
=
None
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
# Expected values come from sklearn
expect_precision
=
sklearn
.
metrics
.
precision_score
(
y_true_binary
,
thr05_y_pred
,
average
=
None
)
*
100
expect_recall
=
sklearn
.
metrics
.
recall_score
(
y_true_binary
,
thr05_y_pred
,
average
=
None
)
*
100
expect_f1
=
sklearn
.
metrics
.
f1_score
(
y_true_binary
,
thr05_y_pred
,
average
=
None
)
*
100
np
.
testing
.
assert_allclose
(
res
[
'multi-label/precision_classwise'
],
expect_precision
)
np
.
testing
.
assert_allclose
(
res
[
'multi-label/recall_classwise'
],
expect_recall
)
np
.
testing
.
assert_allclose
(
res
[
'multi-label/f1-score_classwise'
],
expect_f1
)
# Test with gt_score
pred
=
[
DataSample
(
num_classes
=
4
).
set_pred_score
(
i
).
set_gt_score
(
j
)
for
i
,
j
in
zip
(
y_pred_score
,
y_true_binary
)
]
evaluator
=
Evaluator
(
dict
(
type
=
'MultiLabelMetric'
,
items
=
[
'support'
]))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
4
)
self
.
assertIsInstance
(
res
,
dict
)
self
.
assertEqual
(
res
[
'multi-label/support'
],
7
)
def
assertTensorEqual
(
self
,
tensor
:
torch
.
Tensor
,
value
:
float
,
msg
=
None
,
**
kwarg
):
tensor
=
tensor
.
to
(
torch
.
float32
)
if
tensor
.
dim
()
==
0
:
tensor
=
tensor
.
unsqueeze
(
0
)
value
=
torch
.
FloatTensor
([
value
])
try
:
torch
.
testing
.
assert_allclose
(
tensor
,
value
,
**
kwarg
)
except
AssertionError
as
e
:
self
.
fail
(
self
.
_formatMessage
(
msg
,
str
(
e
)
+
str
(
tensor
)))
class
TestAveragePrecision
(
TestCase
):
def
test_evaluate
(
self
):
"""Test using the metric in the same way as Evalutor."""
y_pred
=
torch
.
tensor
([
[
0.9
,
0.8
,
0.3
,
0.2
],
[
0.1
,
0.2
,
0.2
,
0.1
],
[
0.7
,
0.5
,
0.9
,
0.3
],
[
0.8
,
0.1
,
0.1
,
0.2
],
])
y_true
=
torch
.
tensor
([
[
1
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
],
[
0
,
0
,
1
,
0
],
[
1
,
0
,
0
,
0
],
])
pred
=
[
DataSample
(
num_classes
=
4
).
set_pred_score
(
i
).
set_gt_score
(
j
)
for
i
,
j
in
zip
(
y_pred
,
y_true
)
]
# Test with default macro avergae
evaluator
=
Evaluator
(
dict
(
type
=
'AveragePrecision'
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
5
)
self
.
assertIsInstance
(
res
,
dict
)
self
.
assertAlmostEqual
(
res
[
'multi-label/mAP'
],
70.83333
,
places
=
4
)
# Test with average mode None
evaluator
=
Evaluator
(
dict
(
type
=
'AveragePrecision'
,
average
=
None
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
5
)
self
.
assertIsInstance
(
res
,
dict
)
aps
=
res
[
'multi-label/AP_classwise'
]
self
.
assertAlmostEqual
(
aps
[
0
],
100.
,
places
=
4
)
self
.
assertAlmostEqual
(
aps
[
1
],
83.3333
,
places
=
4
)
self
.
assertAlmostEqual
(
aps
[
2
],
100
,
places
=
4
)
self
.
assertAlmostEqual
(
aps
[
3
],
0
,
places
=
4
)
# Test with gt_label without score
pred
=
[
DataSample
(
num_classes
=
4
).
set_pred_score
(
i
).
set_gt_label
(
j
)
for
i
,
j
in
zip
(
y_pred
,
[[
0
,
1
],
[
1
],
[
2
],
[
0
]])
]
evaluator
=
Evaluator
(
dict
(
type
=
'AveragePrecision'
))
evaluator
.
process
(
pred
)
res
=
evaluator
.
evaluate
(
5
)
self
.
assertAlmostEqual
(
res
[
'multi-label/mAP'
],
70.83333
,
places
=
4
)
def
test_calculate
(
self
):
"""Test using the metric from static method."""
y_true
=
np
.
array
([
[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
1
,
1
,
0
],
[
0
,
0
,
0
,
1
],
])
y_pred
=
np
.
array
([
[
0.9
,
0.8
,
0.3
,
0.2
],
[
0.1
,
0.2
,
0.2
,
0.1
],
[
0.7
,
0.5
,
0.9
,
0.3
],
[
0.8
,
0.1
,
0.1
,
0.2
],
])
ap_score
=
AveragePrecision
.
calculate
(
y_pred
,
y_true
)
expect_ap
=
sklearn
.
metrics
.
average_precision_score
(
y_true
,
y_pred
)
*
100
self
.
assertTensorEqual
(
ap_score
,
expect_ap
)
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
AssertionError
,
'Invalid `average` argument,'
):
AveragePrecision
.
calculate
(
y_pred
,
y_true
,
average
=
'm'
)
y_true
=
np
.
array
([[
1
,
0
,
0
,
0
],
[
0
,
1
,
0
,
1
]])
y_pred
=
np
.
array
([[
1
,
0
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
1
,
0
]])
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
AssertionError
,
'Both `pred` and `target`'
):
AveragePrecision
.
calculate
(
y_pred
,
y_true
)
# Test with invalid inputs
with
self
.
assertRaisesRegex
(
TypeError
,
"<class 'int'> is not an"
):
AveragePrecision
.
calculate
(
y_pred
,
5
)
def
assertTensorEqual
(
self
,
tensor
:
torch
.
Tensor
,
value
:
float
,
msg
=
None
,
**
kwarg
):
tensor
=
tensor
.
to
(
torch
.
float32
)
if
tensor
.
dim
()
==
0
:
tensor
=
tensor
.
unsqueeze
(
0
)
value
=
torch
.
FloatTensor
([
value
])
try
:
torch
.
testing
.
assert_allclose
(
tensor
,
value
,
**
kwarg
)
except
AssertionError
as
e
:
self
.
fail
(
self
.
_formatMessage
(
msg
,
str
(
e
)
+
str
(
tensor
)))
tests/test_evaluation/test_metrics/test_multi_task_metrics.py
0 → 100644
View file @
1baf0566
# Copyright (c) OpenMMLab. All rights reserved.
from
unittest
import
TestCase
import
torch
from
mmpretrain.evaluation.metrics
import
MultiTasksMetric
from
mmpretrain.structures
import
DataSample
class
MultiTaskMetric
(
TestCase
):
data_pred
=
[
{
'task0'
:
torch
.
tensor
([
0.7
,
0.0
,
0.3
]),
'task1'
:
torch
.
tensor
([
0.5
,
0.2
,
0.3
])
},
{
'task0'
:
torch
.
tensor
([
0.0
,
0.0
,
1.0
]),
'task1'
:
torch
.
tensor
([
0.0
,
0.0
,
1.0
])
},
]
data_gt
=
[{
'task0'
:
0
,
'task1'
:
2
},
{
'task1'
:
2
}]
preds
=
[]
for
i
,
pred
in
enumerate
(
data_pred
):
sample
=
{}
for
task_name
in
pred
:
task_sample
=
DataSample
().
set_pred_score
(
pred
[
task_name
])
if
task_name
in
data_gt
[
i
]:
task_sample
.
set_gt_label
(
data_gt
[
i
][
task_name
])
task_sample
.
set_field
(
True
,
'eval_mask'
,
field_type
=
'metainfo'
)
else
:
task_sample
.
set_field
(
False
,
'eval_mask'
,
field_type
=
'metainfo'
)
sample
[
task_name
]
=
task_sample
.
to_dict
()
preds
.
append
(
sample
)
data2
=
zip
([
{
'task0'
:
torch
.
tensor
([
0.7
,
0.0
,
0.3
]),
'task1'
:
{
'task10'
:
torch
.
tensor
([
0.5
,
0.2
,
0.3
]),
'task11'
:
torch
.
tensor
([
0.4
,
0.3
,
0.3
])
}
},
{
'task0'
:
torch
.
tensor
([
0.0
,
0.0
,
1.0
]),
'task1'
:
{
'task10'
:
torch
.
tensor
([
0.1
,
0.6
,
0.3
]),
'task11'
:
torch
.
tensor
([
0.5
,
0.2
,
0.3
])
}
},
],
[{
'task0'
:
0
,
'task1'
:
{
'task10'
:
2
,
'task11'
:
0
}
},
{
'task0'
:
2
,
'task1'
:
{
'task10'
:
1
,
'task11'
:
0
}
}])
pred2
=
[]
for
score
,
label
in
data2
:
sample
=
{}
for
task_name
in
score
:
if
type
(
score
[
task_name
])
!=
dict
:
task_sample
=
DataSample
().
set_pred_score
(
score
[
task_name
])
task_sample
.
set_gt_label
(
label
[
task_name
])
sample
[
task_name
]
=
task_sample
.
to_dict
()
sample
[
task_name
][
'eval_mask'
]
=
True
else
:
sample
[
task_name
]
=
{}
sample
[
task_name
][
'eval_mask'
]
=
True
for
task_name2
in
score
[
task_name
]:
task_sample
=
DataSample
().
set_pred_score
(
score
[
task_name
][
task_name2
])
task_sample
.
set_gt_label
(
label
[
task_name
][
task_name2
])
sample
[
task_name
][
task_name2
]
=
task_sample
.
to_dict
()
sample
[
task_name
][
task_name2
][
'eval_mask'
]
=
True
pred2
.
append
(
sample
)
pred3
=
[{
'task0'
:
{
'eval_mask'
:
False
},
'task1'
:
{
'eval_mask'
:
False
}}]
task_metrics
=
{
'task0'
:
[
dict
(
type
=
'Accuracy'
,
topk
=
(
1
,
))],
'task1'
:
[
dict
(
type
=
'Accuracy'
,
topk
=
(
1
,
3
)),
dict
(
type
=
'SingleLabelMetric'
,
items
=
[
'precision'
,
'recall'
])
]
}
task_metrics2
=
{
'task0'
:
[
dict
(
type
=
'Accuracy'
,
topk
=
(
1
,
))],
'task1'
:
[
dict
(
type
=
'MultiTasksMetric'
,
task_metrics
=
{
'task10'
:
[
dict
(
type
=
'Accuracy'
,
topk
=
(
1
,
3
)),
dict
(
type
=
'SingleLabelMetric'
,
items
=
[
'precision'
])
],
'task11'
:
[
dict
(
type
=
'Accuracy'
,
topk
=
(
1
,
))]
})
]
}
def
test_evaluate
(
self
):
"""Test using the metric in the same way as Evalutor."""
# Test with score (use score instead of label if score exists)
metric
=
MultiTasksMetric
(
self
.
task_metrics
)
metric
.
process
(
None
,
self
.
preds
)
results
=
metric
.
evaluate
(
2
)
self
.
assertIsInstance
(
results
,
dict
)
self
.
assertAlmostEqual
(
results
[
'task0_accuracy/top1'
],
100
)
self
.
assertGreater
(
results
[
'task1_single-label/precision'
],
0
)
# Test nested
metric
=
MultiTasksMetric
(
self
.
task_metrics2
)
metric
.
process
(
None
,
self
.
pred2
)
results
=
metric
.
evaluate
(
2
)
self
.
assertIsInstance
(
results
,
dict
)
self
.
assertGreater
(
results
[
'task1_task10_single-label/precision'
],
0
)
self
.
assertGreater
(
results
[
'task1_task11_accuracy/top1'
],
0
)
# Test with without any ground truth value
metric
=
MultiTasksMetric
(
self
.
task_metrics
)
metric
.
process
(
None
,
self
.
pred3
)
results
=
metric
.
evaluate
(
2
)
self
.
assertIsInstance
(
results
,
dict
)
self
.
assertEqual
(
results
[
'task0_Accuracy'
],
0
)
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment