Unverified Commit ea0e8cdb authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhancement] Add type ints in several files (#2020)



* [Enhance] Add type ints in these files:
'base_module.py', 'checkpoint.py',
'default_constructor.py',
'dist_utils.py', 'fp16_utils.py',
'log_buffer.py', 'priority.py'.

* Fixed all the inappropriate type hints.
Removed the return type of __init__ funcs.

* Fixed type hint.

* Added type hints in `dist_utils.py` and `log_buffer.py`.

* Remove unnecessary import.

* minor fix
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 88ae2a4e
...@@ -4,6 +4,7 @@ import warnings ...@@ -4,6 +4,7 @@ import warnings
from abc import ABCMeta from abc import ABCMeta
from collections import defaultdict from collections import defaultdict
from logging import FileHandler from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn import torch.nn as nn
...@@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
""" """
def __init__(self, init_cfg=None): def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`""" """Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg # NOTE init_cfg can be defined in different levels, but init_cfg
...@@ -133,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -133,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del sub_module._params_init_info del sub_module._params_init_info
@master_only @master_only
def _dump_init_info(self, logger_name): def _dump_init_info(self, logger_name: str):
"""Dump the initialization information to a file named """Dump the initialization information to a file named
`initialization.log.json` in workdir. `initialization.log.json` in workdir.
...@@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential): ...@@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
""" """
def __init__(self, *args, init_cfg=None): def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg) BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args) nn.Sequential.__init__(self, *args)
...@@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList): ...@@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
""" """
def __init__(self, modules=None, init_cfg=None): def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg) BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules) nn.ModuleList.__init__(self, modules)
...@@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict): ...@@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict):
init_cfg (dict, optional): Initialization config dict. init_cfg (dict, optional): Initialization config dict.
""" """
def __init__(self, modules=None, init_cfg=None): def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg) BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules) nn.ModuleDict.__init__(self, modules)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import io import io
import logging
import os import os
import os.path as osp import os.path as osp
import pkgutil import pkgutil
...@@ -9,6 +10,7 @@ import warnings ...@@ -9,6 +10,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Sequence, Union
import torch import torch
import torchvision import torchvision
...@@ -37,7 +39,10 @@ def _get_mmcv_home(): ...@@ -37,7 +39,10 @@ def _get_mmcv_home():
return mmcv_home return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None): def load_state_dict(module: torch.nn.Module,
state_dict: OrderedDict,
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module. """Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`. This method is modified from :meth:`torch.nn.Module.load_state_dict`.
...@@ -53,14 +58,14 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -53,14 +58,14 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
logger (:obj:`logging.Logger`, optional): Logger to log the error logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used. message. If not specified, print function will be used.
""" """
unexpected_keys = [] unexpected_keys: List = []
all_missing_keys = [] all_missing_keys: List = []
err_msg = [] err_msg: List = []
metadata = getattr(state_dict, '_metadata', None) metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy() state_dict = state_dict.copy()
if metadata is not None: if metadata is not None:
state_dict._metadata = metadata state_dict._metadata = metadata # type: ignore
# use _load_from_state_dict to enable checkpoint version control # use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''): def load(module, prefix=''):
...@@ -78,7 +83,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -78,7 +83,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
load(child, prefix + name + '.') load(child, prefix + name + '.')
load(module) load(module)
load = None # break load->load reference cycle # break load->load reference cycle
load = None # type: ignore
# ignore "num_batches_tracked" of BN layers # ignore "num_batches_tracked" of BN layers
missing_keys = [ missing_keys = [
...@@ -96,7 +102,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None): ...@@ -96,7 +102,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if len(err_msg) > 0 and rank == 0: if len(err_msg) > 0 and rank == 0:
err_msg.insert( err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n') 0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg) err_msg = '\n'.join(err_msg) # type: ignore
if strict: if strict:
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
elif logger is not None: elif logger is not None:
...@@ -220,13 +226,16 @@ class CheckpointLoader: ...@@ -220,13 +226,16 @@ class CheckpointLoader:
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True)) sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod @classmethod
def register_scheme(cls, prefixes, loader=None, force=False): def register_scheme(cls,
prefixes: Union[str, Sequence[str]],
loader: Optional[Callable] = None,
force: bool = False):
"""Register a loader to CheckpointLoader. """Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator. This method can be used as a normal class method or a decorator.
Args: Args:
prefixes (str or list[str] or tuple[str]): prefixes (str or Sequence[str]):
The prefix of the registered loader. The prefix of the registered loader.
loader (function, optional): The loader function to be registered. loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None. When this method is used as a decorator, loader is None.
...@@ -264,7 +273,12 @@ class CheckpointLoader: ...@@ -264,7 +273,12 @@ class CheckpointLoader:
return cls._schemes[p] return cls._schemes[p]
@classmethod @classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None): def load_checkpoint(
cls,
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path. """load checkpoint through URL scheme path.
Args: Args:
...@@ -286,7 +300,9 @@ class CheckpointLoader: ...@@ -286,7 +300,9 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='') @CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location): def load_from_local(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path. """load checkpoint by local file path.
Args: Args:
...@@ -304,7 +320,10 @@ def load_from_local(filename, map_location): ...@@ -304,7 +320,10 @@ def load_from_local(filename, map_location):
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://')) @CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None): def load_from_http(
filename: str,
map_location: Optional[str] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed """load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0. setting, this function only download checkpoint at local rank 0.
...@@ -312,7 +331,7 @@ def load_from_http(filename, map_location=None, model_dir=None): ...@@ -312,7 +331,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
filename (str): checkpoint file path with modelzoo or filename (str): checkpoint file path with modelzoo or
torchvision prefix torchvision prefix
map_location (str, optional): Same as :func:`torch.load`. map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object, model_dir (str, optional): directory in which to save the object,
Default: None Default: None
Returns: Returns:
...@@ -331,7 +350,9 @@ def load_from_http(filename, map_location=None, model_dir=None): ...@@ -331,7 +350,9 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://') @CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None): def load_from_pavi(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed """load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
...@@ -363,7 +384,9 @@ def load_from_pavi(filename, map_location=None): ...@@ -363,7 +384,9 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://') @CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename, map_location=None, backend='petrel'): def load_from_ceph(filename: str,
map_location: Optional[str] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed """load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary setting, this function download ckpt at all ranks to different temporary
directories. directories.
...@@ -376,7 +399,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): ...@@ -376,7 +399,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
Args: Args:
filename (str): checkpoint file path with s3 prefix filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`. map_location (str, optional): Same as :func:`torch.load`.
backend (str, optional): The storage backend type. Options are 'ceph', backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'. 'petrel'. Default: 'petrel'.
.. warning:: .. warning::
...@@ -410,7 +433,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): ...@@ -410,7 +433,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None): def load_from_torchvision(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or """load checkpoint through the file path prefixed with modelzoo or
torchvision. torchvision.
...@@ -439,7 +464,9 @@ def load_from_torchvision(filename, map_location=None): ...@@ -439,7 +464,9 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None): def load_from_openmmlab(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or """load checkpoint through the file path prefixed with open-mmlab or
openmmlab. openmmlab.
...@@ -481,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None): ...@@ -481,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://') @CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None): def load_from_mmcls(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls. """load checkpoint through the file path prefixed with mmcls.
Args: Args:
...@@ -500,7 +529,10 @@ def load_from_mmcls(filename, map_location=None): ...@@ -500,7 +529,10 @@ def load_from_mmcls(filename, map_location=None):
return checkpoint return checkpoint
def _load_checkpoint(filename, map_location=None, logger=None): def _load_checkpoint(
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url). """Load checkpoint from somewhere (modelzoo, file, url).
Args: Args:
...@@ -520,7 +552,10 @@ def _load_checkpoint(filename, map_location=None, logger=None): ...@@ -520,7 +552,10 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return CheckpointLoader.load_checkpoint(filename, map_location, logger) return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(prefix, filename, map_location=None): def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix. """Load partial pretrained model with specific prefix.
Args: Args:
...@@ -553,12 +588,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None): ...@@ -553,12 +588,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return state_dict return state_dict
def load_checkpoint(model, def load_checkpoint(
filename, model: torch.nn.Module,
map_location=None, filename: str,
strict=False, map_location: Optional[str] = None,
logger=None, strict: bool = False,
revise_keys=[(r'^module\.', '')]): logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI. """Load checkpoint from a file or URI.
Args: Args:
...@@ -603,7 +639,7 @@ def load_checkpoint(model, ...@@ -603,7 +639,7 @@ def load_checkpoint(model,
return checkpoint return checkpoint
def weights_to_cpu(state_dict): def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu. """Copy a model state_dict to cpu.
Args: Args:
...@@ -616,11 +652,13 @@ def weights_to_cpu(state_dict): ...@@ -616,11 +652,13 @@ def weights_to_cpu(state_dict):
for key, val in state_dict.items(): for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu() state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict # Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict()) state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars): def _save_to_state_dict(module: torch.nn.Module, destination: dict,
prefix: str, keep_vars: bool) -> None:
"""Saves module state to `destination` dictionary. """Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
...@@ -640,7 +678,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars): ...@@ -640,7 +678,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination[prefix + name] = buf if keep_vars else buf.detach() destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False): def get_state_dict(module: torch.nn.Module,
destination: Optional[OrderedDict] = None,
prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
"""Returns a dictionary containing a whole state of the module. """Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are Both parameters and persistent buffers (e.g. running averages) are
...@@ -669,8 +710,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): ...@@ -669,8 +710,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict() # below is the same as torch.nn.Module.state_dict()
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
destination._metadata = OrderedDict() destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version) version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars) _save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items(): for name, child in module._modules.items():
...@@ -681,14 +722,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): ...@@ -681,14 +722,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result = hook(module, destination, prefix, local_metadata) hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None: if hook_result is not None:
destination = hook_result destination = hook_result
return destination return destination # type: ignore
def save_checkpoint(model, def save_checkpoint(model: torch.nn.Module,
filename, filename: str,
optimizer=None, optimizer: Optional[Optimizer] = None,
meta=None, meta: Optional[dict] = None,
file_client_args=None): file_client_args: Optional[dict] = None) -> None:
"""Save checkpoint to file. """Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from .builder import RUNNER_BUILDERS, RUNNERS from .builder import RUNNER_BUILDERS, RUNNERS
...@@ -34,7 +36,7 @@ class DefaultRunnerConstructor: ...@@ -34,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg) >>> runner = build_runner(runner_cfg)
""" """
def __init__(self, runner_cfg, default_args=None): def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict): if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict', raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}') f'but got {type(runner_cfg)}')
......
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import socket import socket
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Optional, Tuple
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -33,7 +34,7 @@ def _is_free_port(port): ...@@ -33,7 +34,7 @@ def _is_free_port(port):
return all(s.connect_ex((ip, port)) != 0 for ip in ips) return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher, backend='nccl', **kwargs): def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
if launcher == 'pytorch': if launcher == 'pytorch':
...@@ -46,7 +47,7 @@ def init_dist(launcher, backend='nccl', **kwargs): ...@@ -46,7 +47,7 @@ def init_dist(launcher, backend='nccl', **kwargs):
raise ValueError(f'Invalid launcher type: {launcher}') raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs): def _init_dist_pytorch(backend: str, **kwargs):
# TODO: use local_rank instead of rank % num_gpus # TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
...@@ -63,7 +64,7 @@ def _init_dist_pytorch(backend, **kwargs): ...@@ -63,7 +64,7 @@ def _init_dist_pytorch(backend, **kwargs):
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs): def _init_dist_mpi(backend: str, **kwargs):
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ: if 'MASTER_PORT' not in os.environ:
...@@ -76,7 +77,7 @@ def _init_dist_mpi(backend, **kwargs): ...@@ -76,7 +77,7 @@ def _init_dist_mpi(backend, **kwargs):
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None): def _init_dist_slurm(backend: str, port: Optional[int] = None):
"""Initialize slurm distributed training environment. """Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system If argument ``port`` is not specified, then the master port will be system
...@@ -115,7 +116,7 @@ def _init_dist_slurm(backend, port=None): ...@@ -115,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
dist.init_process_group(backend=backend) dist.init_process_group(backend=backend)
def get_dist_info(): def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -125,7 +126,7 @@ def get_dist_info(): ...@@ -125,7 +126,7 @@ def get_dist_info():
return rank, world_size return rank, world_size
def master_only(func): def master_only(func: Callable) -> Callable:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -136,12 +137,14 @@ def master_only(func): ...@@ -136,12 +137,14 @@ def master_only(func):
return wrapper return wrapper
def allreduce_params(params, coalesce=True, bucket_size_mb=-1): def allreduce_params(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce parameters. """Allreduce parameters.
Args: Args:
params (list[torch.Parameters]): List of parameters or buffers of a params (list[torch.nn.Parameter]): List of parameters or buffers
model. of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole. coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True. Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB. bucket_size_mb (int, optional): Size of bucket, the unit is MB.
...@@ -158,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1): ...@@ -158,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
dist.all_reduce(tensor.div_(world_size)) dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): def allreduce_grads(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce gradients. """Allreduce gradients.
Args: Args:
params (list[torch.Parameters]): List of parameters of a model params (list[torch.nn.Parameter]): List of parameters of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole. coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True. Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB. bucket_size_mb (int, optional): Size of bucket, the unit is MB.
......
...@@ -3,10 +3,12 @@ import functools ...@@ -3,10 +3,12 @@ import functools
import warnings import warnings
from collections import abc from collections import abc
from inspect import getfullargspec from inspect import getfullargspec
from typing import Callable, Iterable, List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parameter import Parameter
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads from .dist_utils import allreduce_grads as _allreduce_grads
...@@ -21,7 +23,7 @@ except ImportError: ...@@ -21,7 +23,7 @@ except ImportError:
pass pass
def cast_tensor_type(inputs, src_type, dst_type): def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype):
"""Recursively convert Tensor in inputs from src_type to dst_type. """Recursively convert Tensor in inputs from src_type to dst_type.
Note: Note:
...@@ -52,18 +54,22 @@ def cast_tensor_type(inputs, src_type, dst_type): ...@@ -52,18 +54,22 @@ def cast_tensor_type(inputs, src_type, dst_type):
elif isinstance(inputs, np.ndarray): elif isinstance(inputs, np.ndarray):
return inputs return inputs
elif isinstance(inputs, abc.Mapping): elif isinstance(inputs, abc.Mapping):
return type(inputs)({ return type(inputs)({ # type: ignore
k: cast_tensor_type(v, src_type, dst_type) k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items() for k, v in inputs.items()
}) })
elif isinstance(inputs, abc.Iterable): elif isinstance(inputs, abc.Iterable):
return type(inputs)( return type(inputs)( # type: ignore
cast_tensor_type(item, src_type, dst_type) for item in inputs) cast_tensor_type(item, src_type, dst_type) for item in inputs)
else: else:
return inputs return inputs
def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )): def auto_fp16(
apply_to: Optional[Iterable] = None,
out_fp32: bool = False,
supported_types: tuple = (nn.Module, ),
) -> Callable:
"""Decorator to enable fp16 training automatically. """Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support This decorator is useful when you write custom modules and want to support
...@@ -150,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )): ...@@ -150,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )):
return auto_fp16_wrapper return auto_fp16_wrapper
def force_fp32(apply_to=None, out_fp16=False): def force_fp32(apply_to: Optional[Iterable] = None,
out_fp16: bool = False) -> Callable:
"""Decorator to convert input arguments to fp32 in force. """Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support This decorator is useful when you write custom modules and want to support
...@@ -236,15 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -236,15 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
return force_fp32_wrapper return force_fp32_wrapper
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): def allreduce_grads(params: List[Parameter],
warnings.warning( coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
warnings.warn(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be ' '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads', 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning) DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb) _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
def wrap_fp16_model(model): def wrap_fp16_model(model: nn.Module) -> None:
"""Wrap the FP32 model to FP16. """Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
...@@ -273,7 +282,7 @@ def wrap_fp16_model(model): ...@@ -273,7 +282,7 @@ def wrap_fp16_model(model):
m.fp16_enabled = True m.fp16_enabled = True
def patch_norm_fp32(module): def patch_norm_fp32(module: nn.Module) -> nn.Module:
"""Recursively convert normalization layers from FP16 to FP32. """Recursively convert normalization layers from FP16 to FP32.
Args: Args:
...@@ -293,7 +302,10 @@ def patch_norm_fp32(module): ...@@ -293,7 +302,10 @@ def patch_norm_fp32(module):
return module return module
def patch_forward_method(func, src_type, dst_type, convert_output=True): def patch_forward_method(func: Callable,
src_type: torch.dtype,
dst_type: torch.dtype,
convert_output: bool = True) -> Callable:
"""Patch the forward method of a module. """Patch the forward method of a module.
Args: Args:
...@@ -346,10 +358,10 @@ class LossScaler: ...@@ -346,10 +358,10 @@ class LossScaler:
""" """
def __init__(self, def __init__(self,
init_scale=2**32, init_scale: float = 2**32,
mode='dynamic', mode: str = 'dynamic',
scale_factor=2., scale_factor: float = 2.,
scale_window=1000): scale_window: int = 1000):
self.cur_scale = init_scale self.cur_scale = init_scale
self.cur_iter = 0 self.cur_iter = 0
assert mode in ('dynamic', assert mode in ('dynamic',
...@@ -359,7 +371,7 @@ class LossScaler: ...@@ -359,7 +371,7 @@ class LossScaler:
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.scale_window = scale_window self.scale_window = scale_window
def has_overflow(self, params): def has_overflow(self, params: List[Parameter]) -> bool:
"""Check if params contain overflow.""" """Check if params contain overflow."""
if self.mode != 'dynamic': if self.mode != 'dynamic':
return False return False
...@@ -382,7 +394,7 @@ class LossScaler: ...@@ -382,7 +394,7 @@ class LossScaler:
return True return True
return False return False
def update_scale(self, overflow): def update_scale(self, overflow: bool) -> None:
"""update the current loss scale value when overflow happens.""" """update the current loss scale value when overflow happens."""
if self.mode != 'dynamic': if self.mode != 'dynamic':
return return
...@@ -405,7 +417,7 @@ class LossScaler: ...@@ -405,7 +417,7 @@ class LossScaler:
scale_factor=self.scale_factor, scale_factor=self.scale_factor,
scale_window=self.scale_window) scale_window=self.scale_window)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict: dict) -> None:
"""Loads the loss_scaler state dict. """Loads the loss_scaler state dict.
Args: Args:
......
...@@ -12,16 +12,16 @@ class LogBuffer: ...@@ -12,16 +12,16 @@ class LogBuffer:
self.output = OrderedDict() self.output = OrderedDict()
self.ready = False self.ready = False
def clear(self): def clear(self) -> None:
self.val_history.clear() self.val_history.clear()
self.n_history.clear() self.n_history.clear()
self.clear_output() self.clear_output()
def clear_output(self): def clear_output(self) -> None:
self.output.clear() self.output.clear()
self.ready = False self.ready = False
def update(self, vars, count=1): def update(self, vars: dict, count: int = 1) -> None:
assert isinstance(vars, dict) assert isinstance(vars, dict)
for key, var in vars.items(): for key, var in vars.items():
if key not in self.val_history: if key not in self.val_history:
...@@ -30,7 +30,7 @@ class LogBuffer: ...@@ -30,7 +30,7 @@ class LogBuffer:
self.val_history[key].append(var) self.val_history[key].append(var)
self.n_history[key].append(count) self.n_history[key].append(count)
def average(self, n=0): def average(self, n: int = 0) -> None:
"""Average latest n values or all values.""" """Average latest n values or all values."""
assert n >= 0 assert n >= 0
for key in self.val_history: for key in self.val_history:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from enum import Enum from enum import Enum
from typing import Union
class Priority(Enum): class Priority(Enum):
...@@ -39,7 +40,7 @@ class Priority(Enum): ...@@ -39,7 +40,7 @@ class Priority(Enum):
LOWEST = 100 LOWEST = 100
def get_priority(priority): def get_priority(priority: Union[int, str, Priority]) -> int:
"""Get priority value. """Get priority value.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment