Unverified Commit e333d822 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add FileClient to access file from different backends (#237)

* add file client

* install missing requirements

* use .format() instead of f-string
parent a0618d10
......@@ -25,7 +25,7 @@ python:
- "3.8"
before_script:
- pip install codecov flake8 yapf isort
- pip install codecov flake8 yapf isort lmdb
- flake8 .
- isort -rc --check-only --diff mmcv/ tests/ examples/
- yapf -r -d mmcv/ tests/ examples/
......
# Copyright (c) Open-MMLab. All rights reserved.
from .file_client import BaseStorageBackend, FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .io import dump, load, register_handler
from .parse import dict_from_file, list_from_file
__all__ = [
'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler',
'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file'
'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
'list_from_file', 'dict_from_file'
]
import inspect
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: `get()` and `get_text()`.
`get()` reads the file as a byte stream and `get_text()` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class CephBackend(BaseStorageBackend):
"""Ceph storage backend."""
def __init__(self):
try:
import ceph
except ImportError:
raise ImportError('Please install ceph to enable CephBackend.')
self._client = ceph.S3Client()
def get(self, filepath):
filepath = str(filepath)
value = self._client.Get(filepath)
value_buf = memoryview(value)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError(
'Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_path (str): Lmdb database path.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_path (str): Lmdb database path.
"""
def __init__(self,
db_path,
readonly=True,
lock=False,
readahead=False,
**kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
self.db_path = str(db_path)
self._client = lmdb.open(
self.db_path,
readonly=readonly,
lock=lock,
readahead=readahead,
**kwargs)
def get(self, filepath):
"""Get values according to the filepath.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
"""
filepath = str(filepath)
with self._client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'ceph': CephBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(
'Backend {} is not supported. Currently supported ones are {}'.
format(backend, list(self._backends.keys())))
self.backend = backend
self.client = self._backends[backend](**kwargs)
@classmethod
def register_backend(cls, name, backend):
if not inspect.isclass(backend):
raise TypeError('backend should be a class but got {}'.format(
type(backend)))
if not issubclass(backend, BaseStorageBackend):
raise TypeError(
'backend {} is not a subclass of BaseStorageBackend'.format(
backend))
cls._backends[name] = backend
def get(self, filepath):
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import mmcv
from mmcv import BaseStorageBackend, FileClient
sys.modules['ceph'] = MagicMock()
sys.modules['mc'] = MagicMock()
class MockS3Client(object):
def Get(self, filepath):
with open(filepath, 'rb') as f:
content = f.read()
return content
class MockMemcachedClient(object):
def __init__(self, server_list_cfg, client_cfg):
pass
def Get(self, filepath, buffer):
with open(filepath, 'rb') as f:
buffer.content = f.read()
class TestFileClient(object):
@classmethod
def setup_class(cls):
cls.test_data_dir = Path(__file__).parent / 'data'
cls.img_path = cls.test_data_dir / 'color.jpg'
cls.img_shape = (300, 400, 3)
cls.text_path = cls.test_data_dir / 'filelist.txt'
def test_disk_backend(self):
disk_backend = FileClient('disk')
# input path is Path object
img_bytes = disk_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# input path is str
img_bytes = disk_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert self.img_path.open('rb').read() == img_bytes
assert img.shape == self.img_shape
# input path is Path object
value_buf = disk_backend.get_text(self.text_path)
assert self.text_path.open('r').read() == value_buf
# input path is str
value_buf = disk_backend.get_text(str(self.text_path))
assert self.text_path.open('r').read() == value_buf
@patch('ceph.S3Client', MockS3Client)
def test_ceph_backend(self):
ceph_backend = FileClient('ceph')
# input path is Path object
with pytest.raises(NotImplementedError):
ceph_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
ceph_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = ceph_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = ceph_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
@patch('mc.MemcachedClient.GetInstance', MockMemcachedClient)
@patch('mc.pyvector', MagicMock)
@patch('mc.ConvertBuffer', lambda x: x.content)
def test_memcached_backend(self):
mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
mc_backend = FileClient('memcached', **mc_cfg)
# input path is Path object
with pytest.raises(NotImplementedError):
mc_backend.get_text(self.text_path)
# input path is str
with pytest.raises(NotImplementedError):
mc_backend.get_text(str(self.text_path))
# input path is Path object
img_bytes = mc_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
# input path is str
img_bytes = mc_backend.get(str(self.img_path))
img = mmcv.imfrombytes(img_bytes)
assert img.shape == self.img_shape
def test_lmdb_backend(self):
lmdb_path = self.test_data_dir / 'demo.lmdb'
# db_path is Path object
lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(self.text_path)
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
# db_path is str
lmdb_backend = FileClient('lmdb', db_path=str(lmdb_path))
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(str(self.text_path))
img_bytes = lmdb_backend.get('baboon')
img = mmcv.imfrombytes(img_bytes)
assert img.shape == (120, 125, 3)
def test_register_backend(self):
with pytest.raises(TypeError):
class TestClass1(object):
pass
FileClient.register_backend('TestClass1', TestClass1)
with pytest.raises(TypeError):
FileClient.register_backend('int', 0)
class ExampleBackend(BaseStorageBackend):
def get(self, filepath):
return filepath
def get_text(self, filepath):
return filepath
FileClient.register_backend('example', ExampleBackend)
example_backend = FileClient('example')
assert example_backend.get(self.img_path) == self.img_path
assert example_backend.get_text(self.text_path) == self.text_path
assert 'example' in FileClient._backends
def test_error(self):
with pytest.raises(ValueError):
FileClient('hadoop')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment