Unverified Commit 9185eee8 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove runner, parallel, engine and device (#2216)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs
parent 19a02415
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Type, Union
import numpy as np
import torch
def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError(
f'{args[0].__class__.__name__} has no attribute '
f'{func.__name__} for type {args[0].datatype}')
return func(*args, **kwargs)
return wrapper
class DataContainer:
"""A container for any type of objects.
Typically tensors will be stacked in the collate function and sliced along
some dimension in the scatter function. This behavior has some limitations.
1. All tensors have to be the same size.
2. Types are limited (numpy array or Tensor).
We design `DataContainer` and `MMDataParallel` to overcome these
limitations. The behavior can be either of the following.
- copy to GPU, pad all tensors to the same size and stack them
- copy to GPU without stacking
- leave the objects as is and pass it to the model
- pad_dims specifies the number of last few dimensions to do padding
"""
def __init__(self,
data: Union[torch.Tensor, np.ndarray],
stack: bool = False,
padding_value: int = 0,
cpu_only: bool = False,
pad_dims: int = 2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self) -> int:
return len(self._data)
@property
def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data
@property
def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self) -> bool:
return self._cpu_only
@property
def stack(self) -> bool:
return self._stack
@property
def padding_value(self) -> int:
return self._padding_value
@property
def pad_dims(self) -> int:
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self) -> int:
return self.data.dim()
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain
from typing import List, Tuple
from torch.nn.parallel import DataParallel
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDataParallel(DataParallel):
"""The DataParallel module that supports DataContainer.
MMDataParallel has two main differences with PyTorch DataParallel:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
Defaults to None when GPU is not available.
output_device (str | int): Device ID for output. Defaults to None.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim: int = 0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):
"""Override the original forward function.
The main difference lies in the CPU inference where the data in
:class:`DataContainers` will still be gathered.
"""
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module(*inputs[0], **kwargs[0])
else:
return super().forward(*inputs, **kwargs)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.train_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.train_step(*inputs[0], **kwargs[0])
def val_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.val_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.val_step(*inputs[0], **kwargs[0])
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_id: int) -> Tuple[tuple, tuple]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def _run_ddp_forward(self, *inputs, **kwargs) -> Any:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run = self._replicated_tensor_module if \
self._use_replicated_tensor_module else self.module
if self.device_ids:
inputs, kwargs = self.to_kwargs( # type: ignore
inputs, kwargs, self.device_ids[0])
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore
else:
return module_to_run(*inputs, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS
from .scatter_gather import ScatterInputs, scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module):
def __init__(self,
module: nn.Module,
dim: int = 0,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25):
super().__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
buffer_size: int) -> None:
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
for tensor, synced in zip(
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def _sync_params(self) -> None:
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output
# Copyright (c) OpenMMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.utils import Registry
MODULE_WRAPPERS = Registry('module wrapper')
MODULE_WRAPPERS.register_module(module=DataParallel)
MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
from torch import Tensor
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
def scatter(inputs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> list:
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, Tensor):
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_gpus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_gpus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None # type: ignore
def scatter_kwargs(inputs: ScatterInputs,
kwargs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> Tuple[tuple, tuple]:
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
length = len(kwargs) - len(inputs)
inputs.extend([() for _ in range(length)]) # type: ignore
elif len(kwargs) < len(inputs):
length = len(inputs) - len(kwargs)
kwargs.extend([{} for _ in range(length)]) # type: ignore
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import MODULE_WRAPPERS
def is_module_wrapper(module: nn.Module) -> bool:
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the input module is a module wrapper.
"""
def is_module_in_wrapper(module, module_wrapper):
module_wrappers = tuple(module_wrapper.module_dict.values())
if isinstance(module, module_wrappers):
return True
for child in module_wrapper.children.values():
if is_module_in_wrapper(module, child):
return True
return False
return is_module_in_wrapper(module, MODULE_WRAPPERS)
# Copyright (c) OpenMMLab. All rights reserved.
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .default_constructor import DefaultRunnerConstructor
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClearMLLoggerHook, ClosureHook,
DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook,
EMAHook, EvalHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SegmindLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .hooks.lr_updater import StepLrUpdaterHook # noqa
from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook,
InvLrUpdaterHook, LinearAnnealingLrUpdaterHook,
LrUpdaterHook, OneCycleLrUpdaterHook,
PolyLrUpdaterHook)
from .hooks.momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook,
LinearAnnealingMomentumUpdaterHook,
MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
# initialize ipu to registor ipu runner to RUNNERS
from mmcv.device import ipu # isort:skip # noqa
__all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'DefaultRunnerConstructor', 'SegmindLoggerHook',
'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingLrUpdaterHook',
'ClearMLLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
no_type_check)
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import Priority, get_priority
from .utils import get_time_str
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
Args:
model (:obj:`torch.nn.Module`): The model to be run.
batch_processor (callable): A callable method that process a data
batch. The interface of this method should be
`batch_processor(model, data, train_mode) -> dict`
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
optimizer (in most cases) or a dict of optimizers (in models that
requires more than one optimizer, e.g., GAN).
work_dir (str, optional): The working directory to save checkpoints
and logs. Defaults to None.
logger (:obj:`logging.Logger`): Logger used during training.
Defaults to None. (The default value is just for backward
compatibility)
meta (dict | None): A dict records some import information such as
environment info and seed, which will be logged in logger hook.
Defaults to None.
max_epochs (int, optional): Total training epochs.
max_iters (int, optional): Total training iterations.
"""
def __init__(self,
model: torch.nn.Module,
batch_processor: Optional[Callable] = None,
optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
work_dir: Optional[str] = None,
logger: Optional[logging.Logger] = None,
meta: Optional[Dict] = None,
max_iters: Optional[int] = None,
max_epochs: Optional[int] = None) -> None:
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn(
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.',
DeprecationWarning)
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_module_wrapper(model):
_model = model.module
else:
_model = model
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
raise RuntimeError(
'batch_processor and model.train_step()/model.val_step() '
'cannot be both available.')
else:
assert hasattr(model, 'train_step')
# check the type of `optimizer`
if isinstance(optimizer, dict):
for name, optim in optimizer.items():
if not isinstance(optim, Optimizer):
raise TypeError(
f'optimizer must be a dict of torch.optim.Optimizers, '
f'but optimizer["{name}"] is a {type(optim)}')
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
raise TypeError(
f'optimizer must be a torch.optim.Optimizer object '
f'or dict or None, but got {type(optimizer)}')
# check the type of `logger`
if not isinstance(logger, logging.Logger):
raise TypeError(f'logger must be a logging.Logger object, '
f'but got {type(logger)}')
# check the type of `meta`
if meta is not None and not isinstance(meta, dict):
raise TypeError(
f'meta must be a dict or None, but got {type(meta)}')
self.model = model
self.batch_processor = batch_processor
self.optimizer = optimizer
self.logger = logger
self.meta = meta
# create work_dir
if isinstance(work_dir, str):
self.work_dir: Optional[str] = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
else:
raise TypeError('"work_dir" must be a str or None')
# get model name from the model class
if hasattr(self.model, 'module'):
self._model_name = self.model.module.__class__.__name__
else:
self._model_name = self.model.__class__.__name__
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode: Optional[str] = None
self._hooks: List[Hook] = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
if max_epochs is not None and max_iters is not None:
raise ValueError(
'Only one of `max_epochs` or `max_iters` can be set.')
self._max_epochs = max_epochs
self._max_iters = max_iters
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
self.log_buffer = LogBuffer()
@property
def model_name(self) -> str:
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self) -> int:
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self) -> int:
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self) -> List[Hook]:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self) -> int:
"""int: Current epoch."""
return self._epoch
@property
def iter(self) -> int:
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self) -> int:
"""int: Iteration in an epoch."""
return self._inner_iter
@property
def max_epochs(self):
"""int: Maximum training epochs."""
return self._max_epochs
@property
def max_iters(self):
"""int: Maximum training iterations."""
return self._max_iters
@abstractmethod
def train(self):
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def run(self, data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]], **kwargs) -> Any:
pass
@abstractmethod
def save_checkpoint(self,
out_dir: str,
filename_tmpl: str,
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
pass
def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current learning rates.
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
lr: Union[List[float], Dict[str, List[float]]]
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
lr = dict()
for name, optim in self.optimizer.items():
lr[name] = [group['lr'] for group in optim.param_groups]
else:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return lr
def current_momentum(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current momentums.
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
def _get_momentum(optimizer):
momentums = []
for group in optimizer.param_groups:
if 'momentum' in group.keys():
momentums.append(group['momentum'])
elif 'betas' in group.keys():
momentums.append(group['betas'][0])
else:
momentums.append(0)
return momentums
if self.optimizer is None:
raise RuntimeError(
'momentum is not applicable because optimizer does not exist.')
elif isinstance(self.optimizer, torch.optim.Optimizer):
momentums = _get_momentum(self.optimizer)
elif isinstance(self.optimizer, dict):
momentums = dict()
for name, optim in self.optimizer.items():
momentums[name] = _get_momentum(optim)
return momentums
def register_hook(self,
hook: Hook,
priority: Union[int, str, Priority] = 'NORMAL') -> None:
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
priority (See :class:`Priority` for details of priorities).
For hooks with the same priority, they will be triggered in the same
order as they are registered.
Args:
hook (:obj:`Hook`): The hook to be registered.
priority (int or str or :obj:`Priority`): Hook priority.
Lower value means higher priority.
"""
assert isinstance(hook, Hook)
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority # type: ignore
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority: # type: ignore
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name: str) -> None:
"""Call all hooks.
Args:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
"""
for hook in self._hooks:
getattr(hook, fn_name)(self)
def get_hook_info(self) -> str:
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)
stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
def load_checkpoint(
self,
filename: str,
map_location: Union[str, Callable] = 'cpu',
strict: bool = False,
revise_keys: List = [(r'^module.', '')],
) -> Union[Dict, OrderedDict]:
return load_checkpoint(
self.model,
filename,
map_location,
strict,
self.logger,
revise_keys=revise_keys)
@no_type_check
def resume(self,
checkpoint: str,
resume_optimizer: bool = True,
map_location: Union[str, Callable] = 'default') -> None:
if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
checkpoint = self.load_checkpoint(
checkpoint,
map_location=lambda storage, loc: storage.cuda(device_id))
else:
checkpoint = self.load_checkpoint(checkpoint)
else:
checkpoint = self.load_checkpoint(
checkpoint, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
if self.meta is None:
self.meta = {}
self.meta.setdefault('hook_msgs', {})
# load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
# Re-calculate the number of iterations when resuming
# models with different number of GPUs
if 'config' in checkpoint['meta']:
config = mmcv.Config.fromstring(
checkpoint['meta']['config'], file_format='.py')
previous_gpu_ids = config.get('gpu_ids', None)
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
previous_gpu_ids) != self.world_size:
self._iter = int(self._iter * len(previous_gpu_ids) /
self.world_size)
self.logger.info('the iteration number is changed due to '
'change of GPU number')
# resume meta information meta
self.meta = checkpoint['meta']
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def register_lr_hook(self, lr_config: Union[Dict, Hook, None]) -> None:
if lr_config is None:
return
elif isinstance(lr_config, dict):
assert 'policy' in lr_config
policy_type = lr_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of Lr updater.
# Since this is not applicable for `
# CosineAnnealingLrUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'LrUpdaterHook'
lr_config['type'] = hook_type
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(
self, momentum_config: Union[Dict, Hook, None]) -> None:
if momentum_config is None:
return
if isinstance(momentum_config, dict):
assert 'policy' in momentum_config
policy_type = momentum_config.pop('policy')
# If the type of policy is all in lower case, e.g., 'cyclic',
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
# This is for the convenient usage of momentum updater.
# Since this is not applicable for
# `CosineAnnealingMomentumUpdater`,
# the string will not be changed if it contains capital letters.
if policy_type == policy_type.lower():
policy_type = policy_type.title()
hook_type = policy_type + 'MomentumUpdaterHook'
momentum_config['type'] = hook_type
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(
self, optimizer_config: Union[Dict, Hook, None]) -> None:
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
optimizer_config.setdefault('type', 'OptimizerHook')
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(
self, checkpoint_config: Union[Dict, Hook, None]) -> None:
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
checkpoint_config.setdefault('type', 'CheckpointHook')
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config: Optional[Dict]) -> None:
if log_config is None:
return
log_interval = log_config['interval']
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(
self,
timer_config: Union[Dict, Hook, None],
) -> None:
if timer_config is None:
return
if isinstance(timer_config, dict):
timer_config_ = copy.deepcopy(timer_config)
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook, priority='LOW')
def register_custom_hooks(
self, custom_config: Union[List, Dict, Hook, None]) -> None:
if custom_config is None:
return
if not isinstance(custom_config, list):
custom_config = [custom_config]
for item in custom_config:
if isinstance(item, dict):
self.register_hook_from_cfg(item)
else:
self.register_hook(item, priority='NORMAL')
def register_profiler_hook(
self,
profiler_config: Union[Dict, Hook, None],
) -> None:
if profiler_config is None:
return
if isinstance(profiler_config, dict):
profiler_config.setdefault('type', 'ProfilerHook')
hook = mmcv.build_from_cfg(profiler_config, HOOKS)
else:
hook = profiler_config
self.register_hook(hook)
def register_training_hooks(
self,
lr_config: Union[Dict, Hook, None],
optimizer_config: Union[Dict, Hook, None] = None,
checkpoint_config: Union[Dict, Hook, None] = None,
log_config: Optional[Dict] = None,
momentum_config: Union[Dict, Hook, None] = None,
timer_config: Union[Dict, Hook] = dict(type='IterTimerHook'),
custom_hooks_config: Union[List, Dict, Hook, None] = None) -> None:
"""Register default and custom hooks for training.
Default and custom hooks include:
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from ..utils import Registry
RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')
def build_runner_constructor(cfg: dict):
return RUNNER_BUILDERS.build(cfg)
def build_runner(cfg: dict, default_args: Optional[dict] = None):
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
runner_constructor = build_runner_constructor(
dict(
type=constructor_type,
runner_cfg=runner_cfg,
default_args=default_args))
runner = runner_constructor()
return runner
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
import re
import time
import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Tuple, Union
import mmengine
import torch
import torch.nn as nn
import torchvision
from mmengine.fileio import FileClient
from mmengine.fileio import load as load_file
from torch.optim import Optimizer
import mmcv
from ..parallel import is_module_wrapper
from ..utils import digit_version, load_url, mkdir_or_exist
from .dist_utils import get_dist_info
ENV_MMCV_HOME = 'MMCV_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home() -> str:
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
Default value for ``strict`` is set to ``False`` and the message for
param mismatch will be shown even if strict is False.
Args:
module (Module): Module that receives the state_dict.
state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys: List[str] = []
all_missing_keys: List[str] = []
err_msg: List[str] = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy() # type: ignore
if metadata is not None:
state_dict._metadata = metadata # type: ignore
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(module)
# break load->load reference cycle
load = None # type: ignore
# ignore "num_batches_tracked" of BN layers
missing_keys = [
key for key in all_missing_keys if 'num_batches_tracked' not in key
]
if unexpected_keys:
err_msg.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msg.append(
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
rank, _ = get_dist_info()
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
def get_torchvision_models():
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json')
model_urls = mmengine.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items():
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights')
or not hasattr(cls, 'DEFAULT')):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def get_mmcls_models():
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
mmcls_urls = load_file(mmcls_json_path)
return mmcls_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint: Dict) -> Dict:
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
# Some checkpoints converted from 3rd-party repo don't
# have the "state_dict" key.
state_dict = checkpoint
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
new_checkpoint = dict(state_dict=new_state_dict)
return new_checkpoint
class CheckpointLoader:
"""A general checkpoint loader to manage all schemes."""
_schemes: dict = {}
@classmethod
def _register_scheme(cls,
prefixes: Union[str, List, Tuple],
loader: Callable,
force: bool = False) -> None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
assert isinstance(prefixes, (list, tuple))
for prefix in prefixes:
if (prefix not in cls._schemes) or force:
cls._schemes[prefix] = loader
else:
raise KeyError(
f'{prefix} is already registered as a loader backend, '
'add "force=True" if you want to override it')
# sort, longer prefixes take priority
cls._schemes = OrderedDict(
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls,
prefixes: Union[str, List[str], Tuple[str, ...]],
loader: Optional[Callable] = None,
force: bool = False) -> Callable:
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or Sequence[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
Defaults to None.
force (bool, optional): Whether to override the loader
if the prefix has already been registered. Defaults to False.
"""
if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return # type: ignore
def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
return loader_cls
return _register
@classmethod
def _get_checkpoint_loader(cls, path: str):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
Args:
path (str): checkpoint path
Returns:
callable: checkpoint loader
"""
for p in cls._schemes:
# use regular match to handle some cases that where the prefix of
# loader has a prefix. For example, both 's3://path' and
# 'open-mmlab:s3://path' should return `load_from_ceph`
if re.match(p, path) is not None:
return cls._schemes[p]
@classmethod
def load_checkpoint(
cls,
filename: str,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path.
Args:
filename (str): checkpoint file name with given prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
logger (:mod:`logging.Logger`, optional): The logger for message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__ # type: ignore
mmcv.print_log(
f'load checkpoint from {class_name[10:]} path: {filename}', logger)
return checkpoint_loader(filename, map_location) # type: ignore
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path.
Args:
filename (str): local checkpoint file path
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(
filename: str,
map_location: Union[str, Callable, None] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (str, optional): directory in which to save the object,
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
rank, world_size = get_dist_info()
if rank == 0:
checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = load_url(
filename, model_dir=model_dir, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Args:
filename (str): checkpoint file path with pavi prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
assert filename.startswith('pavi://'), \
f'Expected filename startswith `pavi://`, but get {filename}'
model_path = filename[7:]
try:
from pavi import modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model = modelcloud.get(model_path)
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename: str,
map_location: Union[str, Callable, None] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Note:
Since v1.4.1, the registered scheme prefixes have been enhanced to
support bucket names in the path prefix, e.g. 's3://xx.xx/xx.path',
'bucket1:s3://xx.xx/xx.path'.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
.. warning::
:class:`mmengine.fileio.file_client.CephBackend` will be deprecated,
please use :class:`mmengine.fileio.file_client.PetrelBackend` instead.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
allowed_backends = ['ceph', 'petrel']
if backend not in allowed_backends:
raise ValueError(f'Load from Backend {backend} is not supported.')
if backend == 'ceph':
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
# successfully, the CephClient will be chosen.
try:
file_client = FileClient(backend=backend)
except ImportError:
allowed_backends.remove(backend)
file_client = FileClient(backend=allowed_backends[0])
with io.BytesIO(file_client.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
Args:
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'):
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
model_name = filename[11:]
else:
model_name = filename[14:]
# Support getting model urls in the same way as torchvision
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
# resnet50.imagenet1k_v1.
model_name = model_name.lower().replace('_weights', '')
return load_from_http(model_urls[model_name], map_location=map_location)
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
Args:
filename (str): checkpoint file path with open-mmlab or
openmmlab prefix
map_location (str, optional): Same as :func:`torch.load`.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_external_models()
prefix_str = 'open-mmlab://'
if filename.startswith(prefix_str):
model_name = filename[13:]
else:
model_name = filename[12:]
prefix_str = 'openmmlab://'
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(
f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_from_http(model_url, map_location=map_location)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls.
Args:
filename (str): checkpoint file path with mmcls prefix
map_location (str, optional): Same as :func:`torch.load`.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
model_urls = get_mmcls_models()
model_name = filename[8:]
checkpoint = load_from_http(
model_urls[model_name], map_location=map_location)
checkpoint = _process_mmcls_checkpoint(checkpoint)
return checkpoint
def _load_checkpoint(
filename: str,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str, optional): Same as :func:`torch.load`.
Default: None.
logger (:mod:`logging.Logger`, optional): The logger for error message.
Default: None
Returns:
dict or OrderedDict: The loaded checkpoint. It can be either an
OrderedDict storing model weights or a dict containing other
information, which depends on the checkpoint.
"""
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix.
Args:
prefix (str): The prefix of sub-module.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location=map_location)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if not prefix.endswith('.'):
prefix += '.'
prefix_len = len(prefix)
state_dict = {
k[prefix_len:]: v
for k, v in state_dict.items() if k.startswith(prefix)
}
assert state_dict, f'{prefix} is not in the pretrained model'
return state_dict
def load_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Union[str, Callable, None] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.
Args:
model (Module): Module to load checkpoint.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.
logger (:mod:`logging.Logger` or None): The logger for error message.
revise_keys (list): A list of customized keywords to modify the
state_dict in checkpoint. Each item is a (pattern, replacement)
pair of the regular expression operations. Default: strip
the prefix 'module.' by [(r'^module\\.', '')].
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
checkpoint = _load_checkpoint(filename, map_location, logger)
# OrderedDict is a subclass of dict
if not isinstance(checkpoint, dict):
raise RuntimeError(
f'No state_dict found in checkpoint file {filename}')
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# strip prefix of state_dict
metadata = getattr(state_dict, '_metadata', OrderedDict())
for p, r in revise_keys:
state_dict = OrderedDict(
{re.sub(p, r, k): v
for k, v in state_dict.items()})
# Keep metadata in state_dict
state_dict._metadata = metadata
# load state_dict
load_state_dict(model, state_dict, strict, logger)
return checkpoint
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
state_dict (OrderedDict): Model weights on GPU.
Returns:
OrderedDict: Model weights on GPU.
"""
state_dict_cpu = OrderedDict()
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
def _save_to_state_dict(module: torch.nn.Module, destination: dict,
prefix: str, keep_vars: bool) -> None:
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module: torch.nn.Module,
destination: Optional[OrderedDict] = None,
prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars) # type: ignore
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination # type: ignore
def save_checkpoint(model: torch.nn.Module,
filename: str,
optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None,
file_client_args: Optional[dict] = None) -> None:
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
``optimizer``. By default ``meta`` will contain version and time info.
Args:
model (Module): Module whose params are to be saved.
filename (str): Checkpoint filename.
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
meta (dict, optional): Metadata to be saved in checkpoint.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)) # type: ignore
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
checkpoint['optimizer'] = optimizer.state_dict()
elif isinstance(optimizer, dict):
checkpoint['optimizer'] = {}
for name, optim in optimizer.items():
checkpoint['optimizer'][name] = optim.state_dict()
if filename.startswith('pavi://'):
if file_client_args is not None:
raise ValueError(
'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}')
try:
from pavi import exception, modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
model_path = filename[7:]
root = modelcloud.Folder()
model_dir, model_name = osp.split(model_path)
try:
model = modelcloud.get(model_dir)
except exception.NodeNotFoundError:
model = root.create_training_model(model_dir)
with TemporaryDirectory() as tmp_dir:
checkpoint_file = osp.join(tmp_dir, model_name)
with open(checkpoint_file, 'wb') as f:
torch.save(checkpoint, f)
f.flush()
model.create_file(checkpoint_file, name=model_name)
else:
file_client = FileClient.infer_client(file_client_args, filename)
with io.BytesIO() as f:
torch.save(checkpoint, f)
file_client.put(f.getvalue(), filename)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from .builder import RUNNER_BUILDERS, RUNNERS
@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
"""Default constructor for runners.
Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
For example, We can inject some new properties and functions for `Runner`.
Example:
>>> from mmcv.runner import RUNNER_BUILDERS, build_runner
>>> # Define a new RunnerReconstructor
>>> @RUNNER_BUILDERS.register_module()
>>> class MyRunnerConstructor:
... def __init__(self, runner_cfg, default_args=None):
... if not isinstance(runner_cfg, dict):
... raise TypeError('runner_cfg should be a dict',
... f'but got {type(runner_cfg)}')
... self.runner_cfg = runner_cfg
... self.default_args = default_args
...
... def __call__(self):
... runner = RUNNERS.build(self.runner_cfg,
... default_args=self.default_args)
... # Add new properties for existing runner
... runner.my_name = 'my_runner'
... runner.my_function = lambda self: print(self.my_name)
... ...
>>> # build your runner
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
... constructor='MyRunnerConstructor')
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
self.runner_cfg = runner_cfg
self.default_args = default_args
def __call__(self):
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import os
import socket
import subprocess
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import IS_MLU_AVAILABLE
def _find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE:
import torch_mlu # noqa: F401
torch.mlu.set_device(rank)
dist.init_process_group(
backend='cncl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(29500):
os.environ['MASTER_PORT'] = '29500'
else:
os.environ['MASTER_PORT'] = str(_find_free_port())
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def allreduce_params(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce parameters.
Args:
params (list[torch.nn.Parameter]): List of parameters or buffers
of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params, world_size, bucket_size_mb)
else:
for tensor in params:
dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce gradients.
Args:
params (list[torch.nn.Parameter]): List of parameters of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
_, world_size = get_dist_info()
if world_size == 1:
return
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors: torch.Tensor,
world_size: int,
bucket_size_mb: int = -1) -> None:
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .utils import get_host_info
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir: str,
filename_tmpl: str = 'epoch_{}.pth',
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
# Note: meta.update(self.meta) should be done before
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
# there will be problems with resumed checkpoints.
# More details in https://github.com/open-mmlab/mmcv/pull/1108
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import warnings
from collections import abc
from inspect import getfullargspec
from typing import Callable, Iterable, List, Optional
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads
try:
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
# manually, so the behavior may not be consistent with real amp.
from torch.cuda.amp import autocast
except ImportError:
pass
def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
dst_type (torch.dtype): Destination type.
Returns:
The same type with inputs, but all contained Tensors have been cast.
"""
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, torch.Tensor):
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return inputs.to(dst_type) if inputs.dtype == src_type else inputs
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({ # type: ignore
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)( # type: ignore
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def auto_fp16(
apply_to: Optional[Iterable] = None,
out_fp32: bool = False,
supported_types: tuple = (nn.Module, ),
) -> Callable:
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
mixed precision training. If inputs arguments are fp32 tensors, they will
be converted to fp16 automatically. Arguments other than fp32 tensors are
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
supported_types (tuple): Classes can be decorated by ``auto_fp16``.
`New in version 1.5.0.`
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp16
>>> @auto_fp16()
>>> def forward(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp16
>>> @auto_fp16(apply_to=('pred', ))
>>> def do_something(self, pred, others):
>>> pass
"""
def auto_fp16_wrapper(old_func: Callable) -> Callable:
@functools.wraps(old_func)
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], supported_types):
raise TypeError('@auto_fp16 can only be used to decorate the '
f'method of those classes {supported_types}')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
args_to_cast = args_info.args if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
# NOTE: default args are not taken into consideration
if args:
arg_names = args_info.args[:len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(
cast_tensor_type(args[i], torch.float, torch.half))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = {}
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(
arg_value, torch.float, torch.half)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots' and
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp32:
output = cast_tensor_type(output, torch.half, torch.float)
return output
return new_func
return auto_fp16_wrapper
def force_fp32(apply_to: Optional[Iterable] = None,
out_fp16: bool = False) -> Callable:
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
mixed precision training. If there are some inputs that must be processed
in fp32 mode, then this decorator can handle it. If inputs arguments are
fp16 tensors, they will be converted to fp32 automatically. Arguments other
than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
torch.cuda.amp is used as the backend, otherwise, original mmcv
implementation will be adopted.
Args:
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp16 (bool): Whether to convert the output back to fp16.
Example:
>>> import torch.nn as nn
>>> class MyModule1(nn.Module):
>>>
>>> # Convert x and y to fp32
>>> @force_fp32()
>>> def loss(self, x, y):
>>> pass
>>> import torch.nn as nn
>>> class MyModule2(nn.Module):
>>>
>>> # convert pred to fp32
>>> @force_fp32(apply_to=('pred', ))
>>> def post_process(self, pred, others):
>>> pass
"""
def force_fp32_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
raise TypeError('@force_fp32 can only be used to decorate the '
'method of nn.Module')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get the argument names to be casted
args_to_cast = args_info.args if apply_to is None else apply_to
# convert the args that need to be processed
new_args = []
if args:
arg_names = args_info.args[:len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(
cast_tensor_type(args[i], torch.half, torch.float))
else:
new_args.append(args[i])
# convert the kwargs that need to be processed
new_kwargs = dict()
if kwargs:
for arg_name, arg_value in kwargs.items():
if arg_name in args_to_cast:
new_kwargs[arg_name] = cast_tensor_type(
arg_value, torch.half, torch.float)
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots' and
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# cast the results back to fp32 if necessary
if out_fp16:
output = cast_tensor_type(output, torch.float, torch.half)
return output
return new_func
return force_fp32_wrapper
def allreduce_grads(params: List[Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
warnings.warn(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
def wrap_fp16_model(model: nn.Module) -> None:
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
backend, otherwise, original mmcv implementation will be adopted.
For PyTorch >= 1.6, this function will
1. Set fp16 flag inside the model to True.
Otherwise:
1. Convert FP32 model to FP16.
2. Remain some necessary layers to be FP32, e.g., normalization layers.
3. Set `fp16_enabled` flag inside the model to True.
Args:
model (nn.Module): Model in FP32.
"""
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
patch_norm_fp32(model)
# set `fp16_enabled` flag
for m in model.modules():
if hasattr(m, 'fp16_enabled'):
m.fp16_enabled = True
def patch_norm_fp32(module: nn.Module) -> nn.Module:
"""Recursively convert normalization layers from FP16 to FP32.
Args:
module (nn.Module): The modules to be converted in FP16.
Returns:
nn.Module: The converted module, the normalization layers have been
converted to FP32.
"""
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
module.float()
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
module.forward = patch_forward_method(module.forward, torch.half,
torch.float)
for child in module.children():
patch_norm_fp32(child)
return module
def patch_forward_method(func: Callable,
src_type: torch.dtype,
dst_type: torch.dtype,
convert_output: bool = True) -> Callable:
"""Patch the forward method of a module.
Args:
func (callable): The original forward method.
src_type (torch.dtype): Type of input arguments to be converted from.
dst_type (torch.dtype): Type of input arguments to be converted to.
convert_output (bool): Whether to convert the output back to src_type.
Returns:
callable: The patched forward method.
"""
def new_forward(*args, **kwargs):
output = func(*cast_tensor_type(args, src_type, dst_type),
**cast_tensor_type(kwargs, src_type, dst_type))
if convert_output:
output = cast_tensor_type(output, dst_type, src_type)
return output
return new_forward
class LossScaler:
"""Class that manages loss scaling in mixed precision training which
supports both dynamic or static mode.
The implementation refers to
https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
It's important to understand how :class:`LossScaler` operates.
Loss scaling is designed to combat the problem of underflowing
gradients encountered at long times when training fp16 networks.
Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients.
If overflowing gradients are encountered, :class:`FP16_Optimizer` then
skips the update step for this particular iteration/minibatch,
and :class:`LossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients
detected,:class:`LossScaler` increases the loss scale once more.
In this way :class:`LossScaler` attempts to "ride the edge" of always
using the highest loss scale possible without incurring overflow.
Args:
init_scale (float): Initial loss scale value, default: 2**32.
scale_factor (float): Factor used when adjusting the loss scale.
Default: 2.
mode (str): Loss scaling mode. 'dynamic' or 'static'
scale_window (int): Number of consecutive iterations without an
overflow to wait before increasing the loss scale. Default: 1000.
"""
def __init__(self,
init_scale: float = 2**32,
mode: str = 'dynamic',
scale_factor: float = 2.,
scale_window: int = 1000):
self.cur_scale = init_scale
self.cur_iter = 0
assert mode in ('dynamic',
'static'), 'mode can only be dynamic or static'
self.mode = mode
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
def has_overflow(self, params: List[Parameter]) -> bool:
"""Check if params contain overflow."""
if self.mode != 'dynamic':
return False
for p in params:
if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
return True
return False
def _has_inf_or_nan(x: torch.Tensor) -> bool:
"""Check if params contain NaN."""
try:
cpu_sum = float(x.float().sum())
except RuntimeError as instance:
if 'value cannot be converted' not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') \
or cpu_sum != cpu_sum:
return True
return False
def update_scale(self, overflow: bool) -> None:
"""update the current loss scale value when overflow happens."""
if self.mode != 'dynamic':
return
if overflow:
self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % \
self.scale_window == 0:
self.cur_scale *= self.scale_factor
self.cur_iter += 1
def state_dict(self) -> dict:
"""Returns the state of the scaler as a :class:`dict`."""
return dict(
cur_scale=self.cur_scale,
cur_iter=self.cur_iter,
mode=self.mode,
last_overflow_iter=self.last_overflow_iter,
scale_factor=self.scale_factor,
scale_window=self.scale_window)
def load_state_dict(self, state_dict: dict) -> None:
"""Loads the loss_scaler state dict.
Args:
state_dict (dict): scaler state.
"""
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
self.mode = state_dict['mode']
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
@property
def loss_scale(self) -> float:
return self.cur_scale
# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (ClearMLLoggerHook, DvcliveLoggerHook, LoggerHook,
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook,
SegmindLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook,
LinearAnnealingLrUpdaterHook, LrUpdaterHook,
OneCycleLrUpdaterHook, PolyLrUpdaterHook,
StepLrUpdaterHook)
from .memory import EmptyCacheHook
from .momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook,
LinearAnnealingMomentumUpdaterHook,
MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook',
'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook',
'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'SegmindLoggerHook', 'LinearAnnealingLrUpdaterHook',
'LinearAnnealingMomentumUpdaterHook', 'ClearMLLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional
from mmengine.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool, optional): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
"""
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
save_optimizer: bool = True,
out_dir: Optional[str] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner):
if not self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break
def after_train_iter(self, runner):
if self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from .hook import HOOKS, Hook
@HOOKS.register_module()
class ClosureHook(Hook):
def __init__(self, fn_name: str, fn: Callable):
assert hasattr(self, fn_name)
assert callable(fn)
setattr(self, fn_name, fn)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from ...parallel import is_module_wrapper
from ..hooks.hook import HOOKS, Hook
@HOOKS.register_module()
class EMAHook(Hook):
r"""Exponential Moving Average Hook.
Use Exponential Moving Average on all parameters of model in training
process. All parameters have a ema backup, which update by the formula
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
.. math::
Xema\_{t+1} = (1 - \text{momentum}) \times
Xema\_{t} + \text{momentum} \times X_t
Args:
momentum (float): The momentum used for updating ema parameter.
Defaults to 0.0002.
interval (int): Update ema parameter every interval iteration.
Defaults to 1.
warm_up (int): During first warm_up steps, we may use smaller momentum
to update ema parameters more slowly. Defaults to 100.
resume_from (str, optional): The checkpoint path. Defaults to None.
"""
def __init__(self,
momentum: float = 0.0002,
interval: int = 1,
warm_up: int = 100,
resume_from: Optional[str] = None):
assert isinstance(interval, int) and interval > 0
self.warm_up = warm_up
self.interval = interval
assert momentum > 0 and momentum < 1
self.momentum = momentum**interval
self.checkpoint = resume_from
def before_run(self, runner):
"""To resume model with it's ema parameters more friendly.
Register ema parameter as ``named_buffer`` to model
"""
model = runner.model
if is_module_wrapper(model):
model = model.module
self.param_ema_buffer = {}
self.model_parameters = dict(model.named_parameters(recurse=True))
for name, value in self.model_parameters.items():
# "." is not allowed in module's buffer name
buffer_name = f"ema_{name.replace('.', '_')}"
self.param_ema_buffer[name] = buffer_name
model.register_buffer(buffer_name, value.data.clone())
self.model_buffers = dict(model.named_buffers(recurse=True))
if self.checkpoint is not None:
runner.resume(self.checkpoint)
def after_train_iter(self, runner):
"""Update ema parameter every self.interval iterations."""
curr_step = runner.iter
# We warm up the momentum considering the instability at beginning
momentum = min(self.momentum,
(1 + curr_step) / (self.warm_up + curr_step))
if curr_step % self.interval != 0:
return
for name, parameter in self.model_parameters.items():
buffer_name = self.param_ema_buffer[name]
buffer_parameter = self.model_buffers[buffer_name]
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
def after_train_epoch(self, runner):
"""We load parameter values from ema backup to model before the
EvalHook."""
self._swap_ema_parameters()
def before_train_epoch(self, runner):
"""We recover model's parameter from ema backup after last epoch's
EvalHook."""
self._swap_ema_parameters()
def _swap_ema_parameters(self):
"""Swap the parameter of model with parameter in ema_buffer."""
for name, value in self.model_parameters.items():
temp = value.data.clone()
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
value.data.copy_(ema_buffer.data)
ema_buffer.data.copy_(temp)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from math import inf
from typing import Callable, List, Optional
import torch.distributed as dist
from mmengine.fileio import FileClient
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from mmcv.utils import is_seq_of
from .hook import Hook
from .logger import LoggerHook
class EvalHook(Hook):
"""Non-Distributed evaluation hook.
This hook will regularly perform evaluation in a given interval when
performing in non-distributed environment.
Args:
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
implemented ``evaluate`` function.
start (int | None, optional): Evaluation starting epoch. It enables
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
``interval``. Default: None.
interval (int): Evaluation interval. Default: 1.
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
Default: True.
save_best (str, optional): If a metric is specified, it would measure
the best checkpoint during evaluation. The information about best
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
best score value and best checkpoint path, which will be also
loaded when resume checkpoint. Options are the evaluation metrics
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
detection and instance segmentation. ``AR@100`` for proposal
recall. If ``save_best`` is ``auto``, the first key of the returned
``OrderedDict`` result will be used. Default: None.
rule (str | None, optional): Comparison rule for best score. If set to
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
greater_keys (List[str] | None, optional): Metric keys that will be
inferred by 'greater' comparison rule. If ``None``,
_default_greater_keys will be used. (default: ``None``)
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
`New in version 1.3.16.`
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details. Default: None.
`New in version 1.3.16.`
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
Note:
If new arguments are added for EvalHook, tools/test.py,
tools/eval_metric.py may be affected.
"""
# Since the key for determine greater or less is related to the downstream
# tasks, downstream repos may need to overwrite the following inner
# variable accordingly.
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
_default_less_keys = ['loss']
def __init__(self,
dataloader: DataLoader,
start: Optional[int] = None,
interval: int = 1,
by_epoch: bool = True,
save_best: Optional[str] = None,
rule: Optional[str] = None,
test_fn: Optional[Callable] = None,
greater_keys: Optional[List[str]] = None,
less_keys: Optional[List[str]] = None,
out_dir: Optional[str] = None,
file_client_args: Optional[dict] = None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
f'but got {type(dataloader)}')
if interval <= 0:
raise ValueError(f'interval must be a positive number, '
f'but got {interval}')
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
if start is not None and start < 0:
raise ValueError(f'The evaluation start epoch {start} is smaller '
f'than 0')
self.dataloader = dataloader
self.interval = interval
self.start = start
self.by_epoch = by_epoch
assert isinstance(save_best, str) or save_best is None, \
'""save_best"" should be a str or None ' \
f'rather than {type(save_best)}'
self.save_best = save_best
self.eval_kwargs = eval_kwargs
self.initial_flag = True
if test_fn is None:
from mmcv.engine import single_gpu_test
self.test_fn = single_gpu_test
else:
self.test_fn = test_fn
if greater_keys is None:
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
assert isinstance(greater_keys, str)
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys
if less_keys is None:
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
assert isinstance(greater_keys, str)
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys
if self.save_best is not None:
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)
self.out_dir = out_dir
self.file_client_args = file_client_args
def _init_rule(self, rule: Optional[str], key_indicator: str):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
specified as 'less'.
3. Or if any one item in ``self.greater_keys`` is a substring of
key_indicator , the rule will be specified as 'greater'.
4. Or if any one item in ``self.less_keys`` is a substring of
key_indicator , the rule will be specified as 'less'.
Args:
rule (str | None): Comparison rule for best score.
key_indicator (str | None): Key indicator to determine the
comparison rule.
"""
if rule not in self.rule_map and rule is not None:
raise KeyError(f'rule must be greater, less or None, '
f'but got {rule}.')
if rule is None:
if key_indicator != 'auto':
# `_lc` here means we use the lower case of keys for
# case-insensitive matching
assert isinstance(key_indicator, str)
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]
if key_indicator_lc in greater_keys:
rule = 'greater'
elif key_indicator_lc in less_keys:
rule = 'less'
elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
f'{key_indicator}, thus a specific rule '
f'must be specified.')
self.rule = rule
self.key_indicator = key_indicator
if self.rule is not None:
self.compare_func = self.rule_map[self.rule]
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
f'The best checkpoint will be saved to {self.out_dir} by '
f'{self.file_client.name}')
if self.save_best is not None:
if runner.meta is None:
warnings.warn('runner.meta is None. Creating an empty one.')
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())
self.best_ckpt_path = runner.meta['hook_msgs'].get(
'best_ckpt', None)
def before_train_iter(self, runner):
"""Evaluate the model only at the start of training by iteration."""
if self.by_epoch or not self.initial_flag:
return
if self.start is not None and runner.iter >= self.start:
self.after_train_iter(runner)
self.initial_flag = False
def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training by epoch."""
if not (self.by_epoch and self.initial_flag):
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
self.initial_flag = False
def after_train_iter(self, runner):
"""Called after every training iter to evaluate the results."""
if not self.by_epoch and self._should_evaluate(runner):
# Because the priority of EvalHook is higher than LoggerHook, the
# training log and the evaluating log are mixed. Therefore,
# we need to dump the training log and clear it before evaluating
# log is generated. In addition, this problem will only appear in
# `IterBasedRunner` whose `self.by_epoch` is False, because
# `EpochBasedRunner` whose `self.by_epoch` is True calls
# `_do_evaluate` in `after_train_epoch` stage, and at this stage
# the training log has been printed, so it will not cause any
# problem. more details at
# https://github.com/open-mmlab/mmsegmentation/issues/694
for hook in runner._hooks:
if isinstance(hook, LoggerHook):
hook.after_train_iter(runner)
runner.log_buffer.clear()
self._do_evaluate(runner)
def after_train_epoch(self, runner):
"""Called after every training epoch to evaluate the results."""
if self.by_epoch and self._should_evaluate(runner):
self._do_evaluate(runner)
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
# the key_score may be `None` so it needs to skip the action to save
# the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)
def _should_evaluate(self, runner):
"""Judge whether to perform evaluation.
Here is the rule to judge whether to perform evaluation:
1. It will not perform evaluation during the epoch/iteration interval,
which is determined by ``self.interval``.
2. It will not perform evaluation if the start time is larger than
current time.
3. It will not perform evaluation when current time is larger than
the start time but during epoch/iteration interval.
Returns:
bool: The flag indicating whether to perform evaluation.
"""
if self.by_epoch:
current = runner.epoch
check_time = self.every_n_epochs
else:
current = runner.iter
check_time = self.every_n_iters
if self.start is None:
if not check_time(runner, self.interval):
# No evaluation during the interval.
return False
elif (current + 1) < self.start:
# No evaluation if start is larger than the current time.
return False
else:
# Evaluation only at epochs/iters 3, 5, 7...
# if start==3 and interval==2
if (current + 1 - self.start) % self.interval:
return False
return True
def _save_ckpt(self, runner, key_score):
"""Save the best checkpoint.
It will compare the score according to the compare function, write
related information (best score, best checkpoint path) and save the
best checkpoint into ``work_dir``.
"""
if self.by_epoch:
current = f'epoch_{runner.epoch + 1}'
cur_type, cur_time = 'epoch', runner.epoch + 1
else:
current = f'iter_{runner.iter + 1}'
cur_type, cur_time = 'iter', runner.iter + 1
best_score = runner.meta['hook_msgs'].get(
'best_score', self.init_value_map[self.rule])
if self.compare_func(key_score, best_score):
best_score = key_score
runner.meta['hook_msgs']['best_score'] = best_score
if self.best_ckpt_path and self.file_client.isfile(
self.best_ckpt_path):
self.file_client.remove(self.best_ckpt_path)
runner.logger.info(
f'The previous best checkpoint {self.best_ckpt_path} was '
'removed')
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
self.best_ckpt_path = self.file_client.join_path(
self.out_dir, best_ckpt_name)
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
runner.save_checkpoint(
self.out_dir,
filename_tmpl=best_ckpt_name,
create_symlink=False)
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(
f'Best {self.key_indicator} is {best_score:0.4f} '
f'at {cur_time} {cur_type}.')
def evaluate(self, runner, results):
"""Evaluate the results.
Args:
runner (:obj:`mmcv.Runner`): The underlined training runner.
results (list): Output results.
"""
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
if self.save_best is not None:
# If the performance of model is pool, the `eval_res` may be an
# empty dict and it will raise exception when `self.save_best` is
# not None. More details at
# https://github.com/open-mmlab/mmdetection/issues/6265.
if not eval_res:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None
if self.key_indicator == 'auto':
# infer from eval_results
self._init_rule(self.rule, list(eval_res.keys())[0])
return eval_res[self.key_indicator]
return None
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
This hook will regularly perform evaluation in a given interval when
performing in distributed environment.
Args:
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
implemented ``evaluate`` function.
start (int | None, optional): Evaluation starting epoch. It enables
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
``interval``. Default: None.
interval (int): Evaluation interval. Default: 1.
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
default: True.
save_best (str, optional): If a metric is specified, it would measure
the best checkpoint during evaluation. The information about best
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
best score value and best checkpoint path, which will be also
loaded when resume checkpoint. Options are the evaluation metrics
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
detection and instance segmentation. ``AR@100`` for proposal
recall. If ``save_best`` is ``auto``, the first key of the returned
``OrderedDict`` result will be used. Default: None.
rule (str | None, optional): Comparison rule for best score. If set to
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, `runner.work_dir` will be used by default. If specified,
the `out_dir` will be the concatenation of `out_dir` and the last
level directory of `runner.work_dir`.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details. Default: None.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
def __init__(self,
dataloader: DataLoader,
start: Optional[int] = None,
interval: int = 1,
by_epoch: bool = True,
save_best: Optional[str] = None,
rule: Optional[str] = None,
test_fn: Optional[Callable] = None,
greater_keys: Optional[List[str]] = None,
less_keys: Optional[List[str]] = None,
broadcast_bn_buffer: bool = True,
tmpdir: Optional[str] = None,
gpu_collect: bool = False,
out_dir: Optional[str] = None,
file_client_args: Optional[dict] = None,
**eval_kwargs):
if test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test
super().__init__(
dataloader,
start=start,
interval=interval,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
out_dir=out_dir,
file_client_args=file_client_args,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)
tmpdir = self.tmpdir
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
# the key_score may be `None` so it needs to skip the action to
# save the best checkpoint
if self.save_best and key_score:
self._save_ckpt(runner, key_score)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment