Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
f7caa80f
Unverified
Commit
f7caa80f
authored
Jun 23, 2021
by
Junjun2016
Committed by
GitHub
Jun 23, 2021
Browse files
[Enhancement] Add to_ntuple (#1125)
* add to_ntuple * add unit test
parent
f71e47c2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
3 deletions
+51
-3
mmcv/utils/__init__.py
mmcv/utils/__init__.py
+5
-3
mmcv/utils/misc.py
mmcv/utils/misc.py
+20
-0
tests/test_utils/test_misc.py
tests/test_utils/test_misc.py
+26
-0
No files found.
mmcv/utils/__init__.py
View file @
f7caa80f
...
@@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction
...
@@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction
from
.misc
import
(
check_prerequisites
,
concat_list
,
deprecated_api_warning
,
from
.misc
import
(
check_prerequisites
,
concat_list
,
deprecated_api_warning
,
import_modules_from_strings
,
is_list_of
,
is_seq_of
,
is_str
,
import_modules_from_strings
,
is_list_of
,
is_seq_of
,
is_str
,
is_tuple_of
,
iter_cast
,
list_cast
,
requires_executable
,
is_tuple_of
,
iter_cast
,
list_cast
,
requires_executable
,
requires_package
,
slice_list
,
tuple_cast
)
requires_package
,
slice_list
,
to_1tuple
,
to_2tuple
,
to_3tuple
,
to_4tuple
,
to_ntuple
,
tuple_cast
)
from
.path
import
(
check_file_exist
,
fopen
,
is_filepath
,
mkdir_or_exist
,
from
.path
import
(
check_file_exist
,
fopen
,
is_filepath
,
mkdir_or_exist
,
scandir
,
symlink
)
scandir
,
symlink
)
from
.progressbar
import
(
ProgressBar
,
track_iter_progress
,
from
.progressbar
import
(
ProgressBar
,
track_iter_progress
,
...
@@ -29,17 +30,18 @@ except ImportError:
...
@@ -29,17 +30,18 @@ except ImportError:
'Timer'
,
'TimerError'
,
'check_time'
,
'deprecated_api_warning'
,
'Timer'
,
'TimerError'
,
'check_time'
,
'deprecated_api_warning'
,
'digit_version'
,
'get_git_hash'
,
'import_modules_from_strings'
,
'digit_version'
,
'get_git_hash'
,
'import_modules_from_strings'
,
'assert_dict_contains_subset'
,
'assert_attrs_equal'
,
'assert_dict_contains_subset'
,
'assert_attrs_equal'
,
'assert_dict_has_keys'
,
'assert_keys_equal'
,
'check_python_script'
'assert_dict_has_keys'
,
'assert_keys_equal'
,
'check_python_script'
,
'to_1tuple'
,
'to_2tuple'
,
'to_3tuple'
,
'to_4tuple'
,
'to_ntuple'
]
]
else
:
else
:
from
.env
import
collect_env
from
.env
import
collect_env
from
.logging
import
get_logger
,
print_log
from
.logging
import
get_logger
,
print_log
from
.parrots_jit
import
jit
,
skip_no_elena
from
.parrots_wrapper
import
(
from
.parrots_wrapper
import
(
CUDA_HOME
,
TORCH_VERSION
,
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
,
TORCH_VERSION
,
BuildExtension
,
CppExtension
,
CUDAExtension
,
DataLoader
,
PoolDataLoader
,
SyncBatchNorm
,
_AdaptiveAvgPoolNd
,
DataLoader
,
PoolDataLoader
,
SyncBatchNorm
,
_AdaptiveAvgPoolNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_BatchNorm
,
_ConvNd
,
_AdaptiveMaxPoolNd
,
_AvgPoolNd
,
_BatchNorm
,
_ConvNd
,
_ConvTransposeMixin
,
_InstanceNorm
,
_MaxPoolNd
,
get_build_config
)
_ConvTransposeMixin
,
_InstanceNorm
,
_MaxPoolNd
,
get_build_config
)
from
.parrots_jit
import
jit
,
skip_no_elena
from
.registry
import
Registry
,
build_from_cfg
from
.registry
import
Registry
,
build_from_cfg
__all__
=
[
__all__
=
[
'Config'
,
'ConfigDict'
,
'DictAction'
,
'collect_env'
,
'get_logger'
,
'Config'
,
'ConfigDict'
,
'DictAction'
,
'collect_env'
,
'get_logger'
,
...
...
mmcv/utils/misc.py
View file @
f7caa80f
# Copyright (c) Open-MMLab. All rights reserved.
# Copyright (c) Open-MMLab. All rights reserved.
import
collections.abc
import
functools
import
functools
import
itertools
import
itertools
import
subprocess
import
subprocess
...
@@ -6,6 +7,25 @@ import warnings
...
@@ -6,6 +7,25 @@ import warnings
from
collections
import
abc
from
collections
import
abc
from
importlib
import
import_module
from
importlib
import
import_module
from
inspect
import
getfullargspec
from
inspect
import
getfullargspec
from
itertools
import
repeat
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_1tuple
=
_ntuple
(
1
)
to_2tuple
=
_ntuple
(
2
)
to_3tuple
=
_ntuple
(
3
)
to_4tuple
=
_ntuple
(
4
)
to_ntuple
=
_ntuple
def
is_str
(
x
):
def
is_str
(
x
):
...
...
tests/test_utils/test_misc.py
View file @
f7caa80f
...
@@ -4,6 +4,31 @@ import pytest
...
@@ -4,6 +4,31 @@ import pytest
import
mmcv
import
mmcv
def
test_to_ntuple
():
single_number
=
2
assert
mmcv
.
utils
.
to_1tuple
(
single_number
)
==
(
single_number
,
)
assert
mmcv
.
utils
.
to_2tuple
(
single_number
)
==
(
single_number
,
single_number
)
assert
mmcv
.
utils
.
to_3tuple
(
single_number
)
==
(
single_number
,
single_number
,
single_number
)
assert
mmcv
.
utils
.
to_4tuple
(
single_number
)
==
(
single_number
,
single_number
,
single_number
,
single_number
)
assert
mmcv
.
utils
.
to_ntuple
(
5
)(
single_number
)
==
(
single_number
,
single_number
,
single_number
,
single_number
,
single_number
)
assert
mmcv
.
utils
.
to_ntuple
(
6
)(
single_number
)
==
(
single_number
,
single_number
,
single_number
,
single_number
,
single_number
,
single_number
)
def
test_iter_cast
():
def
test_iter_cast
():
assert
mmcv
.
list_cast
([
1
,
2
,
3
],
int
)
==
[
1
,
2
,
3
]
assert
mmcv
.
list_cast
([
1
,
2
,
3
],
int
)
==
[
1
,
2
,
3
]
assert
mmcv
.
list_cast
([
'1.1'
,
2
,
'3'
],
float
)
==
[
1.1
,
2.0
,
3.0
]
assert
mmcv
.
list_cast
([
'1.1'
,
2
,
'3'
],
float
)
==
[
1.1
,
2.0
,
3.0
]
...
@@ -105,6 +130,7 @@ def test_requires_executable(capsys):
...
@@ -105,6 +130,7 @@ def test_requires_executable(capsys):
def
test_import_modules_from_strings
():
def
test_import_modules_from_strings
():
# multiple imports
# multiple imports
import
os.path
as
osp_
import
os.path
as
osp_
import
sys
as
sys_
import
sys
as
sys_
osp
,
sys
=
mmcv
.
import_modules_from_strings
([
'os.path'
,
'sys'
])
osp
,
sys
=
mmcv
.
import_modules_from_strings
([
'os.path'
,
'sys'
])
assert
osp
==
osp_
assert
osp
==
osp_
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment