Commit 34eca272 authored by Kai Chen's avatar Kai Chen
Browse files

add register_handler and refactor BaseFileHandler

parent 3e1befe1
from .io import load, dump from .io import load, dump, register_handler
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .parse import list_from_file, dict_from_file from .parse import list_from_file, dict_from_file
__all__ = [ __all__ = [
'load', 'dump', 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'load', 'dump', 'register_handler', 'BaseFileHandler', 'JsonHandler',
'YamlHandler', 'list_from_file', 'dict_from_file' 'PickleHandler', 'YamlHandler', 'list_from_file', 'dict_from_file'
] ]
...@@ -3,29 +3,24 @@ from abc import ABCMeta, abstractmethod ...@@ -3,29 +3,24 @@ from abc import ABCMeta, abstractmethod
class BaseFileHandler(object): class BaseFileHandler(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta # python 2 compatibility
@staticmethod
@abstractmethod @abstractmethod
def load_from_path(filepath, **kwargs): def load_from_fileobj(self, file, **kwargs):
pass pass
@staticmethod
@abstractmethod @abstractmethod
def load_from_fileobj(file, **kwargs): def dump_to_fileobj(self, obj, file, **kwargs):
pass pass
@staticmethod
@abstractmethod @abstractmethod
def dump_to_str(obj, **kwargs): def dump_to_str(self, obj, **kwargs):
pass pass
@staticmethod def load_from_path(self, filepath, mode='r', **kwargs):
@abstractmethod with open(filepath, mode) as f:
def dump_to_path(obj, filepath, **kwargs): return self.load_from_fileobj(f, **kwargs)
pass
@staticmethod def dump_to_path(self, obj, filepath, mode='w', **kwargs):
@abstractmethod with open(filepath, mode) as f:
def dump_to_fileobj(obj, file, **kwargs): self.dump_to_fileobj(obj, f, **kwargs)
pass
...@@ -5,25 +5,11 @@ from .base import BaseFileHandler ...@@ -5,25 +5,11 @@ from .base import BaseFileHandler
class JsonHandler(BaseFileHandler): class JsonHandler(BaseFileHandler):
@staticmethod def load_from_fileobj(self, file):
def load_from_path(filepath):
with open(filepath, 'r') as f:
obj = json.load(f)
return obj
@staticmethod
def load_from_fileobj(file):
return json.load(file) return json.load(file)
@staticmethod def dump_to_fileobj(self, obj, file, **kwargs):
def dump_to_str(obj, **kwargs):
return json.dumps(obj, **kwargs)
@staticmethod
def dump_to_path(obj, filepath, **kwargs):
with open(filepath, 'w') as f:
json.dump(obj, f, **kwargs)
@staticmethod
def dump_to_fileobj(obj, file, **kwargs):
json.dump(obj, file, **kwargs) json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
return json.dumps(obj, **kwargs)
...@@ -5,28 +5,21 @@ from .base import BaseFileHandler ...@@ -5,28 +5,21 @@ from .base import BaseFileHandler
class PickleHandler(BaseFileHandler): class PickleHandler(BaseFileHandler):
@staticmethod def load_from_fileobj(self, file, **kwargs):
def load_from_path(filepath, **kwargs):
with open(filepath, 'rb') as f:
obj = pickle.load(f, **kwargs)
return obj
@staticmethod
def load_from_fileobj(file, **kwargs):
return pickle.load(file, **kwargs) return pickle.load(file, **kwargs)
@staticmethod def load_from_path(self, filepath, **kwargs):
def dump_to_str(obj, **kwargs): return super(PickleHandler, self).load_from_path(
kwargs.setdefault('protocol', 2) filepath, mode='rb', **kwargs)
return pickle.dumps(obj, **kwargs)
@staticmethod def dump_to_str(self, obj, **kwargs):
def dump_to_path(obj, filepath, **kwargs):
kwargs.setdefault('protocol', 2) kwargs.setdefault('protocol', 2)
with open(filepath, 'wb') as f: return pickle.dumps(obj, **kwargs)
pickle.dump(obj, f, **kwargs)
@staticmethod def dump_to_fileobj(self, obj, file, **kwargs):
def dump_to_fileobj(obj, file, **kwargs):
kwargs.setdefault('protocol', 2) kwargs.setdefault('protocol', 2)
pickle.dump(obj, file, **kwargs) pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)
...@@ -9,30 +9,14 @@ from .base import BaseFileHandler ...@@ -9,30 +9,14 @@ from .base import BaseFileHandler
class YamlHandler(BaseFileHandler): class YamlHandler(BaseFileHandler):
@staticmethod def load_from_fileobj(self, file, **kwargs):
def load_from_path(filepath, **kwargs):
kwargs.setdefault('Loader', Loader)
with open(filepath, 'r') as f:
obj = yaml.load(f, **kwargs)
return obj
@staticmethod
def load_from_fileobj(file, **kwargs):
kwargs.setdefault('Loader', Loader) kwargs.setdefault('Loader', Loader)
return yaml.load(file, **kwargs) return yaml.load(file, **kwargs)
@staticmethod def dump_to_fileobj(self, obj, file, **kwargs):
def dump_to_str(obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
return yaml.dump(obj, **kwargs)
@staticmethod
def dump_to_path(obj, filepath, **kwargs):
kwargs.setdefault('Dumper', Dumper) kwargs.setdefault('Dumper', Dumper)
with open(filepath, 'w') as f: yaml.dump(obj, file, **kwargs)
yaml.dump(obj, f, **kwargs)
@staticmethod def dump_to_str(self, obj, **kwargs):
def dump_to_fileobj(obj, file, **kwargs):
kwargs.setdefault('Dumper', Dumper) kwargs.setdefault('Dumper', Dumper)
yaml.dump(obj, file, **kwargs) return yaml.dump(obj, **kwargs)
from .handlers import JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from ..utils import is_str from ..utils import is_str, is_list_of
file_handlers = { file_handlers = {
'json': JsonHandler, 'json': JsonHandler(),
'yaml': YamlHandler, 'yaml': YamlHandler(),
'yml': YamlHandler, 'yml': YamlHandler(),
'pickle': PickleHandler, 'pickle': PickleHandler(),
'pkl': PickleHandler 'pkl': PickleHandler()
} }
...@@ -74,3 +74,23 @@ def dump(obj, file=None, file_format=None, **kwargs): ...@@ -74,3 +74,23 @@ def dump(obj, file=None, file_format=None, **kwargs):
handler.dump_to_fileobj(obj, file, **kwargs) handler.dump_to_fileobj(obj, file, **kwargs)
else: else:
raise TypeError('"file" must be a filename str or a file-object') raise TypeError('"file" must be a filename str or a file-object')
def register_handler(handler, file_formats):
"""Register a handler for some file extensions.
Args:
handler (:obj:`BaseFileHandler`): Handler to be registered.
file_formats (str or list[str]): File formats to be handled by this
handler.
"""
if not isinstance(handler, BaseFileHandler):
raise TypeError(
'handler must be a child of BaseFileHandler, not {}'.format(
type(handler)))
if isinstance(file_formats, str):
file_formats = [file_formats]
if not is_list_of(file_formats, str):
raise TypeError('file_formats must be a str or a list of str')
for ext in file_formats:
file_handlers[ext] = handler
...@@ -81,6 +81,33 @@ def test_exception(): ...@@ -81,6 +81,33 @@ def test_exception():
mmcv.dump(test_obj, 'tmp.txt') mmcv.dump(test_obj, 'tmp.txt')
def test_register_handler():
class TxtHandler(mmcv.BaseFileHandler):
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
txt_handler = TxtHandler()
mmcv.register_handler(txt_handler, 'txt')
mmcv.register_handler(txt_handler, ['txt1', 'txt2'])
content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
tmp_filename = osp.join(tempfile.gettempdir(), 'mmcv_test.txt2')
mmcv.dump(content, tmp_filename)
with open(tmp_filename, 'r') as f:
written = f.read()
os.remove(tmp_filename)
assert written == content
def test_list_from_file(): def test_list_from_file():
filename = osp.join(osp.dirname(__file__), 'data/filelist.txt') filename = osp.join(osp.dirname(__file__), 'data/filelist.txt')
filelist = mmcv.list_from_file(filename) filelist = mmcv.list_from_file(filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment