Unverified Commit 5947178e authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove many functions in utils and migrate them to mmengine (#2217)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs

* migrate many functions to mmengine

* remove sync_bn.py
parent 9185eee8
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import mmcv
from mmcv import deprecated_api_warning
from mmcv.utils.misc import has_method
def test_to_ntuple():
single_number = 2
assert mmcv.utils.to_1tuple(single_number) == (single_number, )
assert mmcv.utils.to_2tuple(single_number) == (single_number,
single_number)
assert mmcv.utils.to_3tuple(single_number) == (single_number,
single_number,
single_number)
assert mmcv.utils.to_4tuple(single_number) == (single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(5)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number)
assert mmcv.utils.to_ntuple(6)(single_number) == (single_number,
single_number,
single_number,
single_number,
single_number,
single_number)
def test_iter_cast():
assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3]
assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0]
assert mmcv.list_cast([1, 2, 3], str) == ['1', '2', '3']
assert mmcv.tuple_cast((1, 2, 3), str) == ('1', '2', '3')
assert next(mmcv.iter_cast([1, 2, 3], str)) == '1'
with pytest.raises(TypeError):
mmcv.iter_cast([1, 2, 3], '')
with pytest.raises(TypeError):
mmcv.iter_cast(1, str)
def test_is_seq_of():
assert mmcv.is_seq_of([1.0, 2.0, 3.0], float)
assert mmcv.is_seq_of([(1, ), (2, ), (3, )], tuple)
assert mmcv.is_seq_of((1.0, 2.0, 3.0), float)
assert mmcv.is_list_of([1.0, 2.0, 3.0], float)
assert not mmcv.is_seq_of((1.0, 2.0, 3.0), float, seq_type=list)
assert not mmcv.is_tuple_of([1.0, 2.0, 3.0], float)
assert not mmcv.is_seq_of([1.0, 2, 3], int)
assert not mmcv.is_seq_of((1.0, 2, 3), int)
def test_slice_list():
in_list = [1, 2, 3, 4, 5, 6]
assert mmcv.slice_list(in_list, [1, 2, 3]) == [[1], [2, 3], [4, 5, 6]]
assert mmcv.slice_list(in_list, [len(in_list)]) == [in_list]
with pytest.raises(TypeError):
mmcv.slice_list(in_list, 2.0)
with pytest.raises(ValueError):
mmcv.slice_list(in_list, [1, 2])
def test_concat_list():
assert mmcv.concat_list([[1, 2]]) == [1, 2]
assert mmcv.concat_list([[1, 2], [3, 4, 5], [6]]) == [1, 2, 3, 4, 5, 6]
def test_requires_package(capsys):
@mmcv.requires_package('nnn')
def func_a():
pass
@mmcv.requires_package(['numpy', 'n1', 'n2'])
def func_b():
pass
@mmcv.requires_package('numpy')
def func_c():
return 1
with pytest.raises(RuntimeError):
func_a()
out, _ = capsys.readouterr()
assert out == ('Prerequisites "nnn" are required in method "func_a" but '
'not found, please install them first.\n')
with pytest.raises(RuntimeError):
func_b()
out, _ = capsys.readouterr()
assert out == (
'Prerequisites "n1, n2" are required in method "func_b" but not found,'
' please install them first.\n')
assert func_c() == 1
def test_requires_executable(capsys):
@mmcv.requires_executable('nnn')
def func_a():
pass
@mmcv.requires_executable(['ls', 'n1', 'n2'])
def func_b():
pass
@mmcv.requires_executable('mv')
def func_c():
return 1
with pytest.raises(RuntimeError):
func_a()
out, _ = capsys.readouterr()
assert out == ('Prerequisites "nnn" are required in method "func_a" but '
'not found, please install them first.\n')
with pytest.raises(RuntimeError):
func_b()
out, _ = capsys.readouterr()
assert out == (
'Prerequisites "n1, n2" are required in method "func_b" but not found,'
' please install them first.\n')
assert func_c() == 1
def test_import_modules_from_strings():
# multiple imports
import os.path as osp_
import sys as sys_
osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys'])
assert osp == osp_
assert sys == sys_
# single imports
osp = mmcv.import_modules_from_strings('os.path')
assert osp == osp_
# No imports
assert mmcv.import_modules_from_strings(None) is None
assert mmcv.import_modules_from_strings([]) is None
assert mmcv.import_modules_from_strings('') is None
# Unsupported types
with pytest.raises(TypeError):
mmcv.import_modules_from_strings(1)
with pytest.raises(TypeError):
mmcv.import_modules_from_strings([1])
# Failed imports
with pytest.raises(ImportError):
mmcv.import_modules_from_strings('_not_implemented_module')
with pytest.warns(UserWarning):
imported = mmcv.import_modules_from_strings(
'_not_implemented_module', allow_failed_imports=True)
assert imported is None
with pytest.warns(UserWarning):
imported = mmcv.import_modules_from_strings(
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None
def test_is_method_overridden():
class Base:
def foo1():
pass
def foo2():
pass
class Sub(Base):
def foo1():
pass
# test passing sub class directly
assert mmcv.is_method_overridden('foo1', Base, Sub)
assert not mmcv.is_method_overridden('foo2', Base, Sub)
# test passing instance of sub class
sub_instance = Sub()
assert mmcv.is_method_overridden('foo1', Base, sub_instance)
assert not mmcv.is_method_overridden('foo2', Base, sub_instance)
# base_class should be a class, not instance
base_instance = Base()
with pytest.raises(AssertionError):
mmcv.is_method_overridden('foo1', base_instance, sub_instance)
def test_has_method():
class Foo:
def __init__(self, name):
self.name = name
def print_name(self):
print(self.name)
foo = Foo('foo')
assert not has_method(foo, 'name')
assert has_method(foo, 'print_name')
def test_deprecated_api_warning():
@deprecated_api_warning(name_dict=dict(old_key='new_key'))
def dummy_func(new_key=1):
return new_key
# replace `old_key` to `new_key`
assert dummy_func(old_key=2) == 2
# The expected behavior is to replace the
# deprecated key `old_key` to `new_key`,
# but got them in the arguments at the same time
with pytest.raises(AssertionError):
dummy_func(old_key=1, new_key=2)
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmengine.utils import TORCH_VERSION
import mmcv
from mmcv.utils import TORCH_VERSION
pytest.skip('this test not ready now', allow_module_level=True)
skip_no_parrots = pytest.mark.skipif(
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
import pytest
import mmcv
def test_is_filepath():
assert mmcv.is_filepath(__file__)
assert mmcv.is_filepath('abc')
assert mmcv.is_filepath(Path('/etc'))
assert not mmcv.is_filepath(0)
def test_fopen():
assert hasattr(mmcv.fopen(__file__), 'read')
assert hasattr(mmcv.fopen(Path(__file__)), 'read')
def test_check_file_exist():
mmcv.check_file_exist(__file__)
with pytest.raises(FileNotFoundError):
mmcv.check_file_exist('no_such_file.txt')
def test_scandir():
folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan')
filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT']
assert set(mmcv.scandir(folder)) == set(filenames)
assert set(mmcv.scandir(Path(folder))) == set(filenames)
assert set(mmcv.scandir(folder, '.txt')) == {
filename
for filename in filenames if filename.endswith('.txt')
}
assert set(mmcv.scandir(folder, ('.json', '.txt'))) == {
filename
for filename in filenames if filename.endswith(('.txt', '.json'))
}
assert set(mmcv.scandir(folder, '.png')) == set()
# path of sep is `\\` in windows but `/` in linux, so osp.join should be
# used to join string for compatibility
filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT',
osp.join('sub', '1.json'),
osp.join('sub', '1.txt'), '.file'
]
# .file starts with '.' and is a file so it will not be scanned
assert set(mmcv.scandir(folder, recursive=True)) == {
filename
for filename in filenames_recursive if filename != '.file'
}
assert set(mmcv.scandir(Path(folder), recursive=True)) == {
filename
for filename in filenames_recursive if filename != '.file'
}
assert set(mmcv.scandir(folder, '.txt', recursive=True)) == {
filename
for filename in filenames_recursive if filename.endswith('.txt')
}
assert set(
mmcv.scandir(folder, '.TXT', recursive=True,
case_sensitive=False)) == {
filename
for filename in filenames_recursive
if filename.endswith(('.txt', '.TXT'))
}
assert set(
mmcv.scandir(
folder, ('.TXT', '.JSON'), recursive=True,
case_sensitive=False)) == {
filename
for filename in filenames_recursive
if filename.endswith(('.txt', '.json', '.TXT'))
}
with pytest.raises(TypeError):
list(mmcv.scandir(123))
with pytest.raises(TypeError):
list(mmcv.scandir(folder, 111))
# Copyright (c) OpenMMLab. All rights reserved.
import os
import time
from io import StringIO
from unittest.mock import patch
import mmcv
def reset_string_io(io):
io.truncate(0)
io.seek(0)
class TestProgressBar:
def test_start(self):
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
assert out.getvalue() == 'completed: 0, elapsed: 0s'
reset_string_io(out)
prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
prog_bar.start()
assert out.getvalue() == 'completed: 0, elapsed: 0s'
# with total task num
reset_string_io(out)
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
reset_string_io(out)
prog_bar = mmcv.ProgressBar(
10, bar_width=bar_width, start=False, file=out)
assert out.getvalue() == ''
reset_string_io(out)
prog_bar.start()
assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:'
def test_update(self):
out = StringIO()
bar_width = 20
# without total task num
prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s'
reset_string_io(out)
# with total task num
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \
'task/s, elapsed: 1s, ETA: 9s'
def test_adaptive_length(self):
with patch.dict('os.environ', {'COLUMNS': '80'}):
out = StringIO()
bar_width = 20
prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out)
time.sleep(1)
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 66
os.environ['COLUMNS'] = '30'
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 48
os.environ['COLUMNS'] = '60'
reset_string_io(out)
prog_bar.update()
assert len(out.getvalue()) == 60
def sleep_1s(num):
time.sleep(1)
return num
def test_track_progress_list():
out = StringIO()
ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_progress_iterator():
out = StringIO()
ret = mmcv.track_progress(
sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_iter_progress():
out = StringIO()
ret = []
for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out):
ret.append(sleep_1s(num))
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
def test_track_enum_progress():
out = StringIO()
ret = []
count = []
for i, num in enumerate(
mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)):
ret.append(sleep_1s(num))
count.append(i)
assert out.getvalue() == (
'[ ] 0/3, elapsed: 0s, ETA:'
'\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s'
'\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s'
'\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n')
assert ret == [1, 2, 3]
assert count == [0, 1, 2]
def test_track_parallel_progress_list():
out = StringIO()
results = mmcv.track_parallel_progress(
sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
# '[ ] 0/4, elapsed: 0s, ETA:'
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
def test_track_parallel_progress_iterator():
out = StringIO()
results = mmcv.track_parallel_progress(
sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out)
# The following cannot pass CI on Github Action
# assert out.getvalue() == (
# '[ ] 0/4, elapsed: 0s, ETA:'
# '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s'
# '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s'
# '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s'
# '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n')
assert results == [1, 2, 3, 4]
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import mmcv
def test_registry():
CATS = mmcv.Registry('cat')
assert CATS.name == 'cat'
assert CATS.module_dict == {}
assert len(CATS) == 0
@CATS.register_module()
class BritishShorthair:
pass
assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair
class Munchkin:
pass
CATS.register_module(Munchkin)
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin
assert 'Munchkin' in CATS
with pytest.raises(KeyError):
CATS.register_module(Munchkin)
CATS.register_module(Munchkin, force=True)
assert len(CATS) == 2
# force=False
with pytest.raises(KeyError):
@CATS.register_module()
class BritishShorthair:
pass
@CATS.register_module(force=True)
class BritishShorthair:
pass
assert len(CATS) == 2
assert CATS.get('PersianCat') is None
assert 'PersianCat' not in CATS
@CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat:
pass
assert CATS.get('Siamese').__name__ == 'SiameseCat'
assert CATS.get('Siamese2').__name__ == 'SiameseCat'
class SphynxCat:
pass
CATS.register_module(name='Sphynx', module=SphynxCat)
assert CATS.get('Sphynx') is SphynxCat
CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
repr_str = 'Registry(name=cat, items={'
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
"<locals>.BritishShorthair'>, ")
repr_str += ("'Munchkin': <class 'test_registry.test_registry."
"<locals>.Munchkin'>, ")
repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Siamese2': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>")
repr_str += '})'
assert repr(CATS) == repr_str
# name type
with pytest.raises(TypeError):
CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)
@CATS.register_module()
def muchkin():
pass
assert CATS.get('muchkin') is muchkin
assert 'muchkin' in CATS
# can only decorate a class or a function
with pytest.raises(TypeError):
class Demo:
def some_method(self):
pass
method = Demo().some_method
CATS.register_module(name='some_method', module=method)
# begin: test old APIs
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.register_module
class NewCat:
pass
assert CATS.get('NewCat').__name__ == 'NewCat'
with pytest.warns(DeprecationWarning):
CATS.deprecated_register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module
class CuteCat:
pass
assert CATS.get('CuteCat').__name__ == 'CuteCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module(force=True)
class NewCat2:
pass
assert CATS.get('NewCat2').__name__ == 'NewCat2'
# end: test old APIs
def test_multi_scope_registry():
DOGS = mmcv.Registry('dogs')
assert DOGS.name == 'dogs'
assert DOGS.scope == 'test_registry'
assert DOGS.module_dict == {}
assert len(DOGS) == 0
@DOGS.register_module()
class GoldenRetriever:
pass
assert len(DOGS) == 1
assert DOGS.get('GoldenRetriever') is GoldenRetriever
HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound')
@HOUNDS.register_module()
class BloodHound:
pass
assert len(HOUNDS) == 1
assert HOUNDS.get('BloodHound') is BloodHound
assert DOGS.get('hound.BloodHound') is BloodHound
assert HOUNDS.get('hound.BloodHound') is BloodHound
LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound')
@LITTLE_HOUNDS.register_module()
class Dachshund:
pass
assert len(LITTLE_HOUNDS) == 1
assert LITTLE_HOUNDS.get('Dachshund') is Dachshund
assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound
assert HOUNDS.get('little_hound.Dachshund') is Dachshund
assert DOGS.get('hound.little_hound.Dachshund') is Dachshund
MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound')
@MID_HOUNDS.register_module()
class Beagle:
pass
assert MID_HOUNDS.get('Beagle') is Beagle
assert HOUNDS.get('mid_hound.Beagle') is Beagle
assert DOGS.get('hound.mid_hound.Beagle') is Beagle
assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle
assert MID_HOUNDS.get('hound.BloodHound') is BloodHound
assert MID_HOUNDS.get('hound.Dachshund') is None
def test_build_from_cfg():
BACKBONES = mmcv.Registry('backbone')
@BACKBONES.register_module()
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module()
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = dict(type=ResNet, depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# type defined using default_args
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(type='ResNet'))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# not a registry
with pytest.raises(TypeError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# non-registered class
with pytest.raises(KeyError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, BACKBONES)
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg['type'] should be a str or class
with pytest.raises(TypeError):
cfg = dict(type=1000)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(stages=4))
# incorrect registry type
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# incorrect default_args type
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
# incorrect arguments
with pytest.raises(TypeError):
cfg = dict(type='ResNet', non_existing_arg=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import mmcv
try:
import torch
except ImportError:
torch = None
else:
import torch.nn as nn
def test_assert_dict_contains_subset():
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)}
# case 1
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)}
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 2
expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 3
expected_subset = {'a': 'test1', 'b': 2, 'c': None}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 4
expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 5
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [1, 2, 3]])
}
expected_subset = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': np.array([[5, 3, 5], [6, 2, 3]])
}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 6
dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])}
assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
if torch is not None:
dict_obj = {
'a': 'test1',
'b': 2,
'c': (4, 6),
'd': torch.tensor([5, 3, 5])
}
# case 7
expected_subset = {'d': torch.tensor([5, 5, 5])}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
# case 8
expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])}
assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset)
def test_assert_attrs_equal():
class TestExample:
a, b, c = 1, ('wvi', 3), [4.5, 3.14]
def test_func(self):
return self.b
# case 1
assert mmcv.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14]
})
# case 2
assert not mmcv.assert_attrs_equal(TestExample, {
'a': 1,
'b': ('wvi', 3),
'c': [4.5, 3.14, 2]
})
# case 3
assert not mmcv.assert_attrs_equal(TestExample, {
'bc': 54,
'c': [4.5, 3.14]
})
# case 4
assert mmcv.assert_attrs_equal(TestExample, {
'b': ('wvi', 3),
'test_func': TestExample.test_func
})
if torch is not None:
class TestExample:
a, b = torch.tensor([1]), torch.tensor([4, 5])
# case 5
assert mmcv.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 5])
})
# case 6
assert not mmcv.assert_attrs_equal(TestExample, {
'a': torch.tensor([1]),
'b': torch.tensor([4, 6])
})
assert_dict_has_keys_data_1 = [({
'res_layer': 1,
'norm_layer': 2,
'dense_layer': 3
})]
assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True),
(['res_layer', 'conv_layer'], False)]
@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1)
@pytest.mark.parametrize('expected_keys, ret_value',
assert_dict_has_keys_data_2)
def test_assert_dict_has_keys(obj, expected_keys, ret_value):
assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value
assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])]
assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True),
(['res_layer', 'dense_layer', 'norm_layer'], True),
(['res_layer', 'norm_layer'], False),
(['res_layer', 'conv_layer', 'norm_layer'], False)]
@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1)
@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2)
def test_assert_keys_equal(result_keys, target_keys, ret_value):
assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_is_norm_layer():
# case 1
assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3))
# case 2
assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128))
# case 3
assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64))
# case 4
assert not mmcv.assert_is_norm_layer(nn.Sigmoid())
@pytest.mark.skipif(torch is None, reason='requires torch library')
def test_assert_params_all_zeros():
demo_module = nn.Conv2d(3, 64, 3)
nn.init.constant_(demo_module.weight, 0)
nn.init.constant_(demo_module.bias, 0)
assert mmcv.assert_params_all_zeros(demo_module)
nn.init.xavier_normal_(demo_module.weight)
nn.init.constant_(demo_module.bias, 0)
assert not mmcv.assert_params_all_zeros(demo_module)
demo_module = nn.Linear(2048, 400, bias=False)
nn.init.constant_(demo_module.weight, 0)
assert mmcv.assert_params_all_zeros(demo_module)
nn.init.normal_(demo_module.weight, mean=0, std=0.01)
assert not mmcv.assert_params_all_zeros(demo_module)
def test_check_python_script(capsys):
mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz')
captured = capsys.readouterr().out
assert captured == 'hello zz!\n'
mmcv.utils.check_python_script('./tests/data/scripts/hello.py agent')
captured = capsys.readouterr().out
assert captured == 'hello agent!\n'
# Make sure that wrong cmd raises an error
with pytest.raises(SystemExit):
mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz')
# Copyright (c) OpenMMLab. All rights reserved.
import time
import pytest
import mmcv
def test_timer_init():
timer = mmcv.Timer(start=False)
assert not timer.is_running
timer.start()
assert timer.is_running
timer = mmcv.Timer()
assert timer.is_running
def test_timer_run():
timer = mmcv.Timer()
time.sleep(1)
assert abs(timer.since_start() - 1) < 1e-2
time.sleep(1)
assert abs(timer.since_last_check() - 1) < 1e-2
assert abs(timer.since_start() - 2) < 1e-2
timer = mmcv.Timer(False)
with pytest.raises(mmcv.TimerError):
timer.since_start()
with pytest.raises(mmcv.TimerError):
timer.since_last_check()
def test_timer_context(capsys):
with mmcv.Timer():
time.sleep(1)
out, _ = capsys.readouterr()
assert abs(float(out) - 1) < 1e-2
with mmcv.Timer(print_tmpl='time: {:.1f}s'):
time.sleep(1)
out, _ = capsys.readouterr()
assert out == 'time: 1.0s\n'
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import torch_meshgrid
def test_torch_meshgrid():
# torch_meshgrid should not throw warning
with pytest.warns(None) as record:
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
grid_x, grid_y = torch_meshgrid(x, y)
assert len(record) == 0
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import digit_version, is_jit_tracing
@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():
def foo(x):
if is_jit_tracing():
return x
else:
return x.tolist()
x = torch.rand(3)
# test without trace
assert isinstance(foo(x), list)
# test with trace
traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
assert isinstance(traced_foo(x), torch.Tensor)
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
from mmcv import get_git_hash, parse_version_info
from mmcv.utils import digit_version
def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
with pytest.raises(AssertionError):
digit_version('a')
with pytest.raises(AssertionError):
digit_version('1x')
with pytest.raises(AssertionError):
digit_version('1.x')
def test_parse_version_info():
assert parse_version_info('0.2.16') == (0, 2, 16, 0, 0, 0)
assert parse_version_info('1.2.3') == (1, 2, 3, 0, 0, 0)
assert parse_version_info('1.2.3rc0') == (1, 2, 3, 0, 'rc', 0)
assert parse_version_info('1.2.3rc1') == (1, 2, 3, 0, 'rc', 1)
assert parse_version_info('1.0rc0') == (1, 0, 0, 0, 'rc', 0)
def _mock_cmd_success(cmd):
return b'3b46d33e90c397869ad5103075838fdfc9812aa0'
def _mock_cmd_fail(cmd):
raise OSError
def test_get_git_hash():
with patch('mmcv.utils.version_utils._minimal_ext_cmd', _mock_cmd_success):
assert get_git_hash() == '3b46d33e90c397869ad5103075838fdfc9812aa0'
assert get_git_hash(digits=6) == '3b46d3'
assert get_git_hash(digits=100) == get_git_hash()
with patch('mmcv.utils.version_utils._minimal_ext_cmd', _mock_cmd_fail):
assert get_git_hash() == 'unknown'
assert get_git_hash(fallback='n/a') == 'n/a'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment