Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a7106c6b
Commit
a7106c6b
authored
May 21, 2022
by
yangwendi.vendor
Committed by
zhouzaida
Jul 19, 2022
Browse files
[fix]:fix type hint in transforms
parent
59eaefeb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
30 deletions
+67
-30
.gitlab-ci.yml
.gitlab-ci.yml
+37
-0
mmcv/transforms/base.py
mmcv/transforms/base.py
+5
-3
mmcv/transforms/processing.py
mmcv/transforms/processing.py
+13
-12
mmcv/transforms/wrappers.py
mmcv/transforms/wrappers.py
+12
-15
No files found.
.gitlab-ci.yml
0 → 100644
View file @
a7106c6b
variables
:
PYTORCH_IMAGE
:
registry.sensetime.com/openmmlab/pytorch18-cu102-mmengine:dev2
stages
:
-
linting
-
test
-
deploy
before_script
:
-
. /root/scripts/set_envs.sh
-
echo $PATH
-
gcc --version
-
nvcc --version
-
ruby --version
-
python --version
-
pip --version
-
python -c "import torch; print(torch.__version__)"
linting
:
image
:
$PYTORCH_IMAGE
stage
:
linting
script
:
-
pre-commit run --all-files
.test_template
:
&test_template_def
stage
:
test
script
:
-
echo "Start building..."
-
MMCV_WITH_OPS=1 pip install -e .[all] -i https://pypi.tuna.tsinghua.edu.cn/simple/
-
python -c "import mmcv; print(mmcv.__version__)"
-
echo "Start testing..."
-
coverage run --branch --source mmcv -m pytest tests/
-
coverage report -m
test:pytorch1.8-cuda10:
image
:
$PYTORCH_IMAGE
<<
:
*test_template_def
mmcv/transforms/base.py
View file @
a7106c6b
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
class
BaseTransform
(
metaclass
=
ABCMeta
):
def
__call__
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
def
__call__
(
self
,
results
:
Dict
)
->
Optional
[
Union
[
Dict
,
Tuple
[
List
,
List
]]]:
return
self
.
transform
(
results
)
@
abstractmethod
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Union
[
Dict
,
Tuple
[
List
,
List
]]]:
"""The transform function. All subclass of BaseTransform should
override this method.
...
...
mmcv/transforms/processing.py
View file @
a7106c6b
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
warnings
from
typing
import
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -800,10 +800,12 @@ class MultiScaleFlipAug(BaseTransform):
else
:
# if ``scales`` and ``scale_factor`` both be ``None``
if
scale_factor
is
None
:
self
.
scales
=
[
1.
]
self
.
scales
=
[
1.
]
# type: ignore
elif
isinstance
(
scale_factor
,
list
):
self
.
scales
=
scale_factor
# type: ignore
else
:
self
.
scales
=
scale_factor
if
isinstance
(
scale_factor
,
list
)
else
[
scale_factor
]
self
.
scales
=
[
scale_factor
]
# type: ignore
self
.
scale_key
=
'scale_factor'
self
.
allow_flip
=
allow_flip
...
...
@@ -816,7 +818,7 @@ class MultiScaleFlipAug(BaseTransform):
self
.
resize_cfg
=
resize_cfg
.
copy
()
self
.
flip_cfg
=
flip_cfg
def
transform
(
self
,
results
:
dict
)
->
Tuple
[
List
,
List
]
:
def
transform
(
self
,
results
:
dict
)
->
Dict
:
"""Apply test time augment transforms on results.
Args:
...
...
@@ -848,12 +850,12 @@ class MultiScaleFlipAug(BaseTransform):
results
[
'flip_direction'
]
=
None
resize_flip
=
Compose
(
_resize_flip
)
_results
=
results
.
copy
()
_results
=
resize_flip
(
_results
)
packed_results
=
self
.
transforms
(
_results
)
_results
=
resize_flip
(
results
.
copy
())
packed_results
=
self
.
transforms
(
_results
)
# type: ignore
inputs
.
append
(
packed_results
[
'inputs'
])
data_samples
.
append
(
packed_results
[
'data_sample'
])
inputs
.
append
(
packed_results
[
'inputs'
])
# type: ignore
data_samples
.
append
(
packed_results
[
'data_sample'
])
# type: ignore
return
dict
(
inputs
=
inputs
,
data_sample
=
data_samples
)
def
__repr__
(
self
)
->
str
:
...
...
@@ -1312,8 +1314,7 @@ class RandomResize(BaseTransform):
if
isinstance
(
self
.
scale
,
tuple
):
assert
self
.
ratio_range
is
not
None
and
len
(
self
.
ratio_range
)
==
2
scale
:
Tuple
[
int
,
int
]
=
self
.
_random_sample_ratio
(
self
.
scale
,
self
.
ratio_range
)
scale
=
self
.
_random_sample_ratio
(
self
.
scale
,
self
.
ratio_range
)
elif
mmcv
.
is_list_of
(
self
.
scale
,
tuple
):
scale
=
self
.
_random_sample
(
self
.
scale
)
else
:
...
...
mmcv/transforms/wrappers.py
View file @
a7106c6b
# Copyright (c) OpenMMLab. All rights reserved.
from
collections.abc
import
Sequence
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Union
import
numpy
as
np
...
...
@@ -25,7 +24,7 @@ IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try
:
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
# type: ignore
except
ImportError
:
from
contextlib
import
contextmanager
...
...
@@ -55,10 +54,10 @@ class Compose(BaseTransform):
>>> ]
"""
def
__init__
(
self
,
transforms
:
Union
[
Transform
,
List
[
Transform
]]):
def
__init__
(
self
,
transforms
:
Union
[
Transform
,
Sequence
[
Transform
]]):
super
().
__init__
()
if
not
isinstance
(
transforms
,
list
):
if
not
isinstance
(
transforms
,
Sequence
):
transforms
=
[
transforms
]
self
.
transforms
:
List
=
[]
for
transform
in
transforms
:
...
...
@@ -85,7 +84,7 @@ class Compose(BaseTransform):
dict or None: Transformed results.
"""
for
t
in
self
.
transforms
:
results
=
t
(
results
)
results
=
t
(
results
)
# type: ignore
if
results
is
None
:
return
None
return
results
...
...
@@ -331,7 +330,7 @@ class KeyMapper(BaseTransform):
# Apply remapping
outputs
=
self
.
_map_output
(
outputs
,
self
.
remapping
)
results
.
update
(
outputs
)
results
.
update
(
outputs
)
# type: ignore
return
results
...
...
@@ -445,8 +444,7 @@ class TransformBroadcaster(KeyMapper):
def
scatter_sequence
(
self
,
data
:
Dict
)
->
List
[
Dict
]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms.
"""
transforms."""
# infer split number from input
seq_len
=
0
...
...
@@ -458,7 +456,6 @@ class TransformBroadcaster(KeyMapper):
keys
=
data
.
keys
()
for
key
in
keys
:
assert
isinstance
(
data
[
key
],
Sequence
)
if
seq_len
:
if
len
(
data
[
key
])
!=
seq_len
:
...
...
@@ -472,7 +469,7 @@ class TransformBroadcaster(KeyMapper):
assert
seq_len
>
0
,
'Fail to get the number of broadcasting targets'
scatters
=
[]
for
i
in
range
(
seq_len
):
for
i
in
range
(
seq_len
):
# type: ignore
scatter
=
data
.
copy
()
for
key
in
keys
:
scatter
[
key
]
=
data
[
key
][
i
]
...
...
@@ -494,7 +491,7 @@ class TransformBroadcaster(KeyMapper):
# cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared
# by all data items.
ctx
=
cache_random_params
ctx
=
cache_random_params
# type: ignore
else
:
ctx
=
nullcontext
# type: ignore
...
...
@@ -602,13 +599,13 @@ class RandomApply(BaseTransform):
@
cache_randomness
def
random_apply
(
self
)
->
bool
:
"""Return a random bool value indicating whether apply the
transform.
"""
"""Return a random bool value indicating whether apply the
transform.
"""
return
np
.
random
.
rand
()
<
self
.
prob
def
transform
(
self
,
results
:
Dict
)
->
Optional
[
Dict
]:
"""Randomly apply the transform."""
if
self
.
random_apply
():
return
self
.
transforms
(
results
)
return
self
.
transforms
(
results
)
# type: ignore
else
:
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment