"llm/vscode:/vscode.git/clone" did not exist on "dc18eee39d8db35e6cbbc416a39ecbbda68fa962"
Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import Tensor
from torch.nn.parallel._functions import _get_stream
def scatter(input, devices, streams=None):
def scatter(input: Union[List, Tensor],
devices: List,
streams: Optional[List] = None) -> Union[List, Tensor]:
"""Scatters tensor across multiple GPUs."""
if streams is None:
streams = [None] * len(devices)
......@@ -15,30 +20,28 @@ def scatter(input, devices, streams=None):
[streams[i // chunk_size]]) for i in range(len(input))
]
return outputs
elif isinstance(input, torch.Tensor):
elif isinstance(input, Tensor):
output = input.contiguous()
# TODO: copy to a pinned buffer first (if copying from CPU)
stream = streams[0] if output.numel() > 0 else None
if devices != [-1]:
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
output = output.cuda(devices[0], non_blocking=True)
else:
# unsqueeze the first dimension thus the tensor's shape is the
# same as those scattered with GPU.
output = output.unsqueeze(0)
return output
else:
raise Exception(f'Unknown type {type(input)}.')
def synchronize_stream(output, devices, streams):
def synchronize_stream(output: Union[List, Tensor], devices: List,
streams: List) -> None:
if isinstance(output, list):
chunk_size = len(output) // len(devices)
for i in range(len(devices)):
for j in range(chunk_size):
synchronize_stream(output[i * chunk_size + j], [devices[i]],
[streams[i]])
elif isinstance(output, torch.Tensor):
elif isinstance(output, Tensor):
if output.numel() != 0:
with torch.cuda.device(devices[0]):
main_stream = torch.cuda.current_stream()
......@@ -48,14 +51,14 @@ def synchronize_stream(output, devices, streams):
raise Exception(f'Unknown type {type(output)}.')
def get_input_device(input):
def get_input_device(input: Union[List, Tensor]) -> int:
if isinstance(input, list):
for item in input:
input_device = get_input_device(item)
if input_device != -1:
return input_device
return -1
elif isinstance(input, torch.Tensor):
elif isinstance(input, Tensor):
return input.get_device() if input.is_cuda else -1
else:
raise Exception(f'Unknown type {type(input)}.')
......@@ -64,7 +67,7 @@ def get_input_device(input):
class Scatter:
@staticmethod
def forward(target_gpus, input):
def forward(target_gpus: List[int], input: Union[List, Tensor]) -> tuple:
input_device = get_input_device(input)
streams = None
if input_device == -1 and target_gpus != [-1]:
......@@ -76,4 +79,4 @@ class Scatter:
if streams is not None:
synchronize_stream(outputs, target_gpus, streams)
return tuple(outputs)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
......@@ -8,7 +8,7 @@ from torch.utils.data.dataloader import default_collate
from .data_container import DataContainer
def collate(batch, samples_per_gpu=1):
def collate(batch: Sequence, samples_per_gpu: int = 1):
"""Puts each data field into a tensor/DataContainer with outer dimension
batch size.
......
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Type, Union
import numpy as np
import torch
def assert_tensor_type(func):
def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
......@@ -35,11 +37,11 @@ class DataContainer:
"""
def __init__(self,
data,
stack=False,
padding_value=0,
cpu_only=False,
pad_dims=2):
data: Union[torch.Tensor, np.ndarray],
stack: bool = False,
padding_value: int = 0,
cpu_only: bool = False,
pad_dims: int = 2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
......@@ -47,43 +49,43 @@ class DataContainer:
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self):
def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self):
def __len__(self) -> int:
return len(self._data)
@property
def data(self):
def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data
@property
def datatype(self):
def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
def cpu_only(self) -> bool:
return self._cpu_only
@property
def stack(self):
def stack(self) -> bool:
return self._stack
@property
def padding_value(self):
def padding_value(self) -> int:
return self._padding_value
@property
def pad_dims(self):
def pad_dims(self) -> int:
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs):
def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
def dim(self) -> int:
return self.data.dim()
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain
from typing import List, Tuple
from torch.nn.parallel import DataParallel
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDataParallel(DataParallel):
......@@ -13,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
- It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
......@@ -31,8 +32,8 @@ class MMDataParallel(DataParallel):
dim (int): Dimension used to scatter the data. Defaults to 0.
"""
def __init__(self, *args, dim=0, **kwargs):
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
def __init__(self, *args, dim: int = 0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.dim = dim
def forward(self, *inputs, **kwargs):
......@@ -49,7 +50,8 @@ class MMDataParallel(DataParallel):
else:
return super().forward(*inputs, **kwargs)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
class MMDistributedDataParallel(DistributedDataParallel):
......@@ -18,12 +20,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
- It implement two APIs ``train_step()`` and ``val_step()``.
"""
def to_kwargs(self, inputs, kwargs, device_id):
def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_id: int) -> Tuple[tuple, tuple]:
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
# to move all tensors to device_id
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def train_step(self, *inputs, **kwargs):
......@@ -44,8 +48,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
......@@ -57,8 +68,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
......@@ -86,8 +103,15 @@ class MMDistributedDataParallel(DistributedDataParallel):
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
else:
if (getattr(self, 'require_forward_param_sync', False)
and self.require_forward_param_sync):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
......@@ -99,8 +123,14 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')):
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if (torch.is_grad_enabled()
and getattr(self, 'require_backward_grad_sync', False)
and self.require_backward_grad_sync):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
......@@ -110,3 +140,28 @@ class MMDistributedDataParallel(DistributedDataParallel):
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def _run_ddp_forward(self, *inputs, **kwargs) -> Any:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run = self._replicated_tensor_module if \
self._use_replicated_tensor_module else self.module
if self.device_ids:
inputs, kwargs = self.to_kwargs( # type: ignore
inputs, kwargs, self.device_ids[0])
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore
else:
return module_to_run(*inputs, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
......@@ -7,18 +9,18 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs
from .scatter_gather import ScatterInputs, scatter_kwargs
@MODULE_WRAPPERS.register_module()
class MMDistributedDataParallel(nn.Module):
def __init__(self,
module,
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25):
super(MMDistributedDataParallel, self).__init__()
module: nn.Module,
dim: int = 0,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25):
super().__init__()
self.module = module
self.dim = dim
self.broadcast_buffers = broadcast_buffers
......@@ -26,7 +28,8 @@ class MMDistributedDataParallel(nn.Module):
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
self._sync_params()
def _dist_broadcast_coalesced(self, tensors, buffer_size):
def _dist_broadcast_coalesced(self, tensors: Sequence[torch.Tensor],
buffer_size: int) -> None:
for tensors in _take_tensors(tensors, buffer_size):
flat_tensors = _flatten_dense_tensors(tensors)
dist.broadcast(flat_tensors, 0)
......@@ -34,7 +37,7 @@ class MMDistributedDataParallel(nn.Module):
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
tensor.copy_(synced)
def _sync_params(self):
def _sync_params(self) -> None:
module_states = list(self.module.state_dict().values())
if len(module_states) > 0:
self._dist_broadcast_coalesced(module_states,
......@@ -49,7 +52,8 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(buffers,
self.broadcast_bucket_size)
def scatter(self, inputs, kwargs, device_ids):
def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs,
device_ids: List[int]) -> Tuple[tuple, tuple]:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from typing import List, Tuple, Union
from torch import Tensor
from torch.nn.parallel._functions import Scatter as OrigScatter
from ._functions import Scatter
from .data_container import DataContainer
ScatterInputs = Union[Tensor, DataContainer, tuple, list, dict]
def scatter(inputs, target_gpus, dim=0):
def scatter(inputs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> list:
"""Scatter inputs to target gpus.
The only difference from original :func:`scatter` is to add support for
......@@ -14,7 +20,7 @@ def scatter(inputs, target_gpus, dim=0):
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if isinstance(obj, Tensor):
if target_gpus != [-1]:
return OrigScatter.apply(target_gpus, None, dim, obj)
else:
......@@ -33,7 +39,7 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for targets in target_gpus]
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
......@@ -43,17 +49,22 @@ def scatter(inputs, target_gpus, dim=0):
try:
return scatter_map(inputs)
finally:
scatter_map = None
scatter_map = None # type: ignore
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
def scatter_kwargs(inputs: ScatterInputs,
kwargs: ScatterInputs,
target_gpus: List[int],
dim: int = 0) -> Tuple[tuple, tuple]:
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
length = len(kwargs) - len(inputs)
inputs.extend([() for _ in range(length)]) # type: ignore
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
length = len(inputs) - len(kwargs)
kwargs.extend([{} for _ in range(length)]) # type: ignore
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
# Copyright (c) OpenMMLab. All rights reserved.
from torch import nn
from .registry import MODULE_WRAPPERS
def is_module_wrapper(module):
def is_module_wrapper(module: nn.Module) -> bool:
"""Check if a module is a module wrapper.
The following 3 modules in MMCV (and their subclasses) are regarded as
module wrappers: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version). You may add you own
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS or
its children registries.
Args:
module (nn.Module): The module to be checked.
......@@ -16,5 +19,14 @@ def is_module_wrapper(module):
Returns:
bool: True if the input module is a module wrapper.
"""
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
return isinstance(module, module_wrappers)
def is_module_in_wrapper(module, module_wrapper):
module_wrappers = tuple(module_wrapper.module_dict.values())
if isinstance(module, module_wrappers):
return True
for child in module_wrapper.children.values():
if is_module_in_wrapper(module, child):
return True
return False
return is_module_in_wrapper(module, MODULE_WRAPPERS)
# Copyright (c) OpenMMLab. All rights reserved.
from .base_module import BaseModule, ModuleList, Sequential
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .base_runner import BaseRunner
from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
......@@ -10,14 +10,29 @@ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
from .hooks import (HOOKS, CheckpointHook, ClearMLLoggerHook, ClosureHook,
DistEvalHook, DistSamplerSeedHook, DvcliveLoggerHook,
EMAHook, EvalHook, Fp16OptimizerHook,
GradientCumulativeFp16OptimizerHook,
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
LoggerHook, LrUpdaterHook, MlflowLoggerHook,
NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
LoggerHook, MlflowLoggerHook, NeptuneLoggerHook,
OptimizerHook, PaviLoggerHook, SegmindLoggerHook,
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
WandbLoggerHook)
from .hooks.lr_updater import StepLrUpdaterHook # noqa
from .hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook, CyclicLrUpdaterHook,
ExpLrUpdaterHook, FixedLrUpdaterHook,
FlatCosineAnnealingLrUpdaterHook,
InvLrUpdaterHook, LinearAnnealingLrUpdaterHook,
LrUpdaterHook, OneCycleLrUpdaterHook,
PolyLrUpdaterHook)
from .hooks.momentum_updater import (CosineAnnealingMomentumUpdaterHook,
CyclicMomentumUpdaterHook,
LinearAnnealingMomentumUpdaterHook,
MomentumUpdaterHook,
OneCycleMomentumUpdaterHook,
StepMomentumUpdaterHook)
from .iter_based_runner import IterBasedRunner, IterLoader
from .log_buffer import LogBuffer
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
......@@ -26,9 +41,18 @@ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
from .priority import Priority, get_priority
from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
# initialize ipu to registor ipu runner to RUNNERS
from mmcv.device import ipu # isort:skip # noqa
__all__ = [
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
'FixedLrUpdaterHook', 'StepLrUpdaterHook', 'ExpLrUpdaterHook',
'PolyLrUpdaterHook', 'InvLrUpdaterHook', 'CosineAnnealingLrUpdaterHook',
'FlatCosineAnnealingLrUpdaterHook', 'CosineRestartLrUpdaterHook',
'CyclicLrUpdaterHook', 'OneCycleLrUpdaterHook', 'MomentumUpdaterHook',
'StepMomentumUpdaterHook', 'CosineAnnealingMomentumUpdaterHook',
'CyclicMomentumUpdaterHook', 'OneCycleMomentumUpdaterHook',
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
......@@ -42,6 +66,8 @@ __all__ = [
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
'ModuleDict', 'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor',
'SegmindLoggerHook', 'LinearAnnealingMomentumUpdaterHook',
'LinearAnnealingLrUpdaterHook', 'ClearMLLoggerHook'
]
......@@ -4,6 +4,7 @@ import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
......@@ -18,25 +19,24 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
functionality of parameter initialization. Compared with
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter
initialization and recording initialization
information.
- ``_params_init_info``: Used to track the parameter
initialization information. This attribute only
exists during executing the ``init_weights``.
- ``init_cfg``: the config to control the initialization.
- ``init_weights``: The function of parameter initialization and recording
initialization information.
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
# in low levels has a higher priority.
super(BaseModule, self).__init__()
super().__init__()
# define default value of init_cfg instead of hard code
# in init_weights() function
self._is_init = False
......@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property
def is_init(self):
def is_init(self) -> bool:
return self._is_init
def init_weights(self):
def init_weights(self) -> None:
"""Initialize the weights."""
is_top_level_module = False
......@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info = defaultdict(dict)
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True
# Initialize the `_params_init_info`,
......@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name):
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
......@@ -177,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg=None):
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
......@@ -190,6 +190,24 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Args:
modules (dict, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
......@@ -4,9 +4,13 @@ import logging
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
no_type_check)
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import mmcv
from ..parallel import is_module_wrapper
......@@ -49,20 +53,22 @@ class BaseRunner(metaclass=ABCMeta):
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
model: torch.nn.Module,
batch_processor: Optional[Callable] = None,
optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
work_dir: Optional[str] = None,
logger: Optional[logging.Logger] = None,
meta: Optional[Dict] = None,
max_iters: Optional[int] = None,
max_epochs: Optional[int] = None) -> None:
if batch_processor is not None:
if not callable(batch_processor):
raise TypeError('batch_processor must be callable, '
f'but got {type(batch_processor)}')
warnings.warn('batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.')
warnings.warn(
'batch_processor is deprecated, please implement '
'train_step() and val_step() in the model instead.',
DeprecationWarning)
# raise an error is `batch_processor` is not None and
# `model.train_step()` exists.
if is_module_wrapper(model):
......@@ -104,8 +110,8 @@ class BaseRunner(metaclass=ABCMeta):
self.logger = logger
self.meta = meta
# create work_dir
if mmcv.is_str(work_dir):
self.work_dir = osp.abspath(work_dir)
if isinstance(work_dir, str):
self.work_dir: Optional[str] = osp.abspath(work_dir)
mmcv.mkdir_or_exist(self.work_dir)
elif work_dir is None:
self.work_dir = None
......@@ -120,8 +126,8 @@ class BaseRunner(metaclass=ABCMeta):
self._rank, self._world_size = get_dist_info()
self.timestamp = get_time_str()
self.mode = None
self._hooks = []
self.mode: Optional[str] = None
self._hooks: List[Hook] = []
self._epoch = 0
self._iter = 0
self._inner_iter = 0
......@@ -136,38 +142,38 @@ class BaseRunner(metaclass=ABCMeta):
self.log_buffer = LogBuffer()
@property
def model_name(self):
def model_name(self) -> str:
"""str: Name of the model, usually the module class name."""
return self._model_name
@property
def rank(self):
def rank(self) -> int:
"""int: Rank of current process. (distributed training)"""
return self._rank
@property
def world_size(self):
def world_size(self) -> int:
"""int: Number of processes participating in the job.
(distributed training)"""
return self._world_size
@property
def hooks(self):
def hooks(self) -> List[Hook]:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def epoch(self):
def epoch(self) -> int:
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
def iter(self) -> int:
"""int: Current iteration."""
return self._iter
@property
def inner_iter(self):
def inner_iter(self) -> int:
"""int: Iteration in an epoch."""
return self._inner_iter
......@@ -190,26 +196,28 @@ class BaseRunner(metaclass=ABCMeta):
pass
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
def run(self, data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]], **kwargs) -> Any:
pass
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
out_dir: str,
filename_tmpl: str,
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
pass
def current_lr(self):
def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current learning rates.
Returns:
list[float] | dict[str, list[float]]: Current learning rates of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
lr: Union[List[float], Dict[str, List[float]]]
if isinstance(self.optimizer, torch.optim.Optimizer):
lr = [group['lr'] for group in self.optimizer.param_groups]
elif isinstance(self.optimizer, dict):
......@@ -221,13 +229,13 @@ class BaseRunner(metaclass=ABCMeta):
'lr is not applicable because optimizer does not exist.')
return lr
def current_momentum(self):
def current_momentum(self) -> Union[List[float], Dict[str, List[float]]]:
"""Get current momentums.
Returns:
list[float] | dict[str, list[float]]: Current momentums of all
param groups. If the runner has a dict of optimizers, this
method will return a dict.
param groups. If the runner has a dict of optimizers, this method
will return a dict.
"""
def _get_momentum(optimizer):
......@@ -252,7 +260,9 @@ class BaseRunner(metaclass=ABCMeta):
momentums[name] = _get_momentum(optim)
return momentums
def register_hook(self, hook, priority='NORMAL'):
def register_hook(self,
hook: Hook,
priority: Union[int, str, Priority] = 'NORMAL') -> None:
"""Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified
......@@ -269,25 +279,25 @@ class BaseRunner(metaclass=ABCMeta):
if hasattr(hook, 'priority'):
raise ValueError('"priority" is a reserved attribute for hooks')
priority = get_priority(priority)
hook.priority = priority
hook.priority = priority # type: ignore
# insert the hook to a sorted list
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
if priority >= self._hooks[i].priority: # type: ignore
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Notes:
Note:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
......@@ -296,7 +306,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name):
def call_hook(self, fn_name: str) -> None:
"""Call all hooks.
Args:
......@@ -306,14 +316,14 @@ class BaseRunner(metaclass=ABCMeta):
for hook in self._hooks:
getattr(hook, fn_name)(self)
def get_hook_info(self):
def get_hook_info(self) -> str:
# Get hooks info in each stage
stage_hook_map = {stage: [] for stage in Hook.stages}
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
......@@ -329,11 +339,13 @@ class BaseRunner(metaclass=ABCMeta):
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
def load_checkpoint(self,
filename,
map_location='cpu',
strict=False,
revise_keys=[(r'^module.', '')]):
def load_checkpoint(
self,
filename: str,
map_location: Union[str, Callable] = 'cpu',
strict: bool = False,
revise_keys: List = [(r'^module.', '')],
) -> Union[Dict, OrderedDict]:
return load_checkpoint(
self.model,
filename,
......@@ -342,10 +354,11 @@ class BaseRunner(metaclass=ABCMeta):
self.logger,
revise_keys=revise_keys)
@no_type_check
def resume(self,
checkpoint,
resume_optimizer=True,
map_location='default'):
checkpoint: str,
resume_optimizer: bool = True,
map_location: Union[str, Callable] = 'default') -> None:
if map_location == 'default':
if torch.cuda.is_available():
device_id = torch.cuda.current_device()
......@@ -396,7 +409,7 @@ class BaseRunner(metaclass=ABCMeta):
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
def register_lr_hook(self, lr_config):
def register_lr_hook(self, lr_config: Union[Dict, Hook, None]) -> None:
if lr_config is None:
return
elif isinstance(lr_config, dict):
......@@ -417,7 +430,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = lr_config
self.register_hook(hook, priority='VERY_HIGH')
def register_momentum_hook(self, momentum_config):
def register_momentum_hook(
self, momentum_config: Union[Dict, Hook, None]) -> None:
if momentum_config is None:
return
if isinstance(momentum_config, dict):
......@@ -438,7 +452,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = momentum_config
self.register_hook(hook, priority='HIGH')
def register_optimizer_hook(self, optimizer_config):
def register_optimizer_hook(
self, optimizer_config: Union[Dict, Hook, None]) -> None:
if optimizer_config is None:
return
if isinstance(optimizer_config, dict):
......@@ -448,7 +463,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = optimizer_config
self.register_hook(hook, priority='ABOVE_NORMAL')
def register_checkpoint_hook(self, checkpoint_config):
def register_checkpoint_hook(
self, checkpoint_config: Union[Dict, Hook, None]) -> None:
if checkpoint_config is None:
return
if isinstance(checkpoint_config, dict):
......@@ -458,7 +474,7 @@ class BaseRunner(metaclass=ABCMeta):
hook = checkpoint_config
self.register_hook(hook, priority='NORMAL')
def register_logger_hooks(self, log_config):
def register_logger_hooks(self, log_config: Optional[Dict]) -> None:
if log_config is None:
return
log_interval = log_config['interval']
......@@ -467,7 +483,10 @@ class BaseRunner(metaclass=ABCMeta):
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
def register_timer_hook(
self,
timer_config: Union[Dict, Hook, None],
) -> None:
if timer_config is None:
return
if isinstance(timer_config, dict):
......@@ -477,7 +496,8 @@ class BaseRunner(metaclass=ABCMeta):
hook = timer_config
self.register_hook(hook, priority='LOW')
def register_custom_hooks(self, custom_config):
def register_custom_hooks(
self, custom_config: Union[List, Dict, Hook, None]) -> None:
if custom_config is None:
return
......@@ -490,7 +510,10 @@ class BaseRunner(metaclass=ABCMeta):
else:
self.register_hook(item, priority='NORMAL')
def register_profiler_hook(self, profiler_config):
def register_profiler_hook(
self,
profiler_config: Union[Dict, Hook, None],
) -> None:
if profiler_config is None:
return
if isinstance(profiler_config, dict):
......@@ -500,14 +523,15 @@ class BaseRunner(metaclass=ABCMeta):
hook = profiler_config
self.register_hook(hook)
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
def register_training_hooks(
self,
lr_config: Union[Dict, Hook, None],
optimizer_config: Union[Dict, Hook, None] = None,
checkpoint_config: Union[Dict, Hook, None] = None,
log_config: Optional[Dict] = None,
momentum_config: Union[Dict, Hook, None] = None,
timer_config: Union[Dict, Hook] = dict(type='IterTimerHook'),
custom_hooks_config: Union[List, Dict, Hook, None] = None) -> None:
"""Register default and custom hooks for training.
Default and custom hooks include:
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional
from ..utils import Registry
......@@ -7,11 +8,11 @@ RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')
def build_runner_constructor(cfg):
def build_runner_constructor(cfg: dict):
return RUNNER_BUILDERS.build(cfg)
def build_runner(cfg, default_args=None):
def build_runner(cfg: dict, default_args: Optional[dict] = None):
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
......
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
......@@ -9,8 +10,10 @@ import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchvision
from torch.optim import Optimizer
......@@ -18,7 +21,7 @@ import mmcv
from ..fileio import FileClient
from ..fileio import load as load_file
from ..parallel import is_module_wrapper
from ..utils import load_url, mkdir_or_exist
from ..utils import digit_version, load_url, mkdir_or_exist
from .dist_utils import get_dist_info
ENV_MMCV_HOME = 'MMCV_HOME'
......@@ -26,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home():
def _get_mmcv_home() -> str:
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
......@@ -37,7 +40,10 @@ def _get_mmcv_home():
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
def load_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
......@@ -46,21 +52,21 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
unexpected_keys: List[str] = []
all_missing_keys: List[str] = []
err_msg: List[str] = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
state_dict = state_dict.copy() # type: ignore
if metadata is not None:
state_dict._metadata = metadata
state_dict._metadata = metadata # type: ignore
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
......@@ -78,7 +84,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# break load->load reference cycle
load = None # type: ignore
# ignore "num_batches_tracked" of BN layers
missing_keys = [
......@@ -96,7 +103,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
......@@ -106,14 +113,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json')
model_urls = mmcv.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items():
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights')
or not hasattr(cls, 'DEFAULT')):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url
return model_urls
......@@ -147,7 +188,7 @@ def get_deprecated_model_names():
return deprecate_urls
def _process_mmcls_checkpoint(checkpoint):
def _process_mmcls_checkpoint(checkpoint: Dict) -> Dict:
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
......@@ -166,10 +207,13 @@ def _process_mmcls_checkpoint(checkpoint):
class CheckpointLoader:
"""A general checkpoint loader to manage all schemes."""
_schemes = {}
_schemes: dict = {}
@classmethod
def _register_scheme(cls, prefixes, loader, force=False):
def _register_scheme(cls,
prefixes: Union[str, List, Tuple],
loader: Callable,
force: bool = False) -> None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
......@@ -186,13 +230,16 @@ class CheckpointLoader:
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
def register_scheme(cls,
prefixes: Union[str, List[str], Tuple[str, ...]],
loader: Optional[Callable] = None,
force: bool = False) -> Callable:
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or list[str] or tuple[str]):
prefixes (str or Sequence[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
......@@ -203,7 +250,7 @@ class CheckpointLoader:
if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return
return # type: ignore
def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
......@@ -212,7 +259,7 @@ class CheckpointLoader:
return _register
@classmethod
def _get_checkpoint_loader(cls, path):
def _get_checkpoint_loader(cls, path: str):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
......@@ -220,15 +267,22 @@ class CheckpointLoader:
path (str): checkpoint path
Returns:
loader (function): checkpoint loader
callable: checkpoint loader
"""
for p in cls._schemes:
if path.startswith(p):
# use regular match to handle some cases that where the prefix of
# loader has a prefix. For example, both 's3://path' and
# 'open-mmlab:s3://path' should return `load_from_ceph`
if re.match(p, path) is not None:
return cls._schemes[p]
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None):
def load_checkpoint(
cls,
filename: str,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path.
Args:
......@@ -243,14 +297,17 @@ class CheckpointLoader:
"""
checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
class_name = checkpoint_loader.__name__ # type: ignore
mmcv.print_log(
f'load checkpoint from {class_name[10:]} path: {filename}', logger)
return checkpoint_loader(filename, map_location)
return checkpoint_loader(filename, map_location) # type: ignore
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
def load_from_local(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path.
Args:
......@@ -260,15 +317,18 @@ def load_from_local(filename, map_location):
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None):
def load_from_http(
filename: str,
map_location: Union[str, Callable, None] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
......@@ -276,7 +336,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object,
model_dir (str, optional): directory in which to save the object,
Default: None
Returns:
......@@ -295,7 +355,10 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
def load_from_pavi(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
......@@ -326,16 +389,23 @@ def load_from_pavi(filename, map_location=None):
return checkpoint
@CheckpointLoader.register_scheme(prefixes='s3://')
def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename: str,
map_location: Union[str, Callable, None] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
Note:
Since v1.4.1, the registered scheme prefixes have been enhanced to
support bucket names in the path prefix, e.g. 's3://xx.xx/xx.path',
'bucket1:s3://xx.xx/xx.path'.
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str, optional): The storage backend type. Options are 'ceph',
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
.. warning::
......@@ -351,7 +421,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
if backend == 'ceph':
warnings.warn(
'CephBackend will be deprecated, please use PetrelBackend instead')
'CephBackend will be deprecated, please use PetrelBackend instead',
DeprecationWarning)
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
# will be chosen as default. If PetrelBackend can not be instantiated
......@@ -368,7 +439,10 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
def load_from_torchvision(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
......@@ -382,16 +456,25 @@ def load_from_torchvision(filename, map_location=None):
"""
model_urls = get_torchvision_models()
if filename.startswith('modelzoo://'):
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead')
warnings.warn(
'The URL scheme of "modelzoo://" is deprecated, please '
'use "torchvision://" instead', DeprecationWarning)
model_name = filename[11:]
else:
model_name = filename[14:]
# Support getting model urls in the same way as torchvision
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
# resnet50.imagenet1k_v1.
model_name = model_name.lower().replace('_weights', '')
return load_from_http(model_urls[model_name], map_location=map_location)
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
def load_from_openmmlab(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
......@@ -415,8 +498,10 @@ def load_from_openmmlab(filename, map_location=None):
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}')
warnings.warn(
f'{prefix_str}{model_name} is deprecated in favor '
f'of {prefix_str}{deprecated_urls[model_name]}',
DeprecationWarning)
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
......@@ -425,13 +510,16 @@ def load_from_openmmlab(filename, map_location=None):
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
return checkpoint
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
def load_from_mmcls(
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls.
Args:
......@@ -450,7 +538,10 @@ def load_from_mmcls(filename, map_location=None):
return checkpoint
def _load_checkpoint(filename, map_location=None, logger=None):
def _load_checkpoint(
filename: str,
map_location: Union[str, Callable, None] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
......@@ -470,7 +561,11 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Union[str, Callable, None] = None,
) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix.
Args:
......@@ -503,12 +598,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return state_dict
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None,
revise_keys=[(r'^module\.', '')]):
def load_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Union[str, Callable, None] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.
Args:
......@@ -553,7 +649,7 @@ def load_checkpoint(model,
return checkpoint
def weights_to_cpu(state_dict):
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
......@@ -566,11 +662,13 @@ def weights_to_cpu(state_dict):
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
def _save_to_state_dict(module: torch.nn.Module, destination: dict,
prefix: str, keep_vars: bool) -> None:
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
......@@ -590,7 +688,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
def get_state_dict(module: torch.nn.Module,
destination: Optional[OrderedDict] = None,
prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
......@@ -619,10 +720,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
_save_to_state_dict(module, destination, prefix, keep_vars) # type: ignore
for name, child in module._modules.items():
if child is not None:
get_state_dict(
......@@ -631,14 +732,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
return destination # type: ignore
def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
def save_checkpoint(model: torch.nn.Module,
filename: str,
optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None,
file_client_args: Optional[dict] = None) -> None:
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
......@@ -669,7 +770,7 @@ def save_checkpoint(model,
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
'state_dict': weights_to_cpu(get_state_dict(model)) # type: ignore
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
......@@ -685,8 +786,7 @@ def save_checkpoint(model,
'file_client_args should be "None" if filename starts with'
f'"pavi://", but got {file_client_args}')
try:
from pavi import modelcloud
from pavi import exception
from pavi import exception, modelcloud
except ImportError:
raise ImportError(
'Please install pavi to load checkpoint from modelcloud.')
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from .builder import RUNNER_BUILDERS, RUNNERS
......@@ -33,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg, default_args=None):
def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
......
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import os
import socket
import subprocess
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
......@@ -10,8 +13,28 @@ from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import IS_MLU_AVAILABLE
def init_dist(launcher, backend='nccl', **kwargs):
def _find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
......@@ -24,23 +47,37 @@ def init_dist(launcher, backend='nccl', **kwargs):
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
if IS_MLU_AVAILABLE:
import torch_mlu # noqa: F401
torch.mlu.set_device(rank)
dist.init_process_group(
backend='cncl',
rank=rank,
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
if 'MASTER_ADDR' not in os.environ:
raise KeyError('The environment variable MASTER_ADDR is not set')
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
......@@ -64,8 +101,12 @@ def _init_dist_slurm(backend, port=None):
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
# if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(29500):
os.environ['MASTER_PORT'] = '29500'
else:
os.environ['MASTER_PORT'] = str(_find_free_port())
# use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
......@@ -75,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
dist.init_process_group(backend=backend)
def get_dist_info():
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -85,7 +126,7 @@ def get_dist_info():
return rank, world_size
def master_only(func):
def master_only(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
......@@ -96,12 +137,14 @@ def master_only(func):
return wrapper
def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
def allreduce_params(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce parameters.
Args:
params (list[torch.Parameters]): List of parameters or buffers of a
model.
params (list[torch.nn.Parameter]): List of parameters or buffers
of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
......@@ -118,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
def allreduce_grads(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce gradients.
Args:
params (list[torch.Parameters]): List of parameters of a model
params (list[torch.nn.Parameter]): List of parameters of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
......@@ -142,7 +187,9 @@ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
def _allreduce_coalesced(tensors: torch.Tensor,
world_size: int,
bucket_size_mb: int = -1) -> None:
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
......
......@@ -4,8 +4,10 @@ import platform
import shutil
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import mmcv
from .base_runner import BaseRunner
......@@ -21,7 +23,7 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
def run_iter(self, data_batch: Any, train_mode: bool, **kwargs) -> None:
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
......@@ -45,10 +47,12 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
self.call_hook('after_train_epoch')
......@@ -62,14 +66,19 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
......@@ -130,11 +139,11 @@ class EpochBasedRunner(BaseRunner):
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
out_dir: str,
filename_tmpl: str = 'epoch_{}.pth',
save_optimizer: bool = True,
meta: Optional[Dict] = None,
create_symlink: bool = True) -> None:
"""Save the checkpoint.
Args:
......@@ -183,5 +192,6 @@ class Runner(EpochBasedRunner):
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead')
'Runner was deprecated, please use EpochBasedRunner instead',
DeprecationWarning)
super().__init__(*args, **kwargs)
......@@ -3,10 +3,12 @@ import functools
import warnings
from collections import abc
from inspect import getfullargspec
from typing import Callable, Iterable, List, Optional
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads
......@@ -21,9 +23,18 @@ except ImportError:
pass
def cast_tensor_type(inputs, src_type, dst_type):
def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
......@@ -35,24 +46,30 @@ def cast_tensor_type(inputs, src_type, dst_type):
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return inputs.to(dst_type) if inputs.dtype == src_type else inputs
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
return type(inputs)({ # type: ignore
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
return type(inputs)( # type: ignore
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def auto_fp16(apply_to=None, out_fp32=False):
def auto_fp16(
apply_to: Optional[Iterable] = None,
out_fp32: bool = False,
supported_types: tuple = (nn.Module, ),
) -> Callable:
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
......@@ -65,7 +82,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
apply_to (Iterable, optional): The argument names to be converted.
`None` indicates all arguments.
out_fp32 (bool): Whether to convert the output back to fp32.
supported_types (tuple): Classes can be decorated by ``auto_fp16``.
`New in version 1.5.0.`
Example:
>>> import torch.nn as nn
......@@ -85,15 +103,15 @@ def auto_fp16(apply_to=None, out_fp32=False):
>>> pass
"""
def auto_fp16_wrapper(old_func):
def auto_fp16_wrapper(old_func: Callable) -> Callable:
@functools.wraps(old_func)
def new_func(*args, **kwargs):
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
if not isinstance(args[0], supported_types):
raise TypeError('@auto_fp16 can only be used to decorate the '
'method of nn.Module')
f'method of those classes {supported_types}')
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
return old_func(*args, **kwargs)
......@@ -138,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
return auto_fp16_wrapper
def force_fp32(apply_to=None, out_fp16=False):
def force_fp32(apply_to: Optional[Iterable] = None,
out_fp16: bool = False) -> Callable:
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
......@@ -176,7 +195,7 @@ def force_fp32(apply_to=None, out_fp16=False):
def force_fp32_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
......@@ -224,14 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
return force_fp32_wrapper
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
warnings.warning(
def allreduce_grads(params: List[Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
warnings.warn(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
def wrap_fp16_model(model):
def wrap_fp16_model(model: nn.Module) -> None:
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
......@@ -260,7 +282,7 @@ def wrap_fp16_model(model):
m.fp16_enabled = True
def patch_norm_fp32(module):
def patch_norm_fp32(module: nn.Module) -> nn.Module:
"""Recursively convert normalization layers from FP16 to FP32.
Args:
......@@ -280,7 +302,10 @@ def patch_norm_fp32(module):
return module
def patch_forward_method(func, src_type, dst_type, convert_output=True):
def patch_forward_method(func: Callable,
src_type: torch.dtype,
dst_type: torch.dtype,
convert_output: bool = True) -> Callable:
"""Patch the forward method of a module.
Args:
......@@ -333,10 +358,10 @@ class LossScaler:
"""
def __init__(self,
init_scale=2**32,
mode='dynamic',
scale_factor=2.,
scale_window=1000):
init_scale: float = 2**32,
mode: str = 'dynamic',
scale_factor: float = 2.,
scale_window: int = 1000):
self.cur_scale = init_scale
self.cur_iter = 0
assert mode in ('dynamic',
......@@ -346,7 +371,7 @@ class LossScaler:
self.scale_factor = scale_factor
self.scale_window = scale_window
def has_overflow(self, params):
def has_overflow(self, params: List[Parameter]) -> bool:
"""Check if params contain overflow."""
if self.mode != 'dynamic':
return False
......@@ -355,7 +380,7 @@ class LossScaler:
return True
return False
def _has_inf_or_nan(x):
def _has_inf_or_nan(x: torch.Tensor) -> bool:
"""Check if params contain NaN."""
try:
cpu_sum = float(x.float().sum())
......@@ -369,7 +394,7 @@ class LossScaler:
return True
return False
def update_scale(self, overflow):
def update_scale(self, overflow: bool) -> None:
"""update the current loss scale value when overflow happens."""
if self.mode != 'dynamic':
return
......@@ -382,7 +407,7 @@ class LossScaler:
self.cur_scale *= self.scale_factor
self.cur_iter += 1
def state_dict(self):
def state_dict(self) -> dict:
"""Returns the state of the scaler as a :class:`dict`."""
return dict(
cur_scale=self.cur_scale,
......@@ -392,7 +417,7 @@ class LossScaler:
scale_factor=self.scale_factor,
scale_window=self.scale_window)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: dict) -> None:
"""Loads the loss_scaler state dict.
Args:
......@@ -406,5 +431,5 @@ class LossScaler:
self.scale_window = state_dict['scale_window']
@property
def loss_scale(self):
def loss_scale(self) -> float:
return self.cur_scale
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment