Unverified Commit 9185eee8 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

Remove runner, parallel, engine and device (#2216)

* Remove runner, parallel, engine and device

* fix format

* remove outdated docs
parent 19a02415
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Type, Union
import numpy as np
import torch
def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError(
f'{args[0].__class__.__name__} has no attribute '
f'{func.__name__} for type {args[0].datatype}')
return func(*args, **kwargs)
return wrapper
class DataContainer:
"""A container for any type of objects.
Typically tensors will be stacked in the collate function and sliced along
some dimension in the scatter function. This behavior has some limitations.
1. All tensors have to be the same size.
2. Types are limited (numpy array or Tensor).
We design `DataContainer` and `MMDataParallel` to overcome these
limitations. The behavior can be either of the following.
- copy to GPU, pad all tensors to the same size and stack them
- copy to GPU without stacking
- leave the objects as is and pass it to the model
- pad_dims specifies the number of last few dimensions to do padding
"""
def __init__(self,
data: Union[torch.Tensor, np.ndarray],
stack: bool = False,
padding_value: int = 0,
cpu_only: bool = False,
pad_dims: int = 2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self) -> int:
return len(self._data)
@property
def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data
@property
def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self) -> bool:
return self._cpu_only
@property
def stack(self) -> bool:
return self._stack
@property
def padding_value(self) -> int:
return self._padding_value
@property
def pad_dims(self) -> int:
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self) -> int:
return self.data.dim()
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain
from typing import List, Tuple
from torch.nn.parallel import DataParallel
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDataParallel(DataParallel):
"""The DataParallel module that supports DataContainer.
MMDataParallel has two main differences with PyTorch DataParallel:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
device_ids (list[int]): Device IDS of modules to be scattered to.
Defaults to None when GPU is not available.
output_device (str | int): Device ID for output. Defaults to None.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim: int = 0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):
"""Override the original forward function.
The main difference lies in the CPU inference where the data in
:class:`DataContainers` will still be gathered.
"""
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module(*inputs[0], **kwargs[0])
else:
return super().forward(*inputs, **kwargs)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.train_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.train_step(*inputs[0], **kwargs[0])
def val_step(self, *inputs, **kwargs):
if not self.device_ids:
# We add the following line thus the module could gather and
# convert data containers as those in GPU inference
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
return self.module.val_step(*inputs[0], **kwargs[0])
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
return self.module.val_step(*inputs[0], **kwargs[0])
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
"""The DDP module that supports DataContainer.
MMDDP has two main differences with PyTorch DDP:
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data.
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_id: int) -> Tuple[tuple, tuple]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def _run_ddp_forward(self, *inputs, **kwargs) -> Any:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run = self._replicated_tensor_module if \
self._use_replicated_tensor_module else self.module
if self.device_ids:
inputs, kwargs = self.to_kwargs( # type: ignore
inputs, kwargs, self.device_ids[0])
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore
else:
return module_to_run(*inputs, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS
from .scatter_gather import ScatterInputs, scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module):
def __init__(self,
module: nn.Module,
dim: int = 0,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25):
super().__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
buffer_size: int) -> None:
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
for tensor, synced in zip(
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def _sync_params(self) -> None:
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
if len(buffers) > 0:
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output
# Copyright (c) OpenMMLab. All rights reserved.
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmcv.utils import Registry
MODULE_WRAPPERS = Registry('module wrapper')
MODULE_WRAPPERS.register_module(module=DataParallel)
MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple, Union
from torch import Tensor
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
def scatter(inputs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> list:
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
def scatter_map(obj):
if isinstance(obj, Tensor):
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_gpus, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_gpus, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None # type: ignore
def scatter_kwargs(inputs: ScatterInputs,
kwargs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> Tuple[tuple, tuple]:
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
length = len(kwargs) - len(inputs)
inputs.extend([() for _ in range(length)]) # type: ignore
elif len(kwargs) < len(inputs):
length = len(inputs) - len(kwargs)
kwargs.extend([{} for _ in range(length)]) # type: ignore
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import MODULE_WRAPPERS
def is_module_wrapper(module: nn.Module) -> bool:
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the input module is a module wrapper.
"""
def is_module_in_wrapper(module, module_wrapper):
module_wrappers = tuple(module_wrapper.module_dict.values())
if isinstance(module, module_wrappers):
return True
for child in module_wrapper.children.values():
if is_module_in_wrapper(module, child):
return True
return False
return is_module_in_wrapper(module, MODULE_WRAPPERS)
# Copyright (c) OpenMMLab. All rights reserved.
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .default_constructor import DefaultRunnerConstructor
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClearMLLoggerHook, ClosureHook,
DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook,
EMAHook, EvalHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SegmindLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .hooks.lr_updater import StepLrUpdaterHook # noqa
from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook,
InvLrUpdaterHook, LinearAnnealingLrUpdaterHook,
LrUpdaterHook, OneCycleLrUpdaterHook,
PolyLrUpdaterHook)
from .hooks.momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook,
LinearAnnealingMomentumUpdaterHook,
MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
DefaultOptimizerConstructor, build_optimizer,
build_optimizer_constructor)
from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
# initialize ipu to registor ipu runner to RUNNERS
from mmcv.device import ipu # isort:skip # noqa
__all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'DefaultRunnerConstructor', 'SegmindLoggerHook',
'LinearAnnealingMomentumUpdaterHook', 'LinearAnnealingLrUpdaterHook',
'ClearMLLoggerHook'
]
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from ..utils import Registry
RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')
def build_runner_constructor(cfg: dict):
return RUNNER_BUILDERS.build(cfg)
def build_runner(cfg: dict, default_args: Optional[dict] = None):
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
runner_constructor = build_runner_constructor(
dict(
type=constructor_type,
runner_cfg=runner_cfg,
default_args=default_args))
runner = runner_constructor()
return runner
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from .builder import RUNNER_BUILDERS, RUNNERS
@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
"""Default constructor for runners.
Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
For example, We can inject some new properties and functions for `Runner`.
Example:
>>> from mmcv.runner import RUNNER_BUILDERS, build_runner
>>> # Define a new RunnerReconstructor
>>> @RUNNER_BUILDERS.register_module()
>>> class MyRunnerConstructor:
... def __init__(self, runner_cfg, default_args=None):
... if not isinstance(runner_cfg, dict):
... raise TypeError('runner_cfg should be a dict',
... f'but got {type(runner_cfg)}')
... self.runner_cfg = runner_cfg
... self.default_args = default_args
...
... def __call__(self):
... runner = RUNNERS.build(self.runner_cfg,
... default_args=self.default_args)
... # Add new properties for existing runner
... runner.my_name = 'my_runner'
... runner.my_function = lambda self: print(self.my_name)
... ...
>>> # build your runner
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
... constructor='MyRunnerConstructor')
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
self.runner_cfg = runner_cfg
self.default_args = default_args
def __call__(self):
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import os
import socket
import subprocess
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import IS_MLU_AVAILABLE
def _find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE:
import torch_mlu # noqa: F401
torch.mlu.set_device(rank)
dist.init_process_group(
backend='cncl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(29500):
os.environ['MASTER_PORT'] = '29500'
else:
os.environ['MASTER_PORT'] = str(_find_free_port())
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
def allreduce_params(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce parameters.
Args:
params (list[torch.nn.Parameter]): List of parameters or buffers
of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params, world_size, bucket_size_mb)
else:
for tensor in params:
dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce gradients.
Args:
params (list[torch.nn.Parameter]): List of parameters of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
grads = [
param.grad.data for param in params
if param.requires_grad and param.grad is not None
]
_, world_size = get_dist_info()
if world_size == 1:
return
if coalesce:
_allreduce_coalesced(grads, world_size, bucket_size_mb)
else:
for tensor in grads:
dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors: torch.Tensor,
world_size: int,
bucket_size_mb: int = -1) -> None:
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
dist.all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import shutil
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import mmcv
from .base_runner import BaseRunner
from .builder import RUNNERS
from .checkpoint import save_checkpoint
from .utils import get_host_info
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir: str,
filename_tmpl: str = 'epoch_{}.pth',
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
# Note: meta.update(self.meta) should be done before
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
# there will be problems with resumed checkpoints.
# More details in https://github.com/open-mmlab/mmcv/pull/1108
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)
@RUNNERS.register_module()
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
from .checkpoint import CheckpointHook
from .closure import ClosureHook
from .ema import EMAHook
from .evaluation import DistEvalHook, EvalHook
from .hook import HOOKS, Hook
from .iter_timer import IterTimerHook
from .logger import (ClearMLLoggerHook, DvcliveLoggerHook, LoggerHook,
MlflowLoggerHook, NeptuneLoggerHook, PaviLoggerHook,
SegmindLoggerHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook, InvLrUpdaterHook,
LinearAnnealingLrUpdaterHook, LrUpdaterHook,
OneCycleLrUpdaterHook, PolyLrUpdaterHook,
StepLrUpdaterHook)
from .memory import EmptyCacheHook
from .momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook,
LinearAnnealingMomentumUpdaterHook,
MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, OptimizerHook)
from .profiler import ProfilerHook
from .sampler_seed import DistSamplerSeedHook
from .sync_buffer import SyncBuffersHook
__all__ = [
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'OptimizerHook',
'Fp16OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
'TextLoggerHook', 'TensorboardLoggerHook', 'NeptuneLoggerHook',
'WandbLoggerHook', 'DvcliveLoggerHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'SyncBuffersHook', 'EMAHook', 'EvalHook', 'DistEvalHook', 'ProfilerHook',
'GradientCumulativeOptimizerHook', 'GradientCumulativeFp16OptimizerHook',
'SegmindLoggerHook', 'LinearAnnealingLrUpdaterHook',
'LinearAnnealingMomentumUpdaterHook', 'ClearMLLoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Optional
from mmengine.fileio import FileClient
from ..dist_utils import allreduce_params, master_only
from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The root directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``.
`Changed in version 1.3.16.`
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool, optional): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool, optional): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Default: None.
`New in version 1.3.16.`
.. warning::
Before v1.3.16, the ``out_dir`` argument indicates the path where the
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
root directory and the final path to save checkpoint is the
concatenation of ``out_dir`` and the last level directory of
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
and the value of ``runner.work_dir`` is "/path/of/B", then the final
path will be "/path/of/A/B".
"""
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
save_optimizer: bool = True,
out_dir: Optional[str] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner):
if not self.out_dir:
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
'create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}')
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner):
if not self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
@master_only
def _save_checkpoint(self, runner):
"""Save the current checkpoint and delete unwanted checkpoint."""
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename)
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break
def after_train_iter(self, runner):
if self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
allreduce_params(runner.model.buffers())
self._save_checkpoint(runner)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable
from .hook import HOOKS, Hook
@HOOKS.register_module()
class ClosureHook(Hook):
def __init__(self, fn_name: str, fn: Callable):
assert hasattr(self, fn_name)
assert callable(fn)
setattr(self, fn_name, fn)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment