Unverified Commit f7caa80f authored by Junjun2016's avatar Junjun2016 Committed by GitHub
Browse files

[Enhancement] Add to_ntuple (#1125)

* add to_ntuple

* add unit test
parent f71e47c2
...@@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction ...@@ -4,7 +4,8 @@ from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning, from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
import_modules_from_strings, is_list_of, is_seq_of, is_str, import_modules_from_strings, is_list_of, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, requires_executable, is_tuple_of, iter_cast, list_cast, requires_executable,
requires_package, slice_list, tuple_cast) requires_package, slice_list, to_1tuple, to_2tuple,
to_3tuple, to_4tuple, to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink) scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress, from .progressbar import (ProgressBar, track_iter_progress,
...@@ -29,17 +30,18 @@ except ImportError: ...@@ -29,17 +30,18 @@ except ImportError:
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings', 'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script' 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple'
] ]
else: else:
from .env import collect_env from .env import collect_env
from .logging import get_logger, print_log from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import ( from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .parrots_jit import jit, skip_no_elena
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
__all__ = [ __all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import collections.abc
import functools import functools
import itertools import itertools
import subprocess import subprocess
...@@ -6,6 +7,25 @@ import warnings ...@@ -6,6 +7,25 @@ import warnings
from collections import abc from collections import abc
from importlib import import_module from importlib import import_module
from inspect import getfullargspec from inspect import getfullargspec
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def is_str(x): def is_str(x):
......
...@@ -4,6 +4,31 @@ import pytest ...@@ -4,6 +4,31 @@ import pytest
import mmcv import mmcv
def test_to_ntuple():
single_number = 2
assert mmcv.utils.to_1tuple(single_number) == (single_number, )
assert mmcv.utils.to_2tuple(single_number) == (single_number,
single_number)
assert mmcv.utils.to_3tuple(single_number) == (single_number,
single_number,
single_number)
assert mmcv.utils.to_4tuple(single_number) == (single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(5)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(6)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number,
single_number)
def test_iter_cast(): def test_iter_cast():
assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3] assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
...@@ -105,6 +130,7 @@ def test_requires_executable(capsys): ...@@ -105,6 +130,7 @@ def test_requires_executable(capsys):
def test_import_modules_from_strings(): def test_import_modules_from_strings():
# multiple imports # multiple imports
import os.path as osp_ import os.path as osp_
import sys as sys_ import sys as sys_
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys']) osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
assert osp == osp_ assert osp == osp_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment