Commit 34eca272 authored by Kai Chen's avatar Kai Chen
Browse files

add register_handler and refactor BaseFileHandler

parent 3e1befe1
from .io import load, dump
from .io import load, dump, register_handler
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .parse import list_from_file, dict_from_file
__all__ = [
'load', 'dump', 'BaseFileHandler', 'JsonHandler', 'PickleHandler',
'YamlHandler', 'list_from_file', 'dict_from_file'
'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler',
'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file'
]
......@@ -3,29 +3,24 @@ from abc import ABCMeta, abstractmethod
class BaseFileHandler(object):
__metaclass__ = ABCMeta
__metaclass__ = ABCMeta # python 2 compatibility
@staticmethod
@abstractmethod
def load_from_path(filepath, **kwargs):
def load_from_fileobj(self, file, **kwargs):
pass
@staticmethod
@abstractmethod
def load_from_fileobj(file, **kwargs):
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@staticmethod
@abstractmethod
def dump_to_str(obj, **kwargs):
def dump_to_str(self, obj, **kwargs):
pass
@staticmethod
@abstractmethod
def dump_to_path(obj, filepath, **kwargs):
pass
def load_from_path(self, filepath, mode='r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
@staticmethod
@abstractmethod
def dump_to_fileobj(obj, file, **kwargs):
pass
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)
......@@ -5,25 +5,11 @@ from .base import BaseFileHandler
class JsonHandler(BaseFileHandler):
@staticmethod
def load_from_path(filepath):
with open(filepath, 'r') as f:
obj = json.load(f)
return obj
@staticmethod
def load_from_fileobj(file):
def load_from_fileobj(self, file):
return json.load(file)
@staticmethod
def dump_to_str(obj, **kwargs):
return json.dumps(obj, **kwargs)
@staticmethod
def dump_to_path(obj, filepath, **kwargs):
with open(filepath, 'w') as f:
json.dump(obj, f, **kwargs)
@staticmethod
def dump_to_fileobj(obj, file, **kwargs):
def dump_to_fileobj(self, obj, file, **kwargs):
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
return json.dumps(obj, **kwargs)
......@@ -5,28 +5,21 @@ from .base import BaseFileHandler
class PickleHandler(BaseFileHandler):
@staticmethod
def load_from_path(filepath, **kwargs):
with open(filepath, 'rb') as f:
obj = pickle.load(f, **kwargs)
return obj
@staticmethod
def load_from_fileobj(file, **kwargs):
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
@staticmethod
def dump_to_str(obj, **kwargs):
kwargs.setdefault('protocol', 2)
return pickle.dumps(obj, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
@staticmethod
def dump_to_path(obj, filepath, **kwargs):
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
with open(filepath, 'wb') as f:
pickle.dump(obj, f, **kwargs)
return pickle.dumps(obj, **kwargs)
@staticmethod
def dump_to_fileobj(obj, file, **kwargs):
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('protocol', 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
......@@ -9,30 +9,14 @@ from .base import BaseFileHandler
class YamlHandler(BaseFileHandler):
@staticmethod
def load_from_path(filepath, **kwargs):
kwargs.setdefault('Loader', Loader)
with open(filepath, 'r') as f:
obj = yaml.load(f, **kwargs)
return obj
@staticmethod
def load_from_fileobj(file, **kwargs):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault('Loader', Loader)
return yaml.load(file, **kwargs)
@staticmethod
def dump_to_str(obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
return yaml.dump(obj, **kwargs)
@staticmethod
def dump_to_path(obj, filepath, **kwargs):
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('Dumper', Dumper)
with open(filepath, 'w') as f:
yaml.dump(obj, f, **kwargs)
yaml.dump(obj, file, **kwargs)
@staticmethod
def dump_to_fileobj(obj, file, **kwargs):
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
yaml.dump(obj, file, **kwargs)
return yaml.dump(obj, **kwargs)
from .handlers import JsonHandler, PickleHandler, YamlHandler
from ..utils import is_str
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from ..utils import is_str, is_list_of
file_handlers = {
'json': JsonHandler,
'yaml': YamlHandler,
'yml': YamlHandler,
'pickle': PickleHandler,
'pkl': PickleHandler
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
'pickle': PickleHandler(),
'pkl': PickleHandler()
}
......@@ -74,3 +74,23 @@ def dump(obj, file=None, file_format=None, **kwargs):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')
def register_handler(handler, file_formats):
"""Register a handler for some file extensions.
Args:
handler (:obj:`BaseFileHandler`): Handler to be registered.
file_formats (str or list[str]): File formats to be handled by this
handler.
"""
if not isinstance(handler, BaseFileHandler):
raise TypeError(
'handler must be a child of BaseFileHandler, not {}'.format(
type(handler)))
if isinstance(file_formats, str):
file_formats = [file_formats]
if not is_list_of(file_formats, str):
raise TypeError('file_formats must be a str or a list of str')
for ext in file_formats:
file_handlers[ext] = handler
......@@ -81,6 +81,33 @@ def test_exception():
mmcv.dump(test_obj, 'tmp.txt')
def test_register_handler():
class TxtHandler(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
txt_handler = TxtHandler()
mmcv.register_handler(txt_handler, 'txt')
mmcv.register_handler(txt_handler, ['txt1', 'txt2'])
content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename, 'r') as f:
written = f.read()
os.remove(tmp_filename)
assert written == content
def test_list_from_file():
filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
filelist = mmcv.list_from_file(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment