Unverified Commit 0fe1c647 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove fileio from mmcv and use mmengine.fileio instead (#2179)

parent 0b4285d9
......@@ -5,10 +5,10 @@ from math import inf
from typing import Callable, List, Optional
import torch.distributed as dist
from mmengine.fileio import FileClient
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from mmcv.fileio import FileClient
from mmcv.utils import is_seq_of
from .hook import Hook
from .logger import LoggerHook
......@@ -61,7 +61,7 @@ class EvalHook(Hook):
level directory of `runner.work_dir`.
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
See :class:`mmengine.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
......@@ -437,7 +437,7 @@ class DistEvalHook(EvalHook):
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
See :class:`mmengine.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
......
......@@ -4,6 +4,7 @@ import os
import os.path as osp
from typing import Dict, Optional
import mmengine
import torch
import yaml
......@@ -96,9 +97,9 @@ class PaviLoggerHook(LoggerHook):
config_dict = config_dict.copy()
config_dict.setdefault('max_iter', runner.max_iters)
# non-serializable values are first converted in
# mmcv.dump to json
# mmengine.dump to json
config_dict = json.loads(
mmcv.dump(config_dict, file_format='json'))
mmengine.dump(config_dict, file_format='json'))
session_text = yaml.dump(config_dict)
self.init_kwargs.setdefault('session_text', session_text)
self.writer = SummaryWriter(**self.init_kwargs)
......
......@@ -5,11 +5,11 @@ import os.path as osp
from collections import OrderedDict
from typing import Dict, Optional, Union
import mmengine
import torch
import torch.distributed as dist
from mmengine.fileio.file_client import FileClient
import mmcv
from mmcv.fileio.file_client import FileClient
from mmcv.utils import is_tuple_of, scandir
from ..hook import HOOKS
from .base import LoggerHook
......@@ -48,7 +48,7 @@ class TextLoggerHook(LoggerHook):
removed. Default: True.
`New in version 1.3.16.`
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
......@@ -190,7 +190,7 @@ class TextLoggerHook(LoggerHook):
# only append log at last line
if runner.rank == 0:
with open(self.json_log_path, 'a+') as f:
mmcv.dump(json_log, f, file_format='json')
mmengine.dump(json_log, f, file_format='json')
f.write('\n')
def _round_float(self, items):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import mmengine
import numpy as np
import mmcv
......@@ -33,7 +34,7 @@ class LoadImageFromFile(BaseTransform):
See :func:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
ignore_empty (bool): Whether to allow loading empty image or file path
not existent. Defaults to False.
......@@ -50,7 +51,7 @@ class LoadImageFromFile(BaseTransform):
self.color_type = color_type
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args)
self.file_client = mmengine.FileClient(**self.file_client_args)
def transform(self, results: dict) -> Optional[dict]:
"""Functions to load image.
......@@ -168,7 +169,7 @@ class LoadAnnotations(BaseTransform):
See :fun:``mmcv.imfrombytes`` for details.
Defaults to 'cv2'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:``mmcv.fileio.FileClient`` for details.
See :class:``mmengine.fileio.FileClient`` for details.
Defaults to ``dict(backend='disk')``.
"""
......@@ -188,7 +189,7 @@ class LoadAnnotations(BaseTransform):
self.with_keypoints = with_keypoints
self.imdecode_backend = imdecode_backend
self.file_client_args = file_client_args.copy()
self.file_client = mmcv.FileClient(**self.file_client_args)
self.file_client = mmengine.FileClient(**self.file_client_args)
def _load_bboxes(self, results: dict) -> None:
"""Private function to load bounding box annotations.
......
......@@ -15,6 +15,7 @@ from collections import abc
from importlib import import_module
from pathlib import Path
import mmengine
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
......@@ -217,8 +218,7 @@ class Config:
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(temp_config_file.name)
cfg_dict = mmengine.load(temp_config_file.name)
# close temp file
temp_config_file.close()
......@@ -583,20 +583,19 @@ class Config:
file (str, optional): Path of the output file where the config
will be dumped. Defaults to None.
"""
import mmcv
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
else:
file_format = self.filename.split('.')[-1]
return mmcv.dump(cfg_dict, file_format=file_format)
return mmengine.dump(cfg_dict, file_format=file_format)
elif file.endswith('.py'):
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
file_format = file.split('.')[-1]
return mmcv.dump(cfg_dict, file=file, file_format=file_format)
return mmengine.dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import sys
import tempfile
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import mmcv
from mmcv import BaseStorageBackend, FileClient
from mmcv.utils import has_method
sys.modules['ceph'] = MagicMock()
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
sys.modules['mc'] = MagicMock()
@contextmanager
def build_temporary_directory():
"""Build a temporary directory containing many files to test
``FileClient.list_dir_or_file``.
. \n
| -- dir1 \n
| -- | -- text3.txt \n
| -- dir2 \n
| -- | -- dir3 \n
| -- | -- | -- text4.txt \n
| -- | -- img.jpg \n
| -- text1.txt \n
| -- text2.txt \n
"""
with tempfile.TemporaryDirectory() as tmp_dir:
text1 = Path(tmp_dir) / 'text1.txt'
text1.open('w').write('text1')
text2 = Path(tmp_dir) / 'text2.txt'
text2.open('w').write('text2')
dir1 = Path(tmp_dir) / 'dir1'
dir1.mkdir()
text3 = dir1 / 'text3.txt'
text3.open('w').write('text3')
dir2 = Path(tmp_dir) / 'dir2'
dir2.mkdir()
jpg1 = dir2 / 'img.jpg'
jpg1.open('wb').write(b'img')
dir3 = dir2 / 'dir3'
dir3.mkdir()
text4 = dir3 / 'text4.txt'
text4.open('w').write('text4')
yield tmp_dir
@contextmanager
def delete_and_reset_method(obj, method):
method_obj = deepcopy(getattr(type(obj), method))
try:
delattr(type(obj), method)
yield
finally:
setattr(type(obj), method, method_obj)
class MockS3Client:
def __init__(self, enable_mc=True):
self.enable_mc = enable_mc
def Get(self, filepath):
with open(filepath, 'rb') as f:
content = f.read()
return content
class MockPetrelClient:
def __init__(self, enable_mc=True, enable_multi_cluster=False):
self.enable_mc = enable_mc
self.enable_multi_cluster = enable_multi_cluster
def Get(self, filepath):
with open(filepath, 'rb') as f:
content = f.read()
return content
def put(self):
pass
def delete(self):
pass
def contains(self):
pass
def isdir(self):
pass
def list(self, dir_path):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
yield entry.name
elif osp.isdir(entry.path):
yield entry.name + '/'
class MockMemcachedClient:
def __init__(self, server_list_cfg, client_cfg):
pass
def Get(self, filepath, buffer):
with open(filepath, 'rb') as f:
buffer.content = f.read()
class TestFileClient:
@classmethod
def setup_class(cls):
cls.test_data_dir = Path(__file__).parent / 'data'
cls.img_path = cls.test_data_dir / 'color.jpg'
cls.img_shape = (300, 400, 3)
cls.text_path = cls.test_data_dir / 'filelist.txt'
def test_error(self):
with pytest.raises(ValueError):
FileClient('hadoop')
def test_disk_backend(self):
disk_backend = FileClient('disk')
# test `name` attribute
assert disk_backend.name == 'HardDiskBackend'
# test `allow_symlink` attribute
assert disk_backend.allow_symlink
# test `get`
# input path is Path object
img_bytes = disk_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# input path is str
img_bytes = disk_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# test `get_text`
# input path is Path object
value_buf = disk_backend.get_text(self.text_path)
assert self.text_path.open('r').read() == value_buf
# input path is str
value_buf = disk_backend.get_text(str(self.text_path))
assert self.text_path.open('r').read() == value_buf
with tempfile.TemporaryDirectory() as tmp_dir:
# test `put`
filepath1 = Path(tmp_dir) / 'test.jpg'
disk_backend.put(b'disk', filepath1)
assert filepath1.open('rb').read() == b'disk'
# test the `mkdir_or_exist` behavior in `put`
_filepath1 = Path(tmp_dir) / 'not_existed_dir1' / 'test.jpg'
disk_backend.put(b'disk', _filepath1)
assert _filepath1.open('rb').read() == b'disk'
# test `put_text`
filepath2 = Path(tmp_dir) / 'test.txt'
disk_backend.put_text('disk', filepath2)
assert filepath2.open('r').read() == 'disk'
# test the `mkdir_or_exist` behavior in `put_text`
_filepath2 = Path(tmp_dir) / 'not_existed_dir2' / 'test.txt'
disk_backend.put_text('disk', _filepath2)
assert _filepath2.open('r').read() == 'disk'
# test `isfile`
assert disk_backend.isfile(filepath2)
assert not disk_backend.isfile(Path(tmp_dir) / 'not/existed/path')
# test `remove`
disk_backend.remove(filepath2)
# test `exists`
assert not disk_backend.exists(filepath2)
# test `get_local_path`
# if the backend is disk, `get_local_path` just return the input
with disk_backend.get_local_path(filepath1) as path:
assert str(filepath1) == path
assert osp.isfile(filepath1)
# test `join_path`
disk_dir = '/path/of/your/directory'
assert disk_backend.join_path(disk_dir, 'file') == \
osp.join(disk_dir, 'file')
assert disk_backend.join_path(disk_dir, 'dir', 'file') == \
osp.join(disk_dir, 'dir', 'file')
# test `list_dir_or_file`
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(disk_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(disk_backend.list_dir_or_file(
tmp_dir, recursive=True)) == {
'dir1',
osp.join('dir1', 'text3.txt'), 'dir2',
osp.join('dir2', 'dir3'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
}
# 3. only list directories
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match='`suffix` should be None when `list_dir` is True'):
# Exception is raised among the `list_dir_or_file` of client,
# so we need to invode the client to trigger the exception
disk_backend.client.list_dir_or_file(
tmp_dir, list_file=False, suffix='.txt')
# 4. only list directories recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2',
osp.join('dir2', 'dir3')
}
# 5. only list files
assert set(disk_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
disk_backend.client.list_dir_or_file(
tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])
# 8. only list files ending with suffix recursively
assert set(
disk_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt',
'text2.txt'
}
# 7. only list files ending with suffix
assert set(
disk_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == {
osp.join('dir1', 'text3.txt'),
osp.join('dir2', 'dir3', 'text4.txt'),
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
}
@patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self):
ceph_backend = FileClient('ceph')
# test `allow_symlink` attribute
assert not ceph_backend.allow_symlink
# input path is Path object
with pytest.raises(NotImplementedError):
ceph_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
ceph_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = ceph_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = ceph_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# `path_mapping` is either None or dict
with pytest.raises(AssertionError):
FileClient('ceph', path_mapping=1)
# test `path_mapping`
ceph_path = 's3://user/data'
ceph_backend = FileClient(
'ceph', path_mapping={str(self.test_data_dir): ceph_path})
ceph_backend.client._client.Get = MagicMock(
return_value=ceph_backend.client._client.Get(self.img_path))
img_bytes = ceph_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
ceph_backend.client._client.Get.assert_called_with(
str(self.img_path).replace(str(self.test_data_dir), ceph_path))
@patch('petrel_client.client.Client', MockPetrelClient)
@pytest.mark.parametrize('backend,prefix', [('petrel', None),
(None, 's3')])
def test_petrel_backend(self, backend, prefix):
petrel_backend = FileClient(backend=backend, prefix=prefix)
# test `allow_symlink` attribute
assert not petrel_backend.allow_symlink
# input path is Path object
img_bytes = petrel_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = petrel_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# `path_mapping` is either None or dict
with pytest.raises(AssertionError):
FileClient('petrel', path_mapping=1)
# test `_map_path`
petrel_dir = 's3://user/data'
petrel_backend = FileClient(
'petrel', path_mapping={str(self.test_data_dir): petrel_dir})
assert petrel_backend.client._map_path(str(self.img_path)) == \
str(self.img_path).replace(str(self.test_data_dir), petrel_dir)
petrel_path = f'{petrel_dir}/test.jpg'
petrel_backend = FileClient('petrel')
# test `_format_path`
assert petrel_backend.client._format_path('s3://user\\data\\test.jpg')\
== petrel_path
# test `get`
with patch.object(
petrel_backend.client._client, 'Get',
return_value=b'petrel') as mock_get:
assert petrel_backend.get(petrel_path) == b'petrel'
mock_get.assert_called_once_with(petrel_path)
# test `get_text`
with patch.object(
petrel_backend.client._client, 'Get',
return_value=b'petrel') as mock_get:
assert petrel_backend.get_text(petrel_path) == 'petrel'
mock_get.assert_called_once_with(petrel_path)
# test `put`
with patch.object(petrel_backend.client._client, 'put') as mock_put:
petrel_backend.put(b'petrel', petrel_path)
mock_put.assert_called_once_with(petrel_path, b'petrel')
# test `put_text`
with patch.object(petrel_backend.client._client, 'put') as mock_put:
petrel_backend.put_text('petrel', petrel_path)
mock_put.assert_called_once_with(petrel_path, b'petrel')
# test `remove`
assert has_method(petrel_backend.client._client, 'delete')
# raise Exception if `delete` is not implemented
with delete_and_reset_method(petrel_backend.client._client, 'delete'):
assert not has_method(petrel_backend.client._client, 'delete')
with pytest.raises(NotImplementedError):
petrel_backend.remove(petrel_path)
with patch.object(petrel_backend.client._client,
'delete') as mock_delete:
petrel_backend.remove(petrel_path)
mock_delete.assert_called_once_with(petrel_path)
# test `exists`
assert has_method(petrel_backend.client._client, 'contains')
assert has_method(petrel_backend.client._client, 'isdir')
# raise Exception if `delete` is not implemented
with delete_and_reset_method(petrel_backend.client._client,
'contains'), delete_and_reset_method(
petrel_backend.client._client,
'isdir'):
assert not has_method(petrel_backend.client._client, 'contains')
assert not has_method(petrel_backend.client._client, 'isdir')
with pytest.raises(NotImplementedError):
petrel_backend.exists(petrel_path)
with patch.object(
petrel_backend.client._client, 'contains',
return_value=True) as mock_contains:
assert petrel_backend.exists(petrel_path)
mock_contains.assert_called_once_with(petrel_path)
# test `isdir`
assert has_method(petrel_backend.client._client, 'isdir')
with delete_and_reset_method(petrel_backend.client._client, 'isdir'):
assert not has_method(petrel_backend.client._client, 'isdir')
with pytest.raises(NotImplementedError):
petrel_backend.isdir(petrel_path)
with patch.object(
petrel_backend.client._client, 'isdir',
return_value=True) as mock_isdir:
assert petrel_backend.isdir(petrel_dir)
mock_isdir.assert_called_once_with(petrel_dir)
# test `isfile`
assert has_method(petrel_backend.client._client, 'contains')
with delete_and_reset_method(petrel_backend.client._client,
'contains'):
assert not has_method(petrel_backend.client._client, 'contains')
with pytest.raises(NotImplementedError):
petrel_backend.isfile(petrel_path)
with patch.object(
petrel_backend.client._client, 'contains',
return_value=True) as mock_contains:
assert petrel_backend.isfile(petrel_path)
mock_contains.assert_called_once_with(petrel_path)
# test `join_path`
assert petrel_backend.join_path(petrel_dir, 'file') == \
f'{petrel_dir}/file'
assert petrel_backend.join_path(f'{petrel_dir}/', 'file') == \
f'{petrel_dir}/file'
assert petrel_backend.join_path(petrel_dir, 'dir', 'file') == \
f'{petrel_dir}/dir/file'
# test `get_local_path`
with patch.object(petrel_backend.client._client, 'Get',
return_value=b'petrel') as mock_get, \
patch.object(petrel_backend.client._client, 'contains',
return_value=True) as mock_contains:
with petrel_backend.get_local_path(petrel_path) as path:
assert Path(path).open('rb').read() == b'petrel'
# exist the with block and path will be released
assert not osp.isfile(path)
mock_get.assert_called_once_with(petrel_path)
mock_contains.assert_called_once_with(petrel_path)
# test `list_dir_or_file`
assert has_method(petrel_backend.client._client, 'list')
with delete_and_reset_method(petrel_backend.client._client, 'list'):
assert not has_method(petrel_backend.client._client, 'list')
with pytest.raises(NotImplementedError):
list(petrel_backend.list_dir_or_file(petrel_dir))
with build_temporary_directory() as tmp_dir:
# 1. list directories and files
assert set(petrel_backend.list_dir_or_file(tmp_dir)) == {
'dir1', 'dir2', 'text1.txt', 'text2.txt'
}
# 2. list directories and files recursively
assert set(
petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == {
'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join(
('dir2', 'dir3')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
}
# 3. only list directories
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False)) == {'dir1', 'dir2'}
with pytest.raises(
TypeError,
match=('`list_dir` should be False when `suffix` is not '
'None')):
# Exception is raised among the `list_dir_or_file` of client,
# so we need to invode the client to trigger the exception
petrel_backend.client.list_dir_or_file(
tmp_dir, list_file=False, suffix='.txt')
# 4. only list directories recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_file=False, recursive=True)) == {
'dir1', 'dir2', '/'.join(('dir2', 'dir3'))
}
# 5. only list files
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'}
# 6. only list files recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix='.txt')) == {'text1.txt', 'text2.txt'}
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False,
suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'}
with pytest.raises(
TypeError,
match='`suffix` must be a string or tuple of strings'):
petrel_backend.client.list_dir_or_file(
tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])
# 8. only list files ending with suffix recursively
assert set(
petrel_backend.list_dir_or_file(
tmp_dir, list_dir=False, suffix='.txt',
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), 'text1.txt',
'text2.txt'
}
# 7. only list files ending with suffix
assert set(
petrel_backend.list_dir_or_file(
tmp_dir,
list_dir=False,
suffix=('.txt', '.jpg'),
recursive=True)) == {
'/'.join(('dir1', 'text3.txt')), '/'.join(
('dir2', 'dir3', 'text4.txt')), '/'.join(
('dir2', 'img.jpg')), 'text1.txt', 'text2.txt'
}
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock)
@patch('mc.ConvertBuffer', lambda x: x.content)
def test_memcached_backend(self):
mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
mc_backend = FileClient('memcached', **mc_cfg)
# test `allow_symlink` attribute
assert not mc_backend.allow_symlink
# input path is Path object
with pytest.raises(NotImplementedError):
mc_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
mc_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = mc_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = mc_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
def test_lmdb_backend(self):
lmdb_path = self.test_data_dir / 'demo.lmdb'
# db_path is Path object
lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
# test `allow_symlink` attribute
assert not lmdb_backend.allow_symlink
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(self.text_path)
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
# db_path is str
lmdb_backend = FileClient('lmdb', db_path=str(lmdb_path))
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(str(self.text_path))
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
@pytest.mark.parametrize('backend,prefix', [('http', None),
(None, 'http')])
def test_http_backend(self, backend, prefix):
http_backend = FileClient(backend=backend, prefix=prefix)
img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
'master/tests/data/color.jpg'
text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
'master/tests/data/filelist.txt'
# test `allow_symlink` attribute
assert not http_backend.allow_symlink
# input is path or Path object
with pytest.raises(Exception):
http_backend.get(self.img_path)
with pytest.raises(Exception):
http_backend.get(str(self.img_path))
with pytest.raises(Exception):
http_backend.get_text(self.text_path)
with pytest.raises(Exception):
http_backend.get_text(str(self.text_path))
# input url is http image
img_bytes = http_backend.get(img_url)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input url is http text
value_buf = http_backend.get_text(text_url)
assert self.text_path.open('r').read() == value_buf
# test `_get_local_path`
# exist the with block and path will be released
with http_backend.get_local_path(img_url) as path:
assert mmcv.imread(path).shape == self.img_shape
assert not osp.isfile(path)
def test_new_magic_method(self):
class DummyBackend1(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath, encoding='utf-8'):
return filepath
FileClient.register_backend('dummy_backend', DummyBackend1)
client1 = FileClient(backend='dummy_backend')
client2 = FileClient(backend='dummy_backend')
assert client1 is client2
# if a backend is overwrote, it will disable the singleton pattern for
# the backend
class DummyBackend2(BaseStorageBackend):
def get(self, filepath):
pass
def get_text(self, filepath):
pass
FileClient.register_backend('dummy_backend', DummyBackend2, force=True)
client3 = FileClient(backend='dummy_backend')
client4 = FileClient(backend='dummy_backend')
assert client2 is not client3
assert client3 is client4
def test_parse_uri_prefix(self):
# input path is None
with pytest.raises(AssertionError):
FileClient.parse_uri_prefix(None)
# input path is list
with pytest.raises(AssertionError):
FileClient.parse_uri_prefix([])
# input path is Path object
assert FileClient.parse_uri_prefix(self.img_path) is None
# input path is str
assert FileClient.parse_uri_prefix(str(self.img_path)) is None
# input path starts with https
img_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
'master/tests/data/color.jpg'
assert FileClient.parse_uri_prefix(img_url) == 'https'
# input path starts with s3
img_url = 's3://your_bucket/img.png'
assert FileClient.parse_uri_prefix(img_url) == 's3'
# input path starts with clusterName:s3
img_url = 'clusterName:s3://your_bucket/img.png'
assert FileClient.parse_uri_prefix(img_url) == 's3'
def test_infer_client(self):
# HardDiskBackend
file_client_args = {'backend': 'disk'}
client = FileClient.infer_client(file_client_args)
assert client.name == 'HardDiskBackend'
client = FileClient.infer_client(uri=self.img_path)
assert client.name == 'HardDiskBackend'
# PetrelBackend
file_client_args = {'backend': 'petrel'}
client = FileClient.infer_client(file_client_args)
assert client.name == 'PetrelBackend'
uri = 's3://user_data'
client = FileClient.infer_client(uri=uri)
assert client.name == 'PetrelBackend'
def test_register_backend(self):
# name must be a string
with pytest.raises(TypeError):
class TestClass1:
pass
FileClient.register_backend(1, TestClass1)
# module must be a class
with pytest.raises(TypeError):
FileClient.register_backend('int', 0)
# module must be a subclass of BaseStorageBackend
with pytest.raises(TypeError):
class TestClass1:
pass
FileClient.register_backend('TestClass1', TestClass1)
class ExampleBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath, encoding='utf-8'):
return filepath
FileClient.register_backend('example', ExampleBackend)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == self.img_path
assert example_backend.get_text(self.text_path) == self.text_path
assert 'example' in FileClient._backends
class Example2Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes2'
def get_text(self, filepath, encoding='utf-8'):
return 'text2'
# force=False
with pytest.raises(KeyError):
FileClient.register_backend('example', Example2Backend)
FileClient.register_backend('example', Example2Backend, force=True)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == b'bytes2'
assert example_backend.get_text(self.text_path) == 'text2'
@FileClient.register_backend(name='example3')
class Example3Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes3'
def get_text(self, filepath, encoding='utf-8'):
return 'text3'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == b'bytes3'
assert example_backend.get_text(self.text_path) == 'text3'
assert 'example3' in FileClient._backends
# force=False
with pytest.raises(KeyError):
@FileClient.register_backend(name='example3')
class Example4Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes4'
def get_text(self, filepath, encoding='utf-8'):
return 'text4'
@FileClient.register_backend(name='example3', force=True)
class Example5Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes5'
def get_text(self, filepath, encoding='utf-8'):
return 'text5'
example_backend = FileClient('example3')
assert example_backend.get(self.img_path) == b'bytes5'
assert example_backend.get_text(self.text_path) == 'text5'
# prefixes is a str
class Example6Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes6'
def get_text(self, filepath, encoding='utf-8'):
return 'text6'
FileClient.register_backend(
'example4',
Example6Backend,
force=True,
prefixes='example4_prefix')
example_backend = FileClient('example4')
assert example_backend.get(self.img_path) == b'bytes6'
assert example_backend.get_text(self.text_path) == 'text6'
example_backend = FileClient(prefix='example4_prefix')
assert example_backend.get(self.img_path) == b'bytes6'
assert example_backend.get_text(self.text_path) == 'text6'
example_backend = FileClient('example4', prefix='example4_prefix')
assert example_backend.get(self.img_path) == b'bytes6'
assert example_backend.get_text(self.text_path) == 'text6'
# prefixes is a list of str
class Example7Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes7'
def get_text(self, filepath, encoding='utf-8'):
return 'text7'
FileClient.register_backend(
'example5',
Example7Backend,
force=True,
prefixes=['example5_prefix1', 'example5_prefix2'])
example_backend = FileClient('example5')
assert example_backend.get(self.img_path) == b'bytes7'
assert example_backend.get_text(self.text_path) == 'text7'
example_backend = FileClient(prefix='example5_prefix1')
assert example_backend.get(self.img_path) == b'bytes7'
assert example_backend.get_text(self.text_path) == 'text7'
example_backend = FileClient(prefix='example5_prefix2')
assert example_backend.get(self.img_path) == b'bytes7'
assert example_backend.get_text(self.text_path) == 'text7'
# backend has a higher priority than prefixes
class Example8Backend(BaseStorageBackend):
def get(self, filepath):
return b'bytes8'
def get_text(self, filepath, encoding='utf-8'):
return 'text8'
FileClient.register_backend(
'example6',
Example8Backend,
force=True,
prefixes='example6_prefix')
example_backend = FileClient('example6')
assert example_backend.get(self.img_path) == b'bytes8'
assert example_backend.get_text(self.text_path) == 'text8'
example_backend = FileClient('example6', prefix='example4_prefix')
assert example_backend.get(self.img_path) == b'bytes8'
assert example_backend.get_text(self.text_path) == 'text8'
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import sys
import tempfile
from unittest.mock import MagicMock, patch
import pytest
import mmcv
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
def _test_handler(file_format, test_obj, str_checker, mode='r+'):
# dump to a string
dump_str = mmcv.dump(test_obj, file_format=file_format)
str_checker(dump_str)
# load/dump with filenames from disk
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test_dump')
mmcv.dump(test_obj, tmp_filename, file_format=file_format)
assert osp.isfile(tmp_filename)
load_obj = mmcv.load(tmp_filename, file_format=file_format)
assert load_obj == test_obj
os.remove(tmp_filename)
# load/dump with filename from petrel
method = 'put' if 'b' in mode else 'put_text'
with patch.object(PetrelBackend, method, return_value=None) as mock_method:
filename = 's3://path/of/your/file'
mmcv.dump(test_obj, filename, file_format=file_format)
mock_method.assert_called()
# json load/dump with a file-like object
with tempfile.NamedTemporaryFile(mode, delete=False) as f:
tmp_filename = f.name
mmcv.dump(test_obj, f, file_format=file_format)
assert osp.isfile(tmp_filename)
with open(tmp_filename, mode) as f:
load_obj = mmcv.load(f, file_format=file_format)
assert load_obj == test_obj
os.remove(tmp_filename)
# automatically inference the file format from the given filename
tmp_filename = osp.join(tempfile.gettempdir(),
'mmcv_test_dump.' + file_format)
mmcv.dump(test_obj, tmp_filename)
assert osp.isfile(tmp_filename)
load_obj = mmcv.load(tmp_filename)
assert load_obj == test_obj
os.remove(tmp_filename)
obj_for_test = [{'a': 'abc', 'b': 1}, 2, 'c']
def test_json():
def json_checker(dump_str):
assert dump_str in [
'[{"a": "abc", "b": 1}, 2, "c"]', '[{"b": 1, "a": "abc"}, 2, "c"]'
]
_test_handler('json', obj_for_test, json_checker)
def test_yaml():
def yaml_checker(dump_str):
assert dump_str in [
'- {a: abc, b: 1}\n- 2\n- c\n', '- {b: 1, a: abc}\n- 2\n- c\n',
'- a: abc\n b: 1\n- 2\n- c\n', '- b: 1\n a: abc\n- 2\n- c\n'
]
_test_handler('yaml', obj_for_test, yaml_checker)
def test_pickle():
def pickle_checker(dump_str):
import pickle
assert pickle.loads(dump_str) == obj_for_test
_test_handler('pickle', obj_for_test, pickle_checker, mode='rb+')
def test_exception():
test_obj = [{'a': 'abc', 'b': 1}, 2, 'c']
with pytest.raises(ValueError):
mmcv.dump(test_obj)
with pytest.raises(TypeError):
mmcv.dump(test_obj, 'tmp.txt')
def test_register_handler():
@mmcv.register_handler('txt')
class TxtHandler1(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
@mmcv.register_handler(['txt1', 'txt2'])
class TxtHandler2(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write('\n')
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename) as f:
written = f.read()
os.remove(tmp_filename)
assert written == '\n' + content
def test_list_from_file():
# get list from disk
filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg', '4.jpg', '5.jpg']
filelist = mmcv.list_from_file(filename, prefix='a/')
assert filelist == ['a/1.jpg', 'a/2.jpg', 'a/3.jpg', 'a/4.jpg', 'a/5.jpg']
filelist = mmcv.list_from_file(filename, offset=2)
assert filelist == ['3.jpg', '4.jpg', '5.jpg']
filelist = mmcv.list_from_file(filename, max_num=2)
assert filelist == ['1.jpg', '2.jpg']
filelist = mmcv.list_from_file(filename, offset=3, max_num=3)
assert filelist == ['4.jpg', '5.jpg']
# get list from http
with patch.object(
HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
filename = 'http://path/of/your/file'
filelist = mmcv.list_from_file(
filename, file_client_args={'backend': 'http'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(
filename, file_client_args={'prefix': 'http'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
# get list from petrel
with patch.object(
PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'):
filename = 's3://path/of/your/file'
filelist = mmcv.list_from_file(
filename, file_client_args={'backend': 'petrel'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(
filename, file_client_args={'prefix': 's3'})
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
filelist = mmcv.list_from_file(filename)
assert filelist == ['1.jpg', '2.jpg', '3.jpg']
def test_dict_from_file():
# get dict from disk
filename = osp.join(osp.dirname(__file__), 'data/mapping.txt')
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename, key_type=int)
assert mapping == {1: 'cat', 2: ['dog', 'cow'], 3: 'panda'}
# get dict from http
with patch.object(
HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'):
filename = 'http://path/of/your/file'
mapping = mmcv.dict_from_file(
filename, file_client_args={'backend': 'http'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(
filename, file_client_args={'prefix': 'http'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
# get dict from petrel
with patch.object(
PetrelBackend, 'get_text',
return_value='1 cat\n2 dog cow\n3 panda'):
filename = 's3://path/of/your/file'
mapping = mmcv.dict_from_file(
filename, file_client_args={'backend': 'petrel'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(
filename, file_client_args={'prefix': 's3'})
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
mapping = mmcv.dict_from_file(filename)
assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'}
......@@ -7,13 +7,14 @@ from pathlib import Path
from unittest.mock import MagicMock, patch
import cv2
import mmengine
import numpy as np
import pytest
import torch
from mmengine.fileio.file_client import HTTPBackend, PetrelBackend
from numpy.testing import assert_allclose, assert_array_equal
import mmcv
from mmcv.fileio.file_client import HTTPBackend, PetrelBackend
if torch.__version__ == 'parrots':
pytest.skip('not necessary in parrots test', allow_module_level=True)
......@@ -46,7 +47,7 @@ class TestIO:
@classmethod
def teardown_class(cls):
# clean instances avoid to influence other unittest
mmcv.FileClient._instances = {}
mmengine.FileClient._instances = {}
def assert_img_equal(self, img, ref_img, ratio_thr=0.999):
assert img.shape == ref_img.shape
......
......@@ -3,6 +3,7 @@ import os
import os.path as osp
from unittest.mock import patch
import mmengine
import pytest
import torchvision
......@@ -30,7 +31,7 @@ def test_default_mmcv_home():
assert _get_mmcv_home() == os.path.expanduser(
os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
model_urls = get_external_models()
assert model_urls == mmcv.load(
assert model_urls == mmengine.load(
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
......
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import numpy as np
import pytest
import torch
......@@ -144,9 +145,8 @@ class Testnms:
nms_match(wrong_dets, iou_thr)
def test_batched_nms(self):
import mmcv
from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl')
results = mmengine.load('./tests/data/batched_nms_data.pkl')
nms_max_num = 100
nms_cfg = dict(
......
......@@ -3,6 +3,7 @@ import os
from functools import partial
from typing import Callable
import mmengine
import numpy as np
import onnx
import pytest
......@@ -117,7 +118,6 @@ def test_roialign():
def test_nms():
try:
import mmcv
from mmcv.ops import nms
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
......@@ -125,7 +125,7 @@ def test_nms():
# trt config
fp16_mode = False
max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl')
data = mmengine.load('./tests/data/batched_nms_data.pkl')
boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda()
nms = partial(
......@@ -188,7 +188,6 @@ def test_nms():
def test_batched_nms():
try:
import mmcv
from mmcv.ops import batched_nms
except (ImportError, ModuleNotFoundError):
pytest.skip('test requires compilation')
......@@ -197,7 +196,7 @@ def test_batched_nms():
os.environ['ONNX_BACKEND'] = 'MMCVTensorRT'
fp16_mode = False
max_workspace_size = 1 << 30
data = mmcv.load('./tests/data/batched_nms_data.pkl')
data = mmengine.load('./tests/data/batched_nms_data.pkl')
nms_cfg = dict(type='nms', iou_threshold=0.7, score_threshold=0.1)
boxes = torch.from_numpy(data['boxes']).cuda()
scores = torch.from_numpy(data['scores']).cuda()
......
# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import mmengine
import pytest
import torch
from torch import nn
import mmcv
from mmcv.cnn.utils.weight_init import update_init_info
from mmcv.runner import BaseModule, ModuleDict, ModuleList, Sequential
from mmcv.utils import Registry, build_from_cfg
......@@ -135,7 +135,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file)
lines = mmengine.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
......@@ -210,7 +210,7 @@ def test_initilization_info_logger():
# assert initialization information has been dumped
assert os.path.exists(log_file)
lines = mmcv.list_from_file(log_file)
lines = mmengine.list_from_file(log_file)
# check initialization information is right
for i, line in enumerate(lines):
if 'TopLevelModule' in line and 'init_cfg' not in line:
......
......@@ -8,9 +8,9 @@ import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from mmengine.fileio.file_client import PetrelBackend
from torch.nn.parallel import DataParallel
from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint,
......
......@@ -11,9 +11,9 @@ import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from mmengine.fileio.file_client import PetrelBackend
from torch.utils.data import DataLoader, Dataset
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EpochBasedRunner
from mmcv.runner import EvalHook as BaseEvalHook
......
......@@ -18,10 +18,10 @@ from unittest.mock import MagicMock, Mock, call, patch
import pytest
import torch
import torch.nn as nn
from mmengine.fileio.file_client import PetrelBackend
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
# yapf: disable
from mmcv.runner import (CheckpointHook, ClearMLLoggerHook, DvcliveLoggerHook,
EMAHook, Fp16OptimizerHook,
......
......@@ -10,8 +10,9 @@ from pathlib import Path
import pytest
import yaml
from mmengine import dump, load
from mmcv import Config, ConfigDict, DictAction, dump, load
from mmcv import Config, ConfigDict, DictAction
data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment