Unverified Commit 0fe1c647 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove fileio from mmcv and use mmengine.fileio instead (#2179)

parent 0b4285d9
...@@ -5,10 +5,10 @@ from math import inf ...@@ -5,10 +5,10 @@ from math import inf
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch.distributed as dist import torch.distributed as dist
from mmengine.fileio import FileClient
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.fileio import FileClient
from mmcv.utils import is_seq_of from mmcv.utils import is_seq_of
from .hook import Hook from .hook import Hook
from .logger import LoggerHook from .logger import LoggerHook
...@@ -61,7 +61,7 @@ class EvalHook(Hook): ...@@ -61,7 +61,7 @@ class EvalHook(Hook):
level directory of `runner.work_dir`. level directory of `runner.work_dir`.
`New in version 1.3.16.` `New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None. See :class:`mmengine.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.` `New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of **eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset. the dataset.
...@@ -437,7 +437,7 @@ class DistEvalHook(EvalHook): ...@@ -437,7 +437,7 @@ class DistEvalHook(EvalHook):
the `out_dir` will be the concatenation of `out_dir` and the last the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`. level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None. See :class:`mmengine.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of **eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset. the dataset.
""" """
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import os.path as osp import os.path as osp
from typing import Dict, Optional from typing import Dict, Optional
import mmengine
import torch import torch
import yaml import yaml
...@@ -96,9 +97,9 @@ class PaviLoggerHook(LoggerHook): ...@@ -96,9 +97,9 @@ class PaviLoggerHook(LoggerHook):
config_dict = config_dict.copy() config_dict = config_dict.copy()
config_dict.setdefault('max_iter', runner.max_iters) config_dict.setdefault('max_iter', runner.max_iters)
# non-serializable values are first converted in # non-serializable values are first converted in
# mmcv.dump to json # mmengine.dump to json
config_dict = json.loads( config_dict = json.loads(
mmcv.dump(config_dict, file_format='json')) mmengine.dump(config_dict, file_format='json'))
session_text = yaml.dump(config_dict) session_text = yaml.dump(config_dict)
self.init_kwargs.setdefault('session_text', session_text) self.init_kwargs.setdefault('session_text', session_text)
self.writer = SummaryWriter(**self.init_kwargs) self.writer = SummaryWriter(**self.init_kwargs)
......
...@@ -5,11 +5,11 @@ import os.path as osp ...@@ -5,11 +5,11 @@ import os.path as osp
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import mmengine
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmengine.fileio.file_client import FileClient
import mmcv
from mmcv.fileio.file_client import FileClient
from mmcv.utils import is_tuple_of, scandir from mmcv.utils import is_tuple_of, scandir
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -48,7 +48,7 @@ class TextLoggerHook(LoggerHook): ...@@ -48,7 +48,7 @@ class TextLoggerHook(LoggerHook):
removed. Default: True. removed. Default: True.
`New in version 1.3.16.` `New in version 1.3.16.`
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details. FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None. Default: None.
`New in version 1.3.16.` `New in version 1.3.16.`
""" """
...@@ -190,7 +190,7 @@ class TextLoggerHook(LoggerHook): ...@@ -190,7 +190,7 @@ class TextLoggerHook(LoggerHook):
# only append log at last line # only append log at last line
if runner.rank == 0: if runner.rank == 0:
with open(self.json_log_path, 'a+') as f: with open(self.json_log_path, 'a+') as f:
mmcv.dump(json_log, f, file_format='json') mmengine.dump(json_log, f, file_format='json')
f.write('\n') f.write('\n')
def _round_float(self, items): def _round_float(self, items):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional from typing import Optional
import mmengine
import numpy as np import numpy as np
import mmcv import mmcv
...@@ -33,7 +34,7 @@ class LoadImageFromFile(BaseTransform): ...@@ -33,7 +34,7 @@ class LoadImageFromFile(BaseTransform):
See :func:``mmcv.imfrombytes`` for details. See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'. Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. See :class:`mmengine.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``. Defaults to ``dict(backend='disk')``.
ignore_empty (bool): Whether to allow loading empty image or file path ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False. not existent. Defaults to False.
...@@ -50,7 +51,7 @@ class LoadImageFromFile(BaseTransform): ...@@ -50,7 +51,7 @@ class LoadImageFromFile(BaseTransform):
self.color_type = color_type self.color_type = color_type
self.imdecode_backend = imdecode_backend self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy() self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmengine.FileClient(**self.file_client_args)
def transform(self, results: dict) -> Optional[dict]: def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image. """Functions to load image.
...@@ -168,7 +169,7 @@ class LoadAnnotations(BaseTransform): ...@@ -168,7 +169,7 @@ class LoadAnnotations(BaseTransform):
See :fun:``mmcv.imfrombytes`` for details. See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'. Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:``mmcv.fileio.FileClient`` for details. See :class:``mmengine.fileio.FileClient`` for details.
Defaults to ``dict(backend='disk')``. Defaults to ``dict(backend='disk')``.
""" """
...@@ -188,7 +189,7 @@ class LoadAnnotations(BaseTransform): ...@@ -188,7 +189,7 @@ class LoadAnnotations(BaseTransform):
self.with_keypoints = with_keypoints self.with_keypoints = with_keypoints
self.imdecode_backend = imdecode_backend self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy() self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmengine.FileClient(**self.file_client_args)
def _load_bboxes(self, results: dict) -> None: def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations. """Private function to load bounding box annotations.
......
...@@ -15,6 +15,7 @@ from collections import abc ...@@ -15,6 +15,7 @@ from collections import abc
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
import mmengine
from addict import Dict from addict import Dict
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode
...@@ -217,8 +218,7 @@ class Config: ...@@ -217,8 +218,7 @@ class Config:
# delete imported module # delete imported module
del sys.modules[temp_module_name] del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')): elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv cfg_dict = mmengine.load(temp_config_file.name)
cfg_dict = mmcv.load(temp_config_file.name)
# close temp file # close temp file
temp_config_file.close() temp_config_file.close()
...@@ -583,20 +583,19 @@ class Config: ...@@ -583,20 +583,19 @@ class Config:
file (str, optional): Path of the output file where the config file (str, optional): Path of the output file where the config
will be dumped. Defaults to None. will be dumped. Defaults to None.
""" """
import mmcv
cfg_dict = super().__getattribute__('_cfg_dict').to_dict() cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None: if file is None:
if self.filename is None or self.filename.endswith('.py'): if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text return self.pretty_text
else: else:
file_format = self.filename.split('.')[-1] file_format = self.filename.split('.')[-1]
return mmcv.dump(cfg_dict, file_format=file_format) return mmengine.dump(cfg_dict, file_format=file_format)
elif file.endswith('.py'): elif file.endswith('.py'):
with open(file, 'w', encoding='utf-8') as f: with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text) f.write(self.pretty_text)
else: else:
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
return mmcv.dump(cfg_dict, file=file, file_format=file_format) return mmengine.dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self, options, allow_list_keys=True): def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict. """Merge list into cfg_dict.
......
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import sys
import tempfile
from unittest.mock import MagicMock, patch
import pytest
import mmcv
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
def _test_handler(file_format, test_obj, str_checker, mode='r+'):
# dump to a string
dump_str = mmcv.dump(test_obj, file_format=file_format)
str_checker(dump_str)
# load/dump with filenames from disk
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump')
mmcv.dump(test_obj, tmp_filename, file_format=file_format)
assert osp.isfile(tmp_filename)
load_obj = mmcv.load(tmp_filename, file_format=file_format)
assert load_obj == test_obj
os.remove(tmp_filename)
# load/dump with filename from petrel
method = 'put' if 'b' in mode else 'put_text'
with patch.object(PetrelBackend, method, return_value=None) as mock_method:
filename = 's3://path/of/your/file'
mmcv.dump(test_obj, filename, file_format=file_format)
mock_method.assert_called()
# json load/dump with a file-like object
with tempfile.NamedTemporaryFile(mode, delete=False) as f:
tmp_filename = f.name
mmcv.dump(test_obj, f, file_format=file_format)
assert osp.isfile(tmp_filename)
with open(tmp_filename, mode) as f:
load_obj = mmcv.load(f, file_format=file_format)
assert load_obj == test_obj
os.remove(tmp_filename)
# automatically inference the file format from the given filename
tmp_filename = osp.join(tempfile.gettempdir(),
'mmcv_test_dump.' + file_format)
mmcv.dump(test_obj, tmp_filename)
assert osp.isfile(tmp_filename)
load_obj = mmcv.load(tmp_filename)
assert load_obj == test_obj
os.remove(tmp_filename)
obj_for_test = [{'a': 'abc', 'b': 1}, 2, 'c']
def test_json():
def json_checker(dump_str):
assert dump_str in [
'[{"a": "abc", "b": 1}, 2, "c"]', '[{"b": 1, "a": "abc"}, 2, "c"]'
]
_test_handler('json', obj_for_test, json_checker)
def test_yaml():
def yaml_checker(dump_str):
assert dump_str in [
'- {a: abc, b: 1}\n- 2\n- c\n', '- {b: 1, a: abc}\n- 2\n- c\n',
'- a: abc\n b: 1\n- 2\n- c\n', '- b: 1\n a: abc\n- 2\n- c\n'
]
_test_handler('yaml', obj_for_test, yaml_checker)
def test_pickle():
def pickle_checker(dump_str):
import pickle
assert pickle.loads(dump_str) == obj_for_test
_test_handler('pickle', obj_for_test, pickle_checker, mode='rb+')
def test_exception():
test_obj = [{'a': 'abc', 'b': 1}, 2, 'c']
with pytest.raises(ValueError):
mmcv.dump(test_obj)
with pytest.raises(TypeError):
mmcv.dump(test_obj, 'tmp.txt')
def test_register_handler():
@mmcv.register_handler('txt')
class TxtHandler1(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
@mmcv.register_handler(['txt1', 'txt2'])
class TxtHandler2(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write('\n')
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename) as f:
written = f.read()
os.remove(tmp_filename)
assert written == '\n' + content
def test_list_from_file():
# get list from disk
filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg']
filelist = mmcv.list_from_file(filename, prefix='a/')
assert filelist == ['a/1.jpg', 'a/2.jpg', 'a/3.jpg', 'a/4.jpg', 'a/5.jpg']
filelist = mmcv.list_from_file(filename, offset=2)
assert filelist == ['3.jpg', '4.jpg', '5.jpg']
filelist = mmcv.list_from_file(filename, max_num=2)
assert filelist == ['1.jpg', '2.jpg']
filelist = mmcv.list_from_file(filename, offset=3, max_num=3)
assert filelist == ['4.jpg', '5.jpg']
# get list from http
with patch.object(
HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
filename = 'http://path/of/your/file'
filelist = mmcv.list_from_file(
filename, file_client_args={'backend': 'http'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(
filename, file_client_args={'prefix': 'http'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
# get list from petrel
with patch.object(
PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
filename = 's3://path/of/your/file'
filelist = mmcv.list_from_file(
filename, file_client_args={'backend': 'petrel'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(
filename, file_client_args={'prefix': 's3'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
def test_dict_from_file():
# get dict from disk
filename = osp.join(osp.dirname(__file__), 'data/mapping.txt')
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename, key_type=int)
assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
# get dict from http
with patch.object(
HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'):
filename = 'http://path/of/your/file'
mapping = mmcv.dict_from_file(
filename, file_client_args={'backend': 'http'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(
filename, file_client_args={'prefix': 'http'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
# get dict from petrel
with patch.object(
PetrelBackend, 'get_text',
return_value='1 cat\n2 dog cow\n3 panda'):
filename = 's3://path/of/your/file'
mapping = mmcv.dict_from_file(
filename, file_client_args={'backend': 'petrel'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(
filename, file_client_args={'prefix': 's3'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
...@@ -7,13 +7,14 @@ from pathlib import Path ...@@ -7,13 +7,14 @@ from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import cv2 import cv2
import mmengine
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from mmengine.fileio.file_client import HTTPBackend, PetrelBackend
from numpy.testing import assert_allclose, assert_array_equal from numpy.testing import assert_allclose, assert_array_equal
import mmcv import mmcv
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend
if torch.__version__ == 'parrots': if torch.__version__ == 'parrots':
pytest.skip('not necessary in parrots test', allow_module_level=True) pytest.skip('not necessary in parrots test', allow_module_level=True)
...@@ -46,7 +47,7 @@ class TestIO: ...@@ -46,7 +47,7 @@ class TestIO:
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls):
# clean instances avoid to influence other unittest # clean instances avoid to influence other unittest
mmcv.FileClient._instances = {} mmengine.FileClient._instances = {}
def assert_img_equal(self, img, ref_img, ratio_thr=0.999): def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
assert img.shape == ref_img.shape assert img.shape == ref_img.shape
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import os.path as osp import os.path as osp
from unittest.mock import patch from unittest.mock import patch
import mmengine
import pytest import pytest
import torchvision import torchvision
...@@ -30,7 +31,7 @@ def test_default_mmcv_home(): ...@@ -30,7 +31,7 @@ def test_default_mmcv_home():
assert _get_mmcv_home() == os.path.expanduser( assert _get_mmcv_home() == os.path.expanduser(
os.path.join(DEFAULT_CACHE_DIR, 'mmcv')) os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
model_urls = get_external_models() model_urls = get_external_models()
assert model_urls == mmcv.load( assert model_urls == mmengine.load(
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')) osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
...@@ -144,9 +145,8 @@ class Testnms: ...@@ -144,9 +145,8 @@ class Testnms:
nms_match(wrong_dets, iou_thr) nms_match(wrong_dets, iou_thr)
def test_batched_nms(self): def test_batched_nms(self):
import mmcv
from mmcv.ops import batched_nms from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl') results = mmengine.load('./tests/data/batched_nms_data.pkl')
nms_max_num = 100 nms_max_num = 100
nms_cfg = dict( nms_cfg = dict(
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from functools import partial from functools import partial
from typing import Callable from typing import Callable
import mmengine
import numpy as np import numpy as np
import onnx import onnx
import pytest import pytest
...@@ -117,7 +118,6 @@ def test_roialign(): ...@@ -117,7 +118,6 @@ def test_roialign():
def test_nms(): def test_nms():
try: try:
import mmcv
from mmcv.ops import nms from mmcv.ops import nms
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation') pytest.skip('test requires compilation')
...@@ -125,7 +125,7 @@ def test_nms(): ...@@ -125,7 +125,7 @@ def test_nms():
# trt config # trt config
fp16_mode = False fp16_mode = False
max_workspace_size = 1 << 30 max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl') data = mmengine.load('./tests/data/batched_nms_data.pkl')
boxes = torch.from_numpy(data['boxes']).cuda() boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda() scores = torch.from_numpy(data['scores']).cuda()
nms = partial( nms = partial(
...@@ -188,7 +188,6 @@ def test_nms(): ...@@ -188,7 +188,6 @@ def test_nms():
def test_batched_nms(): def test_batched_nms():
try: try:
import mmcv
from mmcv.ops import batched_nms from mmcv.ops import batched_nms
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation') pytest.skip('test requires compilation')
...@@ -197,7 +196,7 @@ def test_batched_nms(): ...@@ -197,7 +196,7 @@ def test_batched_nms():
os.environ['ONNX_BACKEND'] = 'MMCVTensorRT' os.environ['ONNX_BACKEND'] = 'MMCVTensorRT'
fp16_mode = False fp16_mode = False
max_workspace_size = 1 << 30 max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl') data = mmengine.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1) nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1)
boxes = torch.from_numpy(data['boxes']).cuda() boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda() scores = torch.from_numpy(data['scores']).cuda()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import tempfile import tempfile
import mmengine
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
import mmcv
from mmcv.cnn.utils.weight_init import update_init_info from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg from mmcv.utils import Registry, build_from_cfg
...@@ -135,7 +135,7 @@ def test_initilization_info_logger(): ...@@ -135,7 +135,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped # assert initialization information has been dumped
assert os.path.exists(log_file) assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file) lines = mmengine.list_from_file(log_file)
# check initialization information is right # check initialization information is right
for i, line in enumerate(lines): for i, line in enumerate(lines):
...@@ -210,7 +210,7 @@ def test_initilization_info_logger(): ...@@ -210,7 +210,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped # assert initialization information has been dumped
assert os.path.exists(log_file) assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file) lines = mmengine.list_from_file(log_file)
# check initialization information is right # check initialization information is right
for i, line in enumerate(lines): for i, line in enumerate(lines):
if 'TopLevelModule' in line and 'init_cfg' not in line: if 'TopLevelModule' in line and 'init_cfg' not in line:
......
...@@ -8,9 +8,9 @@ import pytest ...@@ -8,9 +8,9 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from mmengine.fileio.file_client import PetrelBackend
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint, get_state_dict, load_checkpoint,
......
...@@ -11,9 +11,9 @@ import pytest ...@@ -11,9 +11,9 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from mmengine.fileio.file_client import PetrelBackend
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import DistEvalHook as BaseDistEvalHook from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EpochBasedRunner from mmcv.runner import EpochBasedRunner
from mmcv.runner import EvalHook as BaseEvalHook from mmcv.runner import EvalHook as BaseEvalHook
......
...@@ -18,10 +18,10 @@ from unittest.mock import MagicMock, Mock, call, patch ...@@ -18,10 +18,10 @@ from unittest.mock import MagicMock, Mock, call, patch
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.fileio.file_client import PetrelBackend
from torch.nn.init import constant_ from torch.nn.init import constant_
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
# yapf: disable # yapf: disable
from mmcv.runner import (CheckpointHook, ClearMLLoggerHook, DvcliveLoggerHook, from mmcv.runner import (CheckpointHook, ClearMLLoggerHook, DvcliveLoggerHook,
EMAHook, Fp16OptimizerHook, EMAHook, Fp16OptimizerHook,
......
...@@ -10,8 +10,9 @@ from pathlib import Path ...@@ -10,8 +10,9 @@ from pathlib import Path
import pytest import pytest
import yaml import yaml
from mmengine import dump, load
from mmcv import Config, ConfigDict, DictAction, dump, load from mmcv import Config, ConfigDict, DictAction
data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data') data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment