Unverified Commit 569d588e authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/fileio (#1997)



* Add typehint in mmcv/fileio/*

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/parse.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/parse.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/parse.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/fileio/io.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix hint bugs
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent e234c183
...@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta): ...@@ -21,10 +21,10 @@ class BaseFileHandler(metaclass=ABCMeta):
def dump_to_str(self, obj, **kwargs): def dump_to_str(self, obj, **kwargs):
pass pass
def load_from_path(self, filepath, mode='r', **kwargs): def load_from_path(self, filepath: str, mode: str = 'r', **kwargs):
with open(filepath, mode) as f: with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs) return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode='w', **kwargs): def dump_to_path(self, obj, filepath: str, mode: str = 'w', **kwargs):
with open(filepath, mode) as f: with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs) self.dump_to_fileobj(obj, f, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from io import BytesIO, StringIO from io import BytesIO, StringIO
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from ..utils import is_list_of, is_str from ..utils import is_list_of
from .file_client import FileClient from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
FileLikeObject = Union[StringIO, BytesIO]
file_handlers = { file_handlers = {
'json': JsonHandler(), 'json': JsonHandler(),
'yaml': YamlHandler(), 'yaml': YamlHandler(),
...@@ -15,7 +18,10 @@ file_handlers = { ...@@ -15,7 +18,10 @@ file_handlers = {
} }
def load(file, file_format=None, file_client_args=None, **kwargs): def load(file: Union[str, Path, FileLikeObject],
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Load data from json/yaml/pickle files. """Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files. This method provides a unified api for loading data from serialized files.
...@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs): ...@@ -45,13 +51,14 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
""" """
if isinstance(file, Path): if isinstance(file, Path):
file = str(file) file = str(file)
if file_format is None and is_str(file): if file_format is None and isinstance(file, str):
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}') raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format] handler = file_handlers[file_format]
if is_str(file): f: FileLikeObject
if isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file) file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like: if handler.str_like:
with StringIO(file_client.get_text(file)) as f: with StringIO(file_client.get_text(file)) as f:
...@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs): ...@@ -66,7 +73,11 @@ def load(file, file_format=None, file_client_args=None, **kwargs):
return obj return obj
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): def dump(obj: Any,
file: Optional[Union[str, Path, FileLikeObject]] = None,
file_format: Optional[str] = None,
file_client_args: Optional[Dict] = None,
**kwargs):
"""Dump data to json/yaml/pickle strings or files. """Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files, This method provides a unified api for dumping data as strings or to files,
...@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): ...@@ -96,18 +107,18 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
if isinstance(file, Path): if isinstance(file, Path):
file = str(file) file = str(file)
if file_format is None: if file_format is None:
if is_str(file): if isinstance(file, str):
file_format = file.split('.')[-1] file_format = file.split('.')[-1]
elif file is None: elif file is None:
raise ValueError( raise ValueError(
'file_format must be specified since file is None') 'file_format must be specified since file is None')
if file_format not in file_handlers: if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}') raise TypeError(f'Unsupported format: {file_format}')
f: FileLikeObject
handler = file_handlers[file_format] handler = file_handlers[file_format]
if file is None: if file is None:
return handler.dump_to_str(obj, **kwargs) return handler.dump_to_str(obj, **kwargs)
elif is_str(file): elif isinstance(file, str):
file_client = FileClient.infer_client(file_client_args, file) file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like: if handler.str_like:
with StringIO() as f: with StringIO() as f:
...@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs): ...@@ -123,7 +134,8 @@ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
raise TypeError('"file" must be a filename str or a file-object') raise TypeError('"file" must be a filename str or a file-object')
def _register_handler(handler, file_formats): def _register_handler(handler: BaseFileHandler,
file_formats: Union[str, List[str]]) -> None:
"""Register a handler for some file extensions. """Register a handler for some file extensions.
Args: Args:
...@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats): ...@@ -142,7 +154,7 @@ def _register_handler(handler, file_formats):
file_handlers[ext] = handler file_handlers[ext] = handler
def register_handler(file_formats, **kwargs): def register_handler(file_formats: Union[str, list], **kwargs) -> Callable:
def wrap(cls): def wrap(cls):
_register_handler(cls(**kwargs), file_formats) _register_handler(cls(**kwargs), file_formats)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from io import StringIO from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Union
from .file_client import FileClient from .file_client import FileClient
def list_from_file(filename, def list_from_file(filename: Union[str, Path],
prefix='', prefix: str = '',
offset=0, offset: int = 0,
max_num=0, max_num: int = 0,
encoding='utf-8', encoding: str = 'utf-8',
file_client_args=None): file_client_args: Optional[Dict] = None) -> List:
"""Load a text file and parse the content as a list of strings. """Load a text file and parse the content as a list of strings.
Note: Note:
...@@ -52,10 +54,10 @@ def list_from_file(filename, ...@@ -52,10 +54,10 @@ def list_from_file(filename,
return item_list return item_list
def dict_from_file(filename, def dict_from_file(filename: Union[str, Path],
key_type=str, key_type: type = str,
encoding='utf-8', encoding: str = 'utf-8',
file_client_args=None): file_client_args: Optional[Dict] = None) -> Dict:
"""Load a text file and parse the content as a dict. """Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by Each line of the text file will be two or more columns split by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment