"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "834d5978dfbbb9e6797a4524e4de37f3d3de6cd6"
Commit 6e8d925e authored by Kai Chen's avatar Kai Chen
Browse files

add a class decorator to register file handlers

parent d796c13b
import functools
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from ..utils import is_str, is_list_of from ..utils import is_str, is_list_of
...@@ -76,7 +78,7 @@ def dump(obj, file=None, file_format=None, **kwargs): ...@@ -76,7 +78,7 @@ def dump(obj, file=None, file_format=None, **kwargs):
raise TypeError('"file" must be a filename str or a file-object') raise TypeError('"file" must be a filename str or a file-object')
def register_handler(handler, file_formats): def _register_handler(handler, file_formats):
"""Register a handler for some file extensions. """Register a handler for some file extensions.
Args: Args:
...@@ -94,3 +96,12 @@ def register_handler(handler, file_formats): ...@@ -94,3 +96,12 @@ def register_handler(handler, file_formats):
raise TypeError('file_formats must be a str or a list of str') raise TypeError('file_formats must be a str or a list of str')
for ext in file_formats: for ext in file_formats:
file_handlers[ext] = handler file_handlers[ext] = handler
def register_handler(file_formats, **kwargs):
def wrap(cls):
_register_handler(cls(**kwargs), file_formats)
return cls
return wrap
...@@ -83,7 +83,8 @@ def test_exception(): ...@@ -83,7 +83,8 @@ def test_exception():
def test_register_handler(): def test_register_handler():
class TxtHandler(mmcv.BaseFileHandler): @mmcv.register_handler('txt')
class TxtHandler1(mmcv.BaseFileHandler):
def load_from_fileobj(self, file): def load_from_fileobj(self, file):
return file.read() return file.read()
...@@ -94,9 +95,18 @@ def test_register_handler(): ...@@ -94,9 +95,18 @@ def test_register_handler():
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
return str(obj) return str(obj)
txt_handler = TxtHandler() @mmcv.register_handler(['txt1', 'txt2'])
mmcv.register_handler(txt_handler, 'txt') class TxtHandler2(mmcv.BaseFileHandler):
mmcv.register_handler(txt_handler, ['txt1', 'txt2'])
def load_from_fileobj(self, file):
return file.read()
def dump_to_fileobj(self, obj, file):
file.write('\n')
file.write(str(obj))
def dump_to_str(self, obj, **kwargs):
return str(obj)
content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt')) content = mmcv.load(osp.join(osp.dirname(__file__), 'data/filelist.txt'))
assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg' assert content == '1.jpg\n2.jpg\n3.jpg\n4.jpg\n5.jpg'
...@@ -105,7 +115,7 @@ def test_register_handler(): ...@@ -105,7 +115,7 @@ def test_register_handler():
with open(tmp_filename, 'r') as f: with open(tmp_filename, 'r') as f:
written = f.read() written = f.read()
os.remove(tmp_filename) os.remove(tmp_filename)
assert written == content assert written == '\n' + content
def test_list_from_file(): def test_list_from_file():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment