Unverified Commit 32e09f49 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Feature] Upload checkpoints and logs to ceph (#1375)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* [Docs] Upload checkpoint to petrel oss

* add infer_client method

* Support uploading checkpoint to petrel oss

* add check_exist method

* refactor CheckpointHook

* support uploading logs to ceph

* rename var client to file_client

* polish docstring

* enhance load_from_ceph

* refactor load_from_ceph

* refactor TextLoggerHook

* change the meaning of out_dir argument

* fix test_checkpoint_hook.py

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* add a unittest for EvalHook when file backend is petrel

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* add warning in load_from_ceph

* fix type of out_suffix

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* fix CI

* mock check_exist

* improve docstring

* improve docstring

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* rename check_exists to exists

* refactor code and polish docstring

* fix windows ci

* add comment and polish docstring

* polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* fix typo

* fix error when mocking PetrelBackend

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* rename keep_log to keep_local

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* refactor

* polish information

* format information
parent ef022196
......@@ -11,6 +11,7 @@ from pathlib import Path
from typing import Iterable, Iterator, Optional, Tuple, Union
from urllib.request import urlopen
import mmcv
from mmcv.utils.misc import has_method
from mmcv.utils.path import is_filepath
......@@ -23,6 +24,17 @@ class BaseStorageBackend(metaclass=ABCMeta):
as texts.
"""
# a flag to indicate whether the backend can create a symlink for a file
_allow_symlink = False
@property
def name(self):
return self.__class__.__name__
@property
def allow_symlink(self):
return self._allow_symlink
@abstractmethod
def get(self, filepath):
pass
......@@ -41,8 +53,8 @@ class CephBackend(BaseStorageBackend):
will be replaced by ``dst``. Default: None.
.. warning::
:class:`CephBackend` will be deprecated, please use
:class:`PetrelBackend` instead
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
"""
def __init__(self, path_mapping=None):
......@@ -266,7 +278,7 @@ class PetrelBackend(BaseStorageBackend):
filepath = self._format_path(filepath)
return self._client.contains(filepath)
def concat_paths(self, filepath: Union[str, Path],
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
......@@ -377,7 +389,7 @@ class PetrelBackend(BaseStorageBackend):
# is a directory, because `self.isdir` relies on
# `self._client.list`
if path.endswith('/'): # a directory path
next_dir_path = self.concat_paths(dir_path, path)
next_dir_path = self.join_path(dir_path, path)
if list_dir:
# get the relative path and exclude the last
# character '/'
......@@ -388,7 +400,7 @@ class PetrelBackend(BaseStorageBackend):
list_file, suffix,
recursive)
else: # a file path
absolute_path = self.concat_paths(dir_path, path)
absolute_path = self.join_path(dir_path, path)
rel_path = absolute_path[len(root):]
if (suffix is None
or rel_path.endswith(suffix)) and list_file:
......@@ -491,6 +503,8 @@ class LmdbBackend(BaseStorageBackend):
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
_allow_symlink = True
def get(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
......@@ -524,10 +538,15 @@ class HardDiskBackend(BaseStorageBackend):
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'wb') as f:
f.write(obj)
......@@ -537,12 +556,17 @@ class HardDiskBackend(BaseStorageBackend):
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
mmcv.mkdir_or_exist(osp.dirname(filepath))
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)
......@@ -579,7 +603,7 @@ class HardDiskBackend(BaseStorageBackend):
return osp.isdir(filepath)
def isfile(self, filepath: Union[str, Path]) -> bool:
"""Check a ``filepath`` whether it is a file.
"""Check whether a file path is a file.
Args:
filepath (str or Path): Path to be checked whether it is a file.
......@@ -590,7 +614,7 @@ class HardDiskBackend(BaseStorageBackend):
"""
return osp.isfile(filepath)
def concat_paths(self, filepath: Union[str, Path],
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
......@@ -714,7 +738,7 @@ class FileClient:
Note that It can also register other backend accessor with a given name,
prefixes, and backend class. In addition, We use the singleton pattern to
avoid repeated object creation. If the arguments are the same, the same
object is returned.
object will be returned.
Args:
backend (str, optional): The storage backend type. Options are "disk",
......@@ -788,18 +812,21 @@ class FileClient:
_instance = super().__new__(cls)
if backend is not None:
_instance.client = cls._backends[backend](**kwargs)
_instance.backend_name = backend
else:
_instance.client = cls._prefix_to_backends[prefix](**kwargs)
# infer the backend name according to the prefix
for backend_name, backend_cls in cls._backends.items():
if isinstance(_instance.client, backend_cls):
_instance.backend_name = backend_name
break
cls._instances[arg_key] = _instance
return _instance
@property
def name(self):
return self.client.name
@property
def allow_symlink(self):
return self.client.allow_symlink
@staticmethod
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
"""Parse the prefix of a uri.
......@@ -980,6 +1007,10 @@ class FileClient:
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``put`` should create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
......@@ -989,6 +1020,10 @@ class FileClient:
def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``put_text`` should create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
......@@ -1041,7 +1076,7 @@ class FileClient:
"""
return self.client.isfile(filepath)
def concat_paths(self, filepath: Union[str, Path],
def join_path(self, filepath: Union[str, Path],
*filepaths: Union[str, Path]) -> str:
"""Concatenate all file paths.
......@@ -1054,7 +1089,7 @@ class FileClient:
Returns:
str: The result of concatenation.
"""
return self.client.concat_paths(filepath, *filepaths)
return self.client.join_path(filepath, *filepaths)
@contextmanager
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
......
......@@ -7,7 +7,6 @@ class BaseFileHandler(metaclass=ABCMeta):
# str-like object or bytes-like object. Pickle only processes bytes-like
# objects but json only processes str-like object. If it is str-like
# object, `StringIO` will be used to process the buffer.
str_like = True
@abstractmethod
......
......@@ -323,7 +323,7 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='ceph'):
def load_from_ceph(filename, map_location=None, backend='petrel'):
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
......@@ -331,19 +331,34 @@ def load_from_ceph(filename, map_location=None, backend='ceph'):
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are "disk", "ceph",
"memcached" and "lmdb". Default: 'ceph'
backend (str, optional): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
.. warning::
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
allowed_backends = ['ceph']
allowed_backends = ['ceph', 'petrel']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
fileclient = FileClient(backend=backend)
buffer = io.BytesIO(fileclient.get(filename))
if backend == 'ceph':
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
# successfully, the CephClient will be chosen.
try:
file_client = FileClient(backend=backend)
except ImportError:
allowed_backends.remove(backend)
file_client = FileClient(backend=allowed_backends[0])
with io.BytesIO(file_client.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
......@@ -506,7 +521,6 @@ def load_checkpoint(model,
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
......@@ -616,7 +630,11 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
......@@ -627,6 +645,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
if meta is None:
meta = {}
......@@ -654,6 +676,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'):
if file_client_args is not None:
raise ValueError(
'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}')
try:
from pavi import modelcloud
from pavi import exception
......@@ -674,8 +700,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
mmcv.mkdir_or_exist(osp.dirname(filename))
# immediately flush buffer
with open(filename, 'wb') as f:
file_client = FileClient.infer_client(file_client_args, filename)
with io.BytesIO() as f:
torch.save(checkpoint, f)
f.flush()
file_client.put(f.getvalue(), filename)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import warnings
from mmcv.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook
......@@ -18,16 +20,32 @@ class CheckpointHook(Hook):
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be saved
regardless of interval.
sync_buffer (bool): Whether to synchronize buffers in different
gpus. Default: False.
save_last (bool, optional): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
"""
def __init__(self,
......@@ -38,6 +56,7 @@ class CheckpointHook(Hook):
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
file_client_args=None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
......@@ -47,11 +66,39 @@ class CheckpointHook(Hook):
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner):
if not self.by_epoch:
return
......@@ -81,7 +128,7 @@ class CheckpointHook(Hook):
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = os.path.join(
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
......@@ -96,10 +143,10 @@ class CheckpointHook(Hook):
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_step))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import warnings
from math import inf
......@@ -8,6 +7,7 @@ import torch.distributed as dist
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from mmcv.fileio import FileClient
from mmcv.utils import is_seq_of
from .hook import Hook
from .logger import LoggerHook
......@@ -54,6 +54,14 @@ class EvalHook(Hook):
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
......@@ -84,6 +92,8 @@ class EvalHook(Hook):
test_fn=None,
greater_keys=None,
less_keys=None,
out_dir=None,
file_client_args=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
......@@ -137,6 +147,9 @@ class EvalHook(Hook):
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)
self.out_dir = out_dir
self.file_client_args = file_client_args
def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.
......@@ -187,6 +200,23 @@ class EvalHook(Hook):
self.compare_func = self.rule_map[self.rule]
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}'))
if self.save_best is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.')
......@@ -299,15 +329,20 @@ class EvalHook(Hook):
best_score = key_score
runner.meta['hook_msgs']['best_score'] = best_score
if self.best_ckpt_path and osp.isfile(self.best_ckpt_path):
os.remove(self.best_ckpt_path)
if self.best_ckpt_path and self.file_client.isfile(
self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
runner.logger.info(
(f'The previous best checkpoint {self.best_ckpt_path} was '
'removed'))
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name)
self.best_ckpt_path = self.file_client.join_path(
self.out_dir, best_ckpt_name)
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False)
self.out_dir, best_ckpt_name, create_symlink=False)
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(
......@@ -378,6 +413,12 @@ class DistEvalHook(EvalHook):
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
......@@ -395,6 +436,8 @@ class DistEvalHook(EvalHook):
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
out_dir=None,
file_client_args=None,
**eval_kwargs):
if test_fn is None:
......@@ -411,6 +454,8 @@ class DistEvalHook(EvalHook):
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
out_dir=out_dir,
file_client_args=file_client_args,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
......
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os
import os.path as osp
from collections import OrderedDict
......@@ -7,6 +8,8 @@ import torch
import torch.distributed as dist
import mmcv
from mmcv.fileio.file_client import FileClient
from mmcv.utils import is_tuple_of, scandir
from ..hook import HOOKS
from .base import LoggerHook
......@@ -19,14 +22,34 @@ class TextLoggerHook(LoggerHook):
saved in json file.
Args:
by_epoch (bool): Whether EpochBasedRunner is used.
interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging.
interval_exp_name (int): Logging interval for experiment name. This
feature is to help users conveniently get the experiment
by_epoch (bool, optional): Whether EpochBasedRunner is used.
Default: True.
interval (int, optional): Logging interval (every k iterations).
Default: 10.
ignore_last (bool, optional): Ignore the log of last iterations in each
epoch if less than :attr:`interval`. Default: True.
reset_flag (bool, optional): Whether to clear the output buffer after
logging. Default: False.
interval_exp_name (int, optional): Logging interval for experiment
name. This feature is to help users conveniently get the experiment
information from screen or log file. Default: 1000.
out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
If ``out_dir`` is specified, logs will be copied to a new directory
which is the concatenation of ``out_dir`` and the last level
directory of ``runner.work_dir``. Default: None.
`New in version 1.3.16.`
out_suffix (str or tuple[str], optional): Those filenames ending with
``out_suffix`` will be copied to ``out_dir``.
Default: ('.log.json', '.log', '.py').
`New in version 1.3.16.`
keep_local (bool, optional): Whether to keep local log when
:attr:`out_dir` is specified. If False, the local log will be
removed. Default: True.
`New in version 1.3.16.`
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
def __init__(self,
......@@ -34,15 +57,49 @@ class TextLoggerHook(LoggerHook):
interval=10,
ignore_last=True,
reset_flag=False,
interval_exp_name=1000):
interval_exp_name=1000,
out_dir=None,
out_suffix=('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
if out_dir is None and file_client_args is not None:
raise ValueError(
'file_client_args should be "None" when `out_dir` is not'
'specified.')
self.out_dir = out_dir
if not (out_dir is None or isinstance(out_dir, str)
or is_tuple_of(out_dir, str)):
raise TypeError('out_dir should be "None" or string or tuple of '
'string, but got {out_dir}')
self.out_suffix = out_suffix
self.keep_local = keep_local
self.file_client_args = file_client_args
if self.out_dir is not None:
self.file_client = FileClient.infer_client(file_client_args,
self.out_dir)
def before_run(self, runner):
super(TextLoggerHook, self).before_run(runner)
if self.out_dir is not None:
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# The final `self.out_dir` is the concatenation of `self.out_dir`
# and the last level directory of `runner.work_dir`
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.'))
self.start_iter = runner.iter
self.json_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
......@@ -177,3 +234,23 @@ class TextLoggerHook(LoggerHook):
self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
return log_dict
def after_run(self, runner):
# copy or upload logs to self.out_dir
if self.out_dir is not None:
for filename in scandir(runner.work_dir, self.out_suffix, True):
local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path(
self.out_dir, filename)
with open(local_filepath, 'r') as f:
self.file_client.put_text(f.read(), out_filepath)
runner.logger.info(
(f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.'))
if not self.keep_local:
os.remove(local_filepath)
runner.logger.info(
(f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
......@@ -132,6 +132,10 @@ class TestFileClient:
def test_disk_backend(self):
disk_backend = FileClient('disk')
# test `name` attribute
assert disk_backend.name == 'HardDiskBackend'
# test `allow_symlink` attribute
assert disk_backend.allow_symlink
# test `get`
# input path is Path object
img_bytes = disk_backend.get(self.img_path)
......@@ -157,11 +161,19 @@ class TestFileClient:
filepath1 = Path(tmp_dir) / 'test.jpg'
disk_backend.put(b'disk', filepath1)
assert filepath1.open('rb').read() == b'disk'
# test the `mkdir_or_exist` behavior in `put`
_filepath1 = Path(tmp_dir) / 'not_existed_dir1' / 'test.jpg'
disk_backend.put(b'disk', _filepath1)
assert _filepath1.open('rb').read() == b'disk'
# test `put_text`
filepath2 = Path(tmp_dir) / 'test.txt'
disk_backend.put_text('disk', filepath2)
assert filepath2.open('r').read() == 'disk'
# test the `mkdir_or_exist` behavior in `put_text`
_filepath2 = Path(tmp_dir) / 'not_existed_dir2' / 'test.txt'
disk_backend.put_text('disk', _filepath2)
assert _filepath2.open('r').read() == 'disk'
# test `isfile`
assert disk_backend.isfile(filepath2)
......@@ -179,11 +191,11 @@ class TestFileClient:
assert str(filepath1) == path
assert osp.isfile(filepath1)
# test `concat_paths`
# test `join_path`
disk_dir = '/path/of/your/directory'
assert disk_backend.concat_paths(disk_dir, 'file') == \
assert disk_backend.join_path(disk_dir, 'file') == \
osp.join(disk_dir, 'file')
assert disk_backend.concat_paths(disk_dir, 'dir', 'file') == \
assert disk_backend.join_path(disk_dir, 'dir', 'file') == \
osp.join(disk_dir, 'dir', 'file')
# test `list_dir_or_file`
......@@ -268,6 +280,9 @@ class TestFileClient:
def test_ceph_backend(self):
ceph_backend = FileClient('ceph')
# test `allow_symlink` attribute
assert not ceph_backend.allow_symlink
# input path is Path object
with pytest.raises(NotImplementedError):
ceph_backend.get_text(self.text_path)
......@@ -305,6 +320,9 @@ class TestFileClient:
def test_petrel_backend(self, backend, prefix):
petrel_backend = FileClient(backend=backend, prefix=prefix)
# test `allow_symlink` attribute
assert not petrel_backend.allow_symlink
# input path is Path object
img_bytes = petrel_backend.get(self.img_path)
img = mmcv.imfrombytes(img_bytes)
......@@ -415,12 +433,12 @@ class TestFileClient:
assert petrel_backend.isfile(petrel_path)
mock_contains.assert_called_once_with(petrel_path)
# test `concat_paths`
assert petrel_backend.concat_paths(petrel_dir, 'file') == \
# test `join_path`
assert petrel_backend.join_path(petrel_dir, 'file') == \
f'{petrel_dir}/file'
assert petrel_backend.concat_paths(f'{petrel_dir}/', 'file') == \
assert petrel_backend.join_path(f'{petrel_dir}/', 'file') == \
f'{petrel_dir}/file'
assert petrel_backend.concat_paths(petrel_dir, 'dir', 'file') == \
assert petrel_backend.join_path(petrel_dir, 'dir', 'file') == \
f'{petrel_dir}/dir/file'
# test `get_local_path`
......@@ -528,6 +546,9 @@ class TestFileClient:
mc_cfg = dict(server_list_cfg='', client_cfg='', sys_path=None)
mc_backend = FileClient('memcached', **mc_cfg)
# test `allow_symlink` attribute
assert not mc_backend.allow_symlink
# input path is Path object
with pytest.raises(NotImplementedError):
mc_backend.get_text(self.text_path)
......@@ -550,6 +571,9 @@ class TestFileClient:
# db_path is Path object
lmdb_backend = FileClient('lmdb', db_path=lmdb_path)
# test `allow_symlink` attribute
assert not lmdb_backend.allow_symlink
with pytest.raises(NotImplementedError):
lmdb_backend.get_text(self.text_path)
......@@ -574,6 +598,9 @@ class TestFileClient:
text_url = 'https://raw.githubusercontent.com/open-mmlab/mmcv/' \
'master/tests/data/filelist.txt'
# test `allow_symlink` attribute
assert not http_backend.allow_symlink
# input is path or Path object
with pytest.raises(Exception):
http_backend.get(self.img_path)
......@@ -659,17 +686,17 @@ class TestFileClient:
# HardDiskBackend
file_client_args = {'backend': 'disk'}
client = FileClient.infer_client(file_client_args)
assert client.backend_name == 'disk'
assert client.name == 'HardDiskBackend'
client = FileClient.infer_client(uri=self.img_path)
assert client.backend_name == 'disk'
assert client.name == 'HardDiskBackend'
# PetrelBackend
file_client_args = {'backend': 'petrel'}
client = FileClient.infer_client(file_client_args)
assert client.backend_name == 'petrel'
assert client.name == 'PetrelBackend'
uri = 's3://user_data'
client = FileClient.infer_client(uri=uri)
assert client.backend_name == 'petrel'
assert client.name == 'PetrelBackend'
def test_register_backend(self):
......
import sys
from collections import OrderedDict
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DataParallel
from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint,
load_from_pavi)
load_from_pavi, save_checkpoint)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
@MODULE_WRAPPERS.register_module()
......@@ -392,3 +397,36 @@ def test_checkpoint_loader():
filename = 'a/b/c/d'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_abc'
def test_save_checkpoint(tmp_path):
model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# meta is not a dict
with pytest.raises(TypeError):
save_checkpoint(model, '/path/of/your/filename', meta='invalid type')
# 1. save to disk
filename = str(tmp_path / 'checkpoint1.pth')
save_checkpoint(model, filename)
filename = str(tmp_path / 'checkpoint2.pth')
save_checkpoint(model, filename, optimizer)
filename = str(tmp_path / 'checkpoint3.pth')
save_checkpoint(model, filename, meta={'test': 'test'})
filename = str(tmp_path / 'checkpoint4.pth')
save_checkpoint(model, filename, file_client_args={'backend': 'disk'})
# 2. save to petrel oss
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path/of/your/checkpoint1.pth'
save_checkpoint(model, filename)
mock_method.assert_called()
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path//of/your/checkpoint2.pth'
save_checkpoint(
model, filename, file_client_args={'backend': 'petrel'})
mock_method.assert_called()
import json
import os.path as osp
import sys
import tempfile
import unittest.mock as mock
from collections import OrderedDict
......@@ -11,12 +12,16 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EpochBasedRunner
from mmcv.runner import EvalHook as BaseEvalHook
from mmcv.runner import IterBasedRunner
from mmcv.utils import get_logger, scandir
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
class ExampleDataset(Dataset):
......@@ -298,6 +303,34 @@ def test_eval_hook():
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3
# test EvalHook with specified `out_dir`
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())
out_dir = 's3://user/data'
eval_hook = EvalHook(
data_loader, interval=1, save_best='auto', out_dir=out_dir)
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile, \
tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
ckpt_path = f'{out_dir}/{basename}/best_acc_epoch_4.pth'
assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert runner.meta['hook_msgs']['best_score'] == 7
assert mock_put.call_count == 3
assert mock_remove.call_count == 2
assert mock_isfile.call_count == 2
@patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock)
......
......@@ -12,7 +12,7 @@ import re
import shutil
import sys
import tempfile
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch
import pytest
import torch
......@@ -20,6 +20,7 @@ import torch.nn as nn
from torch.nn.init import constant_
from torch.utils.data import DataLoader
from mmcv.fileio.file_client import PetrelBackend
from mmcv.runner import (CheckpointHook, DvcliveLoggerHook, EMAHook,
Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
......@@ -34,8 +35,11 @@ from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
OneCycleLrUpdaterHook,
StepLrUpdaterHook)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
def test_checkpoint_hook():
def test_checkpoint_hook(tmp_path):
"""xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""
# test epoch based runner
......@@ -49,6 +53,25 @@ def test_checkpoint_hook():
runner.work_dir, 'epoch_1.pth')
shutil.rmtree(runner.work_dir)
# test petrel oss when the type of runner is `EpochBasedRunner`
runner = _build_demo_runner('EpochBasedRunner', max_epochs=4)
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=True, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'epoch_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
shutil.rmtree(runner.work_dir)
# test iter based runner
runner = _build_demo_runner(
'IterBasedRunner', max_iters=1, max_epochs=None)
......@@ -60,6 +83,26 @@ def test_checkpoint_hook():
runner.work_dir, 'iter_1.pth')
shutil.rmtree(runner.work_dir)
# test petrel oss when the type of runner is `IterBasedRunner`
runner = _build_demo_runner(
'IterBasedRunner', max_iters=4, max_epochs=None)
runner.meta = dict()
out_dir = 's3://user/data'
with patch.object(PetrelBackend, 'put') as mock_put, \
patch.object(PetrelBackend, 'remove') as mock_remove, \
patch.object(PetrelBackend, 'isfile') as mock_isfile:
checkpointhook = CheckpointHook(
interval=1, out_dir=out_dir, by_epoch=False, max_keep_ckpts=2)
runner.register_hook(checkpointhook)
runner.run([loader], [('train', 1)])
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
assert runner.meta['hook_msgs']['last_ckpt'] == \
'/'.join([out_dir, basename, 'iter_4.pth'])
mock_put.assert_called()
mock_remove.assert_called()
mock_isfile.assert_called()
shutil.rmtree(runner.work_dir)
def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment