Unverified Commit 32e09f49 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Feature] Upload checkpoints and logs to ceph (#1375)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* [Docs] Upload checkpoint to petrel oss

* add infer_client method

* Support uploading checkpoint to petrel oss

* add check_exist method

* refactor CheckpointHook

* support uploading logs to ceph

* rename var client to file_client

* polish docstring

* enhance load_from_ceph

* refactor load_from_ceph

* refactor TextLoggerHook

* change the meaning of out_dir argument

* fix test_checkpoint_hook.py

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* add a unittest for EvalHook when file backend is petrel

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* add warning in load_from_ceph

* fix type of out_suffix

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* fix CI

* mock check_exist

* improve docstring

* improve docstring

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* rename check_exists to exists

* refactor code and polish docstring

* fix windows ci

* add comment and polish docstring

* polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* fix typo

* fix error when mocking PetrelBackend

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* rename keep_log to keep_local

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* refactor

* polish information

* format information
parent ef022196
...@@ -11,6 +11,7 @@ from pathlib import Path ...@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union from typing import Iterable, Iterator, Optional, Tuple, Union
from urllib.request import urlopen from urllib.request import urlopen
import mmcv
from mmcv.utils.misc import has_method from mmcv.utils.misc import has_method
from mmcv.utils.path import is_filepath from mmcv.utils.path import is_filepath
...@@ -23,6 +24,17 @@ class BaseStorageBackend(metaclass=ABCMeta): ...@@ -23,6 +24,17 @@ class BaseStorageBackend(metaclass=ABCMeta):
as texts. as texts.
""" """
# a flag to indicate whether the backend can create a symlink for a file
_allow_symlink = False
@property
def name(self):
return self.__class__.__name__
@property
def allow_symlink(self):
return self._allow_symlink
@abstractmethod @abstractmethod
def get(self, filepath): def get(self, filepath):
pass pass
...@@ -41,8 +53,8 @@ class CephBackend(BaseStorageBackend): ...@@ -41,8 +53,8 @@ class CephBackend(BaseStorageBackend):
will be replaced by ``dst``. Default: None. will be replaced by ``dst``. Default: None.
.. warning:: .. warning::
:class:`CephBackend` will be deprecated, please use :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
:class:`PetrelBackend` instead please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
""" """
def __init__(self, path_mapping=None): def __init__(self, path_mapping=None):
...@@ -266,8 +278,8 @@ class PetrelBackend(BaseStorageBackend): ...@@ -266,8 +278,8 @@ class PetrelBackend(BaseStorageBackend):
filepath = self._format_path(filepath) filepath = self._format_path(filepath)
return self._client.contains(filepath) return self._client.contains(filepath)
def concat_paths(self, filepath: Union[str, Path], def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str: *filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths. """Concatenate all file paths.
Args: Args:
...@@ -377,7 +389,7 @@ class PetrelBackend(BaseStorageBackend): ...@@ -377,7 +389,7 @@ class PetrelBackend(BaseStorageBackend):
# is a directory, because `self.isdir` relies on # is a directory, because `self.isdir` relies on
# `self._client.list` # `self._client.list`
if path.endswith('/'): # a directory path if path.endswith('/'): # a directory path
next_dir_path = self.concat_paths(dir_path, path) next_dir_path = self.join_path(dir_path, path)
if list_dir: if list_dir:
# get the relative path and exclude the last # get the relative path and exclude the last
# character '/' # character '/'
...@@ -388,7 +400,7 @@ class PetrelBackend(BaseStorageBackend): ...@@ -388,7 +400,7 @@ class PetrelBackend(BaseStorageBackend):
list_file, suffix, list_file, suffix,
recursive) recursive)
else: # a file path else: # a file path
absolute_path = self.concat_paths(dir_path, path) absolute_path = self.join_path(dir_path, path)
rel_path = absolute_path[len(root):] rel_path = absolute_path[len(root):]
if (suffix is None if (suffix is None
or rel_path.endswith(suffix)) and list_file: or rel_path.endswith(suffix)) and list_file:
...@@ -491,6 +503,8 @@ class LmdbBackend(BaseStorageBackend): ...@@ -491,6 +503,8 @@ class LmdbBackend(BaseStorageBackend):
class HardDiskBackend(BaseStorageBackend): class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend.""" """Raw hard disks storage backend."""
_allow_symlink = True
def get(self, filepath: Union[str, Path]) -> bytes: def get(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode. """Read data from a given ``filepath`` with 'rb' mode.
...@@ -524,10 +538,15 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -524,10 +538,15 @@ class HardDiskBackend(BaseStorageBackend):
def put(self, obj: bytes, filepath: Union[str, Path]) -> None: def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode. """Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args: Args:
obj (bytes): Data to be written. obj (bytes): Data to be written.
filepath (str or Path): Path to write data. filepath (str or Path): Path to write data.
""" """
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'wb') as f: with open(filepath, 'wb') as f:
f.write(obj) f.write(obj)
...@@ -537,12 +556,17 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -537,12 +556,17 @@ class HardDiskBackend(BaseStorageBackend):
encoding: str = 'utf-8') -> None: encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode. """Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args: Args:
obj (str): Data to be written. obj (str): Data to be written.
filepath (str or Path): Path to write data. filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``. encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'. Default: 'utf-8'.
""" """
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'w', encoding=encoding) as f: with open(filepath, 'w', encoding=encoding) as f:
f.write(obj) f.write(obj)
...@@ -579,7 +603,7 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -579,7 +603,7 @@ class HardDiskBackend(BaseStorageBackend):
return osp.isdir(filepath) return osp.isdir(filepath)
def isfile(self, filepath: Union[str, Path]) -> bool: def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check a ``filepath`` whether it is a file. """Check whether a file path is a file.
Args: Args:
filepath (str or Path): Path to be checked whether it is a file. filepath (str or Path): Path to be checked whether it is a file.
...@@ -590,8 +614,8 @@ class HardDiskBackend(BaseStorageBackend): ...@@ -590,8 +614,8 @@ class HardDiskBackend(BaseStorageBackend):
""" """
return osp.isfile(filepath) return osp.isfile(filepath)
def concat_paths(self, filepath: Union[str, Path], def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str: *filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths. """Concatenate all file paths.
Join one or more filepath components intelligently. The return value Join one or more filepath components intelligently. The return value
...@@ -714,7 +738,7 @@ class FileClient: ...@@ -714,7 +738,7 @@ class FileClient:
Note that It can also register other backend accessor with a given name, Note that It can also register other backend accessor with a given name,
prefixes, and backend class. In addition, We use the singleton pattern to prefixes, and backend class. In addition, We use the singleton pattern to
avoid repeated object creation. If the arguments are the same, the same avoid repeated object creation. If the arguments are the same, the same
object is returned. object will be returned.
Args: Args:
backend (str, optional): The storage backend type. Options are "disk", backend (str, optional): The storage backend type. Options are "disk",
...@@ -788,18 +812,21 @@ class FileClient: ...@@ -788,18 +812,21 @@ class FileClient:
_instance = super().__new__(cls) _instance = super().__new__(cls)
if backend is not None: if backend is not None:
_instance.client = cls._backends[backend](**kwargs) _instance.client = cls._backends[backend](**kwargs)
_instance.backend_name = backend
else: else:
_instance.client = cls._prefix_to_backends[prefix](**kwargs) _instance.client = cls._prefix_to_backends[prefix](**kwargs)
# infer the backend name according to the prefix
for backend_name, backend_cls in cls._backends.items():
if isinstance(_instance.client, backend_cls):
_instance.backend_name = backend_name
break
cls._instances[arg_key] = _instance cls._instances[arg_key] = _instance
return _instance return _instance
@property
def name(self):
return self.client.name
@property
def allow_symlink(self):
return self.client.allow_symlink
@staticmethod @staticmethod
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
"""Parse the prefix of a uri. """Parse the prefix of a uri.
...@@ -980,6 +1007,10 @@ class FileClient: ...@@ -980,6 +1007,10 @@ class FileClient:
def put(self, obj: bytes, filepath: Union[str, Path]) -> None: def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode. """Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` should create a directory if the directory of ``filepath``
does not exist.
Args: Args:
obj (bytes): Data to be written. obj (bytes): Data to be written.
filepath (str or Path): Path to write data. filepath (str or Path): Path to write data.
...@@ -989,6 +1020,10 @@ class FileClient: ...@@ -989,6 +1020,10 @@ class FileClient:
def put_text(self, obj: str, filepath: Union[str, Path]) -> None: def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'w' mode. """Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` should create a directory if the directory of
``filepath`` does not exist.
Args: Args:
obj (str): Data to be written. obj (str): Data to be written.
filepath (str or Path): Path to write data. filepath (str or Path): Path to write data.
...@@ -1041,8 +1076,8 @@ class FileClient: ...@@ -1041,8 +1076,8 @@ class FileClient:
""" """
return self.client.isfile(filepath) return self.client.isfile(filepath)
def concat_paths(self, filepath: Union[str, Path], def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str: *filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths. """Concatenate all file paths.
Join one or more filepath components intelligently. The return value Join one or more filepath components intelligently. The return value
...@@ -1054,7 +1089,7 @@ class FileClient: ...@@ -1054,7 +1089,7 @@ class FileClient:
Returns: Returns:
str: The result of concatenation. str: The result of concatenation.
""" """
return self.client.concat_paths(filepath, *filepaths) return self.client.join_path(filepath, *filepaths)
@contextmanager @contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]: def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
......
...@@ -7,7 +7,6 @@ class BaseFileHandler(metaclass=ABCMeta): ...@@ -7,7 +7,6 @@ class BaseFileHandler(metaclass=ABCMeta):
# str-like object or bytes-like object. Pickle only processes bytes-like # str-like object or bytes-like object. Pickle only processes bytes-like
# objects but json only processes str-like object. If it is str-like # objects but json only processes str-like object. If it is str-like
# object, `StringIO` will be used to process the buffer. # object, `StringIO` will be used to process the buffer.
str_like = True str_like = True
@abstractmethod @abstractmethod
......
...@@ -323,7 +323,7 @@ def load_from_pavi(filename, map_location=None): ...@@ -323,7 +323,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='s3://') @CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'): def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed """load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
...@@ -331,20 +331,35 @@ def load_from_ceph(filename, map_location=None, backend='ceph'): ...@@ -331,20 +331,35 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
Args: Args:
filename (str): checkpoint file path with s3 prefix filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`. map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are "disk", "ceph", backend (str, optional): The storage backend type. Options are 'ceph',
"memcached" and "lmdb". Default: 'ceph' 'petrel'. Default: 'petrel'.
.. warning::
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
allowed_backends = ['ceph', 'petrel']
allowed_backends = ['ceph']
if backend not in allowed_backends: if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.') raise ValueError(f'Load from Backend {backend} is not supported.')
fileclient = FileClient(backend=backend) if backend == 'ceph':
buffer = io.BytesIO(fileclient.get(filename)) warnings.warn(
checkpoint = torch.load(buffer, map_location=map_location) 'CephBackend will be deprecated, please use PetrelBackend instead')
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
# successfully, the CephClient will be chosen.
try:
file_client = FileClient(backend=backend)
except ImportError:
allowed_backends.remove(backend)
file_client = FileClient(backend=allowed_backends[0])
with io.BytesIO(file_client.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint return checkpoint
...@@ -506,7 +521,6 @@ def load_checkpoint(model, ...@@ -506,7 +521,6 @@ def load_checkpoint(model,
pair of the regular expression operations. Default: strip pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')]. the prefix 'module.' by [(r'^module\\.', '')].
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
...@@ -616,7 +630,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): ...@@ -616,7 +630,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
return destination return destination
def save_checkpoint(model, filename, optimizer=None, meta=None): def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
"""Save checkpoint to file. """Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
...@@ -627,6 +645,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -627,6 +645,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
filename (str): Checkpoint filename. filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint. meta (dict, optional): Metadata to be saved in checkpoint.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
""" """
if meta is None: if meta is None:
meta = {} meta = {}
...@@ -654,6 +676,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -654,6 +676,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint['optimizer'][name] = optim.state_dict() checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'): if filename.startswith('pavi://'):
if file_client_args is not None:
raise ValueError(
'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}')
try: try:
from pavi import modelcloud from pavi import modelcloud
from pavi import exception from pavi import exception
...@@ -674,8 +700,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None): ...@@ -674,8 +700,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
f.flush() f.flush()
model.create_file(checkpoint_file, name=model_name) model.create_file(checkpoint_file, name=model_name)
else: else:
mmcv.mkdir_or_exist(osp.dirname(filename)) file_client = FileClient.infer_client(file_client_args, filename)
# immediately flush buffer with io.BytesIO() as f:
with open(filename, 'wb') as f:
torch.save(checkpoint, f) torch.save(checkpoint, f)
f.flush() file_client.put(f.getvalue(), filename)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os.path as osp
import warnings
from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -18,16 +20,32 @@ class CheckpointHook(Hook): ...@@ -18,16 +20,32 @@ class CheckpointHook(Hook):
save_optimizer (bool): Whether to save optimizer state_dict in the save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments. checkpoint. It is usually used for resuming experiments.
Default: True. Default: True.
out_dir (str, optional): The directory to save checkpoints. If not out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep. max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space. like to delete old ones to save the disk space.
Default: -1, which means unlimited. Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved save_last (bool, optional): Whether to force the last checkpoint to be
regardless of interval. saved regardless of interval. Default: True.
sync_buffer (bool): Whether to synchronize buffers in different sync_buffer (bool, optional): Whether to synchronize buffers in
gpus. Default: False. different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
""" """
def __init__(self, def __init__(self,
...@@ -38,6 +56,7 @@ class CheckpointHook(Hook): ...@@ -38,6 +56,7 @@ class CheckpointHook(Hook):
max_keep_ckpts=-1, max_keep_ckpts=-1,
save_last=True, save_last=True,
sync_buffer=False, sync_buffer=False,
file_client_args=None,
**kwargs): **kwargs):
self.interval = interval self.interval = interval
self.by_epoch = by_epoch self.by_epoch = by_epoch
...@@ -47,11 +66,39 @@ class CheckpointHook(Hook): ...@@ -47,11 +66,39 @@ class CheckpointHook(Hook):
self.save_last = save_last self.save_last = save_last
self.args = kwargs self.args = kwargs
self.sync_buffer = sync_buffer self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner): def before_run(self, runner):
if not self.out_dir: if not self.out_dir:
self.out_dir = runner.work_dir self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
if not self.by_epoch: if not self.by_epoch:
return return
...@@ -81,7 +128,7 @@ class CheckpointHook(Hook): ...@@ -81,7 +128,7 @@ class CheckpointHook(Hook):
cur_ckpt_filename = self.args.get( cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict()) runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = os.path.join( runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) self.out_dir, cur_ckpt_filename)
# remove other checkpoints # remove other checkpoints
if self.max_keep_ckpts > 0: if self.max_keep_ckpts > 0:
...@@ -96,10 +143,10 @@ class CheckpointHook(Hook): ...@@ -96,10 +143,10 @@ class CheckpointHook(Hook):
-self.interval) -self.interval)
filename_tmpl = self.args.get('filename_tmpl', name) filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts: for _step in redundant_ckpts:
ckpt_path = os.path.join(self.out_dir, ckpt_path = self.file_client.join_path(
filename_tmpl.format(_step)) self.out_dir, filename_tmpl.format(_step))
if os.path.exists(ckpt_path): if self.file_client.isfile(ckpt_path):
os.remove(ckpt_path) self.file_client.remove(ckpt_path)
else: else:
break break
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp import os.path as osp
import warnings import warnings
from math import inf from math import inf
...@@ -8,6 +7,7 @@ import torch.distributed as dist ...@@ -8,6 +7,7 @@ import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.fileio import FileClient
from mmcv.utils import is_seq_of from mmcv.utils import is_seq_of
from .hook import Hook from .hook import Hook
from .logger import LoggerHook from .logger import LoggerHook
...@@ -54,6 +54,14 @@ class EvalHook(Hook): ...@@ -54,6 +54,14 @@ class EvalHook(Hook):
less_keys (List[str] | None, optional): Metric keys that will be less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``) will be used. (default: ``None``)
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of **eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset. the dataset.
...@@ -84,6 +92,8 @@ class EvalHook(Hook): ...@@ -84,6 +92,8 @@ class EvalHook(Hook):
test_fn=None, test_fn=None,
greater_keys=None, greater_keys=None,
less_keys=None, less_keys=None,
out_dir=None,
file_client_args=None,
**eval_kwargs): **eval_kwargs):
if not isinstance(dataloader, DataLoader): if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, ' raise TypeError(f'dataloader must be a pytorch DataLoader, '
...@@ -137,6 +147,9 @@ class EvalHook(Hook): ...@@ -137,6 +147,9 @@ class EvalHook(Hook):
self.best_ckpt_path = None self.best_ckpt_path = None
self._init_rule(rule, self.save_best) self._init_rule(rule, self.save_best)
self.out_dir = out_dir
self.file_client_args = file_client_args
def _init_rule(self, rule, key_indicator): def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score. """Initialize rule, key_indicator, comparison_func, and best score.
...@@ -187,6 +200,23 @@ class EvalHook(Hook): ...@@ -187,6 +200,23 @@ class EvalHook(Hook):
self.compare_func = self.rule_map[self.rule] self.compare_func = self.rule_map[self.rule]
def before_run(self, runner): def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}'))
if self.save_best is not None: if self.save_best is not None:
if runner.meta is None: if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.') warnings.warn('runner.meta is None. Creating an empty one.')
...@@ -299,15 +329,20 @@ class EvalHook(Hook): ...@@ -299,15 +329,20 @@ class EvalHook(Hook):
best_score = key_score best_score = key_score
runner.meta['hook_msgs']['best_score'] = best_score runner.meta['hook_msgs']['best_score'] = best_score
if self.best_ckpt_path and osp.isfile(self.best_ckpt_path): if self.best_ckpt_path and self.file_client.isfile(
os.remove(self.best_ckpt_path) self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was '
'removed'))
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name) self.best_ckpt_path = self.file_client.join_path(
self.out_dir, best_ckpt_name)
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
runner.save_checkpoint( runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False) self.out_dir, best_ckpt_name, create_symlink=False)
runner.logger.info( runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.') f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info( runner.logger.info(
...@@ -378,6 +413,12 @@ class DistEvalHook(EvalHook): ...@@ -378,6 +413,12 @@ class DistEvalHook(EvalHook):
broadcast_bn_buffer (bool): Whether to broadcast the broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True. before evaluation. Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of **eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset. the dataset.
""" """
...@@ -395,6 +436,8 @@ class DistEvalHook(EvalHook): ...@@ -395,6 +436,8 @@ class DistEvalHook(EvalHook):
broadcast_bn_buffer=True, broadcast_bn_buffer=True,
tmpdir=None, tmpdir=None,
gpu_collect=False, gpu_collect=False,
out_dir=None,
file_client_args=None,
**eval_kwargs): **eval_kwargs):
if test_fn is None: if test_fn is None:
...@@ -411,6 +454,8 @@ class DistEvalHook(EvalHook): ...@@ -411,6 +454,8 @@ class DistEvalHook(EvalHook):
test_fn=test_fn, test_fn=test_fn,
greater_keys=greater_keys, greater_keys=greater_keys,
less_keys=less_keys, less_keys=less_keys,
out_dir=out_dir,
file_client_args=file_client_args,
**eval_kwargs) **eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer self.broadcast_bn_buffer = broadcast_bn_buffer
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import datetime import datetime
import os
import os.path as osp import os.path as osp
from collections import OrderedDict from collections import OrderedDict
...@@ -7,6 +8,8 @@ import torch ...@@ -7,6 +8,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import mmcv import mmcv
from mmcv.fileio.file_client import FileClient
from mmcv.utils import is_tuple_of, scandir
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -19,14 +22,34 @@ class TextLoggerHook(LoggerHook): ...@@ -19,14 +22,34 @@ class TextLoggerHook(LoggerHook):
saved in json file. saved in json file.
Args: Args:
by_epoch (bool): Whether EpochBasedRunner is used. by_epoch (bool, optional): Whether EpochBasedRunner is used.
interval (int): Logging interval (every k iterations). Default: True.
ignore_last (bool): Ignore the log of last iterations in each epoch interval (int, optional): Logging interval (every k iterations).
if less than `interval`. Default: 10.
reset_flag (bool): Whether to clear the output buffer after logging. ignore_last (bool, optional): Ignore the log of last iterations in each
interval_exp_name (int): Logging interval for experiment name. This epoch if less than :attr:`interval`. Default: True.
feature is to help users conveniently get the experiment reset_flag (bool, optional): Whether to clear the output buffer after
logging. Default: False.
interval_exp_name (int, optional): Logging interval for experiment
name. This feature is to help users conveniently get the experiment
information from screen or log file. Default: 1000. information from screen or log file. Default: 1000.
out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
If ``out_dir`` is specified, logs will be copied to a new directory
which is the concatenation of ``out_dir`` and the last level
directory of ``runner.work_dir``. Default: None.
`New in version 1.3.16.`
out_suffix (str or tuple[str], optional): Those filenames ending with
``out_suffix`` will be copied to ``out_dir``.
Default: ('.log.json', '.log', '.py').
`New in version 1.3.16.`
keep_local (bool, optional): Whether to keep local log when
:attr:`out_dir` is specified. If False, the local log will be
removed. Default: True.
`New in version 1.3.16.`
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
""" """
def __init__(self, def __init__(self,
...@@ -34,15 +57,49 @@ class TextLoggerHook(LoggerHook): ...@@ -34,15 +57,49 @@ class TextLoggerHook(LoggerHook):
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
interval_exp_name=1000): interval_exp_name=1000,
out_dir=None,
out_suffix=('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch) by_epoch)
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.time_sec_tot = 0 self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name self.interval_exp_name = interval_exp_name
if out_dir is None and file_client_args is not None:
raise ValueError(
'file_client_args should be "None" when `out_dir` is not'
'specified.')
self.out_dir = out_dir
if not (out_dir is None or isinstance(out_dir, str)
or is_tuple_of(out_dir, str)):
raise TypeError('out_dir should be "None" or string or tuple of '
'string, but got {out_dir}')
self.out_suffix = out_suffix
self.keep_local = keep_local
self.file_client_args = file_client_args
if self.out_dir is not None:
self.file_client = FileClient.infer_client(file_client_args,
self.out_dir)
def before_run(self, runner): def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner) super(TextLoggerHook, self).before_run(runner)
if self.out_dir is not None:
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# The final `self.out_dir` is the concatenation of `self.out_dir`
# and the last level directory of `runner.work_dir`
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.'))
self.start_iter = runner.iter self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir, self.json_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json') f'{runner.timestamp}.log.json')
...@@ -177,3 +234,23 @@ class TextLoggerHook(LoggerHook): ...@@ -177,3 +234,23 @@ class TextLoggerHook(LoggerHook):
self._log_info(log_dict, runner) self._log_info(log_dict, runner)
self._dump_log(log_dict, runner) self._dump_log(log_dict, runner)
return log_dict return log_dict
def after_run(self, runner):
# copy or upload logs to self.out_dir
if self.out_dir is not None:
for filename in scandir(runner.work_dir, self.out_suffix, True):
local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path(
self.out_dir, filename)
with open(local_filepath, 'r') as f:
self.file_client.put_text(f.read(), out_filepath)
runner.logger.info(
(f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.'))
if not self.keep_local:
os.remove(local_filepath)
runner.logger.info(
(f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
...@@ -132,6 +132,10 @@ class TestFileClient: ...@@ -132,6 +132,10 @@ class TestFileClient:
def test_disk_backend(self): def test_disk_backend(self):
disk_backend = FileClient('disk') disk_backend = FileClient('disk')
# test `name` attribute
assert disk_backend.name == 'HardDiskBackend'
# test `allow_symlink` attribute
assert disk_backend.allow_symlink
# test `get` # test `get`
# input path is Path object # input path is Path object
img_bytes = disk_backend.get(self.img_path) img_bytes = disk_backend.get(self.img_path)
...@@ -157,11 +161,19 @@ class TestFileClient: ...@@ -157,11 +161,19 @@ class TestFileClient:
filepath1 = Path(tmp_dir) / 'test.jpg' filepath1 = Path(tmp_dir) / 'test.jpg'
disk_backend.put(b'disk', filepath1) disk_backend.put(b'disk', filepath1)
assert filepath1.open('rb').read() == b'disk' assert filepath1.open('rb').read() == b'disk'
# test the `mkdir_or_exist` behavior in `put`
_filepath1 = Path(tmp_dir) / 'not_existed_dir1' / 'test.jpg'
disk_backend.put(b'disk', _filepath1)
assert _filepath1.open('rb').read() == b'disk'
# test `put_text` # test `put_text`
filepath2 = Path(tmp_dir) / 'test.txt' filepath2 = Path(tmp_dir) / 'test.txt'
disk_backend.put_text('disk', filepath2) disk_backend.put_text('disk', filepath2)
assert filepath2.open('r').read() == 'disk' assert filepath2.open('r').read() == 'disk'
# test the `mkdir_or_exist` behavior in `put_text`
_filepath2 = Path(tmp_dir) / 'not_existed_dir2' / 'test.txt'
disk_backend.put_text('disk', _filepath2)
assert _filepath2.open('r').read() == 'disk'
# test `isfile` # test `isfile`
assert disk_backend.isfile(filepath2) assert disk_backend.isfile(filepath2)
...@@ -179,11 +191,11 @@ class TestFileClient: ...@@ -179,11 +191,11 @@ class TestFileClient:
assert str(filepath1) == path assert str(filepath1) == path
assert osp.isfile(filepath1) assert osp.isfile(filepath1)
# test `concat_paths` # test `join_path`
disk_dir = '/path/of/your/directory' disk_dir = '/path/of/your/directory'
assert disk_backend.concat_paths(disk_dir, 'file') == \ assert disk_backend.join_path(disk_dir, 'file') == \
osp.join(disk_dir, 'file') osp.join(disk_dir, 'file')
assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \ assert disk_backend.join_path(disk_dir, 'dir', 'file') == \
osp.join(disk_dir, 'dir', 'file') osp.join(disk_dir, 'dir', 'file')
# test `list_dir_or_file` # test `list_dir_or_file`
...@@ -268,6 +280,9 @@ class TestFileClient: ...@@ -268,6 +280,9 @@ class TestFileClient:
def test_ceph_backend(self): def test_ceph_backend(self):
ceph_backend = FileClient('ceph') ceph_backend = FileClient('ceph')
# test `allow_symlink` attribute
assert not ceph_backend.allow_symlink
# input path is Path object # input path is Path object
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
ceph_backend.get_text(self.text_path) ceph_backend.get_text(self.text_path)
...@@ -305,6 +320,9 @@ class TestFileClient: ...@@ -305,6 +320,9 @@ class TestFileClient:
def test_petrel_backend(self, backend, prefix): def test_petrel_backend(self, backend, prefix):
petrel_backend = FileClient(backend=backend, prefix=prefix) petrel_backend = FileClient(backend=backend, prefix=prefix)
# test `allow_symlink` attribute
assert not petrel_backend.allow_symlink
# input path is Path object # input path is Path object
img_bytes = petrel_backend.get(self.img_path) img_bytes = petrel_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes) img = mmcv.imfrombytes(img_bytes)
...@@ -415,12 +433,12 @@ class TestFileClient: ...@@ -415,12 +433,12 @@ class TestFileClient:
assert petrel_backend.isfile(petrel_path) assert petrel_backend.isfile(petrel_path)
mock_contains.assert_called_once_with(petrel_path) mock_contains.assert_called_once_with(petrel_path)
# test `concat_paths` # test `join_path`
assert petrel_backend.concat_paths(petrel_dir, 'file') == \ assert petrel_backend.join_path(petrel_dir, 'file') == \
f'{petrel_dir}/file' f'{petrel_dir}/file'
assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \ assert petrel_backend.join_path(f'{petrel_dir}/', 'file') == \
f'{petrel_dir}/file' f'{petrel_dir}/file'
assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \ assert petrel_backend.join_path(petrel_dir, 'dir', 'file') == \
f'{petrel_dir}/dir/file' f'{petrel_dir}/dir/file'
# test `get_local_path` # test `get_local_path`
...@@ -528,6 +546,9 @@ class TestFileClient: ...@@ -528,6 +546,9 @@ class TestFileClient:
mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None) mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
mc_backend = FileClient('memcached', **mc_cfg) mc_backend = FileClient('memcached', **mc_cfg)
# test `allow_symlink` attribute
assert not mc_backend.allow_symlink
# input path is Path object # input path is Path object
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
mc_backend.get_text(self.text_path) mc_backend.get_text(self.text_path)
...@@ -550,6 +571,9 @@ class TestFileClient: ...@@ -550,6 +571,9 @@ class TestFileClient:
# db_path is Path object # db_path is Path object
lmdb_backend = FileClient('lmdb', db_path=lmdb_path) lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
# test `allow_symlink` attribute
assert not lmdb_backend.allow_symlink
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
lmdb_backend.get_text(self.text_path) lmdb_backend.get_text(self.text_path)
...@@ -574,6 +598,9 @@ class TestFileClient: ...@@ -574,6 +598,9 @@ class TestFileClient:
text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \ text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
'master/tests/data/filelist.txt' 'master/tests/data/filelist.txt'
# test `allow_symlink` attribute
assert not http_backend.allow_symlink
# input is path or Path object # input is path or Path object
with pytest.raises(Exception): with pytest.raises(Exception):
http_backend.get(self.img_path) http_backend.get(self.img_path)
...@@ -659,17 +686,17 @@ class TestFileClient: ...@@ -659,17 +686,17 @@ class TestFileClient:
# HardDiskBackend # HardDiskBackend
file_client_args = {'backend': 'disk'} file_client_args = {'backend': 'disk'}
client = FileClient.infer_client(file_client_args) client = FileClient.infer_client(file_client_args)
assert client.backend_name == 'disk' assert client.name == 'HardDiskBackend'
client = FileClient.infer_client(uri=self.img_path) client = FileClient.infer_client(uri=self.img_path)
assert client.backend_name == 'disk' assert client.name == 'HardDiskBackend'
# PetrelBackend # PetrelBackend
file_client_args = {'backend': 'petrel'} file_client_args = {'backend': 'petrel'}
client = FileClient.infer_client(file_client_args) client = FileClient.infer_client(file_client_args)
assert client.backend_name == 'petrel' assert client.name == 'PetrelBackend'
uri = 's3://user_data' uri = 's3://user_data'
client = FileClient.infer_client(uri=uri) client = FileClient.infer_client(uri=uri)
assert client.backend_name == 'petrel' assert client.name == 'PetrelBackend'
def test_register_backend(self): def test_register_backend(self):
......
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DataParallel from torch.nn.parallel import DataParallel
from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix, from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint, get_state_dict, load_checkpoint,
load_from_pavi) load_from_pavi, save_checkpoint)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
@MODULE_WRAPPERS.register_module() @MODULE_WRAPPERS.register_module()
...@@ -392,3 +397,36 @@ def test_checkpoint_loader(): ...@@ -392,3 +397,36 @@ def test_checkpoint_loader():
filename = 'a/b/c/d' filename = 'a/b/c/d'
loader = CheckpointLoader._get_checkpoint_loader(filename) loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_abc' assert loader.__name__ == 'load_from_abc'
def test_save_checkpoint(tmp_path):
model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# meta is not a dict
with pytest.raises(TypeError):
save_checkpoint(model, '/path/of/your/filename', meta='invalid type')
# 1. save to disk
filename = str(tmp_path / 'checkpoint1.pth')
save_checkpoint(model, filename)
filename = str(tmp_path / 'checkpoint2.pth')
save_checkpoint(model, filename, optimizer)
filename = str(tmp_path / 'checkpoint3.pth')
save_checkpoint(model, filename, meta={'test': 'test'})
filename = str(tmp_path / 'checkpoint4.pth')
save_checkpoint(model, filename, file_client_args={'backend': 'disk'})
# 2. save to petrel oss
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path/of/your/checkpoint1.pth'
save_checkpoint(model, filename)
mock_method.assert_called()
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path//of/your/checkpoint2.pth'
save_checkpoint(
model, filename, file_client_args={'backend': 'petrel'})
mock_method.assert_called()
import json import json
import os.path as osp import os.path as osp
import sys
import tempfile import tempfile
import unittest.mock as mock import unittest.mock as mock
from collections import OrderedDict from collections import OrderedDict
...@@ -11,12 +12,16 @@ import torch.nn as nn ...@@ -11,12 +12,16 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import DistEvalHook as BaseDistEvalHook from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EpochBasedRunner from mmcv.runner import EpochBasedRunner
from mmcv.runner import EvalHook as BaseEvalHook from mmcv.runner import EvalHook as BaseEvalHook
from mmcv.runner import IterBasedRunner from mmcv.runner import IterBasedRunner
from mmcv.utils import get_logger, scandir from mmcv.utils import get_logger, scandir
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
class ExampleDataset(Dataset): class ExampleDataset(Dataset):
...@@ -298,6 +303,34 @@ def test_eval_hook(): ...@@ -298,6 +303,34 @@ def test_eval_hook():
assert osp.exists(ckpt_path) assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3 assert runner.meta['hook_msgs']['best_score'] == -3
# test EvalHook with specified `out_dir`
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
out_dir = 's3://user/data'
eval_hook = EvalHook(
data_loader, interval=1, save_best='auto', out_dir=out_dir)
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile, \
tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
ckpt_path = f'{out_dir}/{basename}/best_acc_epoch_4.pth'
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert runner.meta['hook_msgs']['best_score'] == 7
assert mock_put.call_count == 3
assert mock_remove.call_count == 2
assert mock_isfile.call_count == 2
@patch('mmcv.engine.single_gpu_test', MagicMock) @patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock) @patch('mmcv.engine.multi_gpu_test', MagicMock)
......
...@@ -12,7 +12,7 @@ import re ...@@ -12,7 +12,7 @@ import re
import shutil import shutil
import sys import sys
import tempfile import tempfile
from unittest.mock import MagicMock, call from unittest.mock import MagicMock, call, patch
import pytest import pytest
import torch import torch
...@@ -20,6 +20,7 @@ import torch.nn as nn ...@@ -20,6 +20,7 @@ import torch.nn as nn
from torch.nn.init import constant_ from torch.nn.init import constant_
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook, from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
Fp16OptimizerHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
...@@ -34,8 +35,11 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook, ...@@ -34,8 +35,11 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook, OneCycleLrUpdaterHook,
StepLrUpdaterHook) StepLrUpdaterHook)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
def test_checkpoint_hook():
def test_checkpoint_hook(tmp_path):
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook.""" """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
# test epoch based runner # test epoch based runner
...@@ -49,6 +53,25 @@ def test_checkpoint_hook(): ...@@ -49,6 +53,25 @@ def test_checkpoint_hook():
runner.work_dir, 'epoch_1.pth') runner.work_dir, 'epoch_1.pth')
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
# test petrel oss when the type of runner is `EpochBasedRunner`
runner = _build_demo_runner('EpochBasedRunner', max_epochs=4)
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=True, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'epoch_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
shutil.rmtree(runner.work_dir)
# test iter based runner # test iter based runner
runner = _build_demo_runner( runner = _build_demo_runner(
'IterBasedRunner', max_iters=1, max_epochs=None) 'IterBasedRunner', max_iters=1, max_epochs=None)
...@@ -60,6 +83,26 @@ def test_checkpoint_hook(): ...@@ -60,6 +83,26 @@ def test_checkpoint_hook():
runner.work_dir, 'iter_1.pth') runner.work_dir, 'iter_1.pth')
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
# test petrel oss when the type of runner is `IterBasedRunner`
runner = _build_demo_runner(
'IterBasedRunner', max_iters=4, max_epochs=None)
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=False, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'iter_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
shutil.rmtree(runner.work_dir)
def test_ema_hook(): def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook.""" """xdoctest -m tests/test_hooks.py test_ema_hook."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment