Unverified Commit ea0e8cdb authored by WINDSKY45's avatar WINDSKY45 Committed by GitHub
Browse files

[Enhancement] Add type ints in several files (#2020)



* [Enhance] Add type ints in these files:
'base_module.py', 'checkpoint.py',
'default_constructor.py',
'dist_utils.py', 'fp16_utils.py',
'log_buffer.py', 'priority.py'.

* Fixed all the inappropriate type hints.
Removed the return type of __init__ funcs.

* Fixed type hint.

* Added type hints in `dist_utils.py` and `log_buffer.py`.

* Remove unnecessary import.

* minor fix
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 88ae2a4e
......@@ -4,6 +4,7 @@ import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
......@@ -29,7 +30,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Optional[dict] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
......@@ -133,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del sub_module._params_init_info
@master_only
def _dump_init_info(self, logger_name):
def _dump_init_info(self, logger_name: str):
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
......@@ -176,7 +177,7 @@ class Sequential(BaseModule, nn.Sequential):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg=None):
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
......@@ -189,7 +190,9 @@ class ModuleList(BaseModule, nn.ModuleList):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
......@@ -203,6 +206,8 @@ class ModuleDict(BaseModule, nn.ModuleDict):
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, modules=None, init_cfg=None):
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
......@@ -9,6 +10,7 @@ import warnings
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Sequence, Union
import torch
import torchvision
......@@ -37,7 +39,10 @@ def _get_mmcv_home():
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
def load_state_dict(module: torch.nn.Module,
state_dict: OrderedDict,
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
"""Load state_dict to a module.
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
......@@ -53,14 +58,14 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys = []
all_missing_keys = []
err_msg = []
unexpected_keys: List = []
all_missing_keys: List = []
err_msg: List = []
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
state_dict._metadata = metadata # type: ignore
# use _load_from_state_dict to enable checkpoint version control
def load(module, prefix=''):
......@@ -78,7 +83,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
load(child, prefix + name + '.')
load(module)
load = None # break load->load reference cycle
# break load->load reference cycle
load = None # type: ignore
# ignore "num_batches_tracked" of BN layers
missing_keys = [
......@@ -96,7 +102,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
if len(err_msg) > 0 and rank == 0:
err_msg.insert(
0, 'The model and loaded state dict do not match exactly\n')
err_msg = '\n'.join(err_msg)
err_msg = '\n'.join(err_msg) # type: ignore
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
......@@ -220,13 +226,16 @@ class CheckpointLoader:
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
@classmethod
def register_scheme(cls, prefixes, loader=None, force=False):
def register_scheme(cls,
prefixes: Union[str, Sequence[str]],
loader: Optional[Callable] = None,
force: bool = False):
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Args:
prefixes (str or list[str] or tuple[str]):
prefixes (str or Sequence[str]):
The prefix of the registered loader.
loader (function, optional): The loader function to be registered.
When this method is used as a decorator, loader is None.
......@@ -264,7 +273,12 @@ class CheckpointLoader:
return cls._schemes[p]
@classmethod
def load_checkpoint(cls, filename, map_location=None, logger=None):
def load_checkpoint(
cls,
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None
) -> Union[dict, OrderedDict]:
"""load checkpoint through URL scheme path.
Args:
......@@ -286,7 +300,9 @@ class CheckpointLoader:
@CheckpointLoader.register_scheme(prefixes='')
def load_from_local(filename, map_location):
def load_from_local(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint by local file path.
Args:
......@@ -304,7 +320,10 @@ def load_from_local(filename, map_location):
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
def load_from_http(filename, map_location=None, model_dir=None):
def load_from_http(
filename: str,
map_location: Optional[str] = None,
model_dir: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
setting, this function only download checkpoint at local rank 0.
......@@ -312,7 +331,7 @@ def load_from_http(filename, map_location=None, model_dir=None):
filename (str): checkpoint file path with modelzoo or
torchvision prefix
map_location (str, optional): Same as :func:`torch.load`.
model_dir (string, optional): directory in which to save the object,
model_dir (str, optional): directory in which to save the object,
Default: None
Returns:
......@@ -331,7 +350,9 @@ def load_from_http(filename, map_location=None, model_dir=None):
@CheckpointLoader.register_scheme(prefixes='pavi://')
def load_from_pavi(filename, map_location=None):
def load_from_pavi(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with pavi. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
......@@ -363,7 +384,9 @@ def load_from_pavi(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=r'(\S+\:)?s3://')
def load_from_ceph(filename, map_location=None, backend='petrel'):
def load_from_ceph(filename: str,
map_location: Optional[str] = None,
backend: str = 'petrel') -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with s3. In distributed
setting, this function download ckpt at all ranks to different temporary
directories.
......@@ -376,7 +399,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
Args:
filename (str): checkpoint file path with s3 prefix
map_location (str, optional): Same as :func:`torch.load`.
backend (str, optional): The storage backend type. Options are 'ceph',
backend (str): The storage backend type. Options are 'ceph',
'petrel'. Default: 'petrel'.
.. warning::
......@@ -410,7 +433,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
def load_from_torchvision(filename, map_location=None):
def load_from_torchvision(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with modelzoo or
torchvision.
......@@ -439,7 +464,9 @@ def load_from_torchvision(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
def load_from_openmmlab(filename, map_location=None):
def load_from_openmmlab(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with open-mmlab or
openmmlab.
......@@ -481,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None):
@CheckpointLoader.register_scheme(prefixes='mmcls://')
def load_from_mmcls(filename, map_location=None):
def load_from_mmcls(
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""load checkpoint through the file path prefixed with mmcls.
Args:
......@@ -500,7 +529,10 @@ def load_from_mmcls(filename, map_location=None):
return checkpoint
def _load_checkpoint(filename, map_location=None, logger=None):
def _load_checkpoint(
filename: str,
map_location: Optional[str] = None,
logger: Optional[logging.Logger] = None) -> Union[dict, OrderedDict]:
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
......@@ -520,7 +552,10 @@ def _load_checkpoint(filename, map_location=None, logger=None):
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
def _load_checkpoint_with_prefix(
prefix: str,
filename: str,
map_location: Optional[str] = None) -> Union[dict, OrderedDict]:
"""Load partial pretrained model with specific prefix.
Args:
......@@ -553,12 +588,13 @@ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
return state_dict
def load_checkpoint(model,
filename,
map_location=None,
strict=False,
logger=None,
revise_keys=[(r'^module\.', '')]):
def load_checkpoint(
model: torch.nn.Module,
filename: str,
map_location: Optional[str] = None,
strict: bool = False,
logger: Optional[logging.Logger] = None,
revise_keys: list = [(r'^module\.', '')]) -> Union[dict, OrderedDict]:
"""Load checkpoint from a file or URI.
Args:
......@@ -603,7 +639,7 @@ def load_checkpoint(model,
return checkpoint
def weights_to_cpu(state_dict):
def weights_to_cpu(state_dict: OrderedDict) -> OrderedDict:
"""Copy a model state_dict to cpu.
Args:
......@@ -616,11 +652,13 @@ def weights_to_cpu(state_dict):
for key, val in state_dict.items():
state_dict_cpu[key] = val.cpu()
# Keep metadata in state_dict
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
state_dict_cpu._metadata = getattr( # type: ignore
state_dict, '_metadata', OrderedDict())
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
def _save_to_state_dict(module: torch.nn.Module, destination: dict,
prefix: str, keep_vars: bool) -> None:
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
......@@ -640,7 +678,10 @@ def _save_to_state_dict(module, destination, prefix, keep_vars):
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
def get_state_dict(module: torch.nn.Module,
destination: Optional[OrderedDict] = None,
prefix: str = '',
keep_vars: bool = False) -> OrderedDict:
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
......@@ -669,8 +710,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
......@@ -681,14 +722,14 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
return destination # type: ignore
def save_checkpoint(model,
filename,
optimizer=None,
meta=None,
file_client_args=None):
def save_checkpoint(model: torch.nn.Module,
filename: str,
optimizer: Optional[Optimizer] = None,
meta: Optional[dict] = None,
file_client_args: Optional[dict] = None) -> None:
"""Save checkpoint to file.
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from .builder import RUNNER_BUILDERS, RUNNERS
......@@ -34,7 +36,7 @@ class DefaultRunnerConstructor:
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg, default_args=None):
def __init__(self, runner_cfg: dict, default_args: Optional[dict] = None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
......
......@@ -5,6 +5,7 @@ import os
import socket
import subprocess
from collections import OrderedDict
from typing import Callable, List, Optional, Tuple
import torch
import torch.multiprocessing as mp
......@@ -33,7 +34,7 @@ def _is_free_port(port):
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher, backend='nccl', **kwargs):
def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
......@@ -46,7 +47,7 @@ def init_dist(launcher, backend='nccl', **kwargs):
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
def _init_dist_pytorch(backend: str, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE:
......@@ -63,7 +64,7 @@ def _init_dist_pytorch(backend, **kwargs):
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
def _init_dist_mpi(backend: str, **kwargs):
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
......@@ -76,7 +77,7 @@ def _init_dist_mpi(backend, **kwargs):
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
def _init_dist_slurm(backend: str, port: Optional[int] = None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
......@@ -115,7 +116,7 @@ def _init_dist_slurm(backend, port=None):
dist.init_process_group(backend=backend)
def get_dist_info():
def get_dist_info() -> Tuple[int, int]:
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -125,7 +126,7 @@ def get_dist_info():
return rank, world_size
def master_only(func):
def master_only(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
......@@ -136,12 +137,14 @@ def master_only(func):
return wrapper
def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
def allreduce_params(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce parameters.
Args:
params (list[torch.Parameters]): List of parameters or buffers of a
model.
params (list[torch.nn.Parameter]): List of parameters or buffers
of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
......@@ -158,11 +161,13 @@ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
dist.all_reduce(tensor.div_(world_size))
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
def allreduce_grads(params: List[torch.nn.Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""Allreduce gradients.
Args:
params (list[torch.Parameters]): List of parameters of a model
params (list[torch.nn.Parameter]): List of parameters of a model.
coalesce (bool, optional): Whether allreduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
......
......@@ -3,10 +3,12 @@ import functools
import warnings
from collections import abc
from inspect import getfullargspec
from typing import Callable, Iterable, List, Optional
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads
......@@ -21,7 +23,7 @@ except ImportError:
pass
def cast_tensor_type(inputs, src_type, dst_type):
def cast_tensor_type(inputs, src_type: torch.dtype, dst_type: torch.dtype):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
......@@ -52,18 +54,22 @@ def cast_tensor_type(inputs, src_type, dst_type):
elif isinstance(inputs, np.ndarray):
return inputs
elif isinstance(inputs, abc.Mapping):
return type(inputs)({
return type(inputs)({ # type: ignore
k: cast_tensor_type(v, src_type, dst_type)
for k, v in inputs.items()
})
elif isinstance(inputs, abc.Iterable):
return type(inputs)(
return type(inputs)( # type: ignore
cast_tensor_type(item, src_type, dst_type) for item in inputs)
else:
return inputs
def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )):
def auto_fp16(
apply_to: Optional[Iterable] = None,
out_fp32: bool = False,
supported_types: tuple = (nn.Module, ),
) -> Callable:
"""Decorator to enable fp16 training automatically.
This decorator is useful when you write custom modules and want to support
......@@ -150,7 +156,8 @@ def auto_fp16(apply_to=None, out_fp32=False, supported_types=(nn.Module, )):
return auto_fp16_wrapper
def force_fp32(apply_to=None, out_fp16=False):
def force_fp32(apply_to: Optional[Iterable] = None,
out_fp16: bool = False) -> Callable:
"""Decorator to convert input arguments to fp32 in force.
This decorator is useful when you write custom modules and want to support
......@@ -236,15 +243,17 @@ def force_fp32(apply_to=None, out_fp16=False):
return force_fp32_wrapper
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
warnings.warning(
def allreduce_grads(params: List[Parameter],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
warnings.warn(
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads',
DeprecationWarning)
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
def wrap_fp16_model(model):
def wrap_fp16_model(model: nn.Module) -> None:
"""Wrap the FP32 model to FP16.
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
......@@ -273,7 +282,7 @@ def wrap_fp16_model(model):
m.fp16_enabled = True
def patch_norm_fp32(module):
def patch_norm_fp32(module: nn.Module) -> nn.Module:
"""Recursively convert normalization layers from FP16 to FP32.
Args:
......@@ -293,7 +302,10 @@ def patch_norm_fp32(module):
return module
def patch_forward_method(func, src_type, dst_type, convert_output=True):
def patch_forward_method(func: Callable,
src_type: torch.dtype,
dst_type: torch.dtype,
convert_output: bool = True) -> Callable:
"""Patch the forward method of a module.
Args:
......@@ -346,10 +358,10 @@ class LossScaler:
"""
def __init__(self,
init_scale=2**32,
mode='dynamic',
scale_factor=2.,
scale_window=1000):
init_scale: float = 2**32,
mode: str = 'dynamic',
scale_factor: float = 2.,
scale_window: int = 1000):
self.cur_scale = init_scale
self.cur_iter = 0
assert mode in ('dynamic',
......@@ -359,7 +371,7 @@ class LossScaler:
self.scale_factor = scale_factor
self.scale_window = scale_window
def has_overflow(self, params):
def has_overflow(self, params: List[Parameter]) -> bool:
"""Check if params contain overflow."""
if self.mode != 'dynamic':
return False
......@@ -382,7 +394,7 @@ class LossScaler:
return True
return False
def update_scale(self, overflow):
def update_scale(self, overflow: bool) -> None:
"""update the current loss scale value when overflow happens."""
if self.mode != 'dynamic':
return
......@@ -405,7 +417,7 @@ class LossScaler:
scale_factor=self.scale_factor,
scale_window=self.scale_window)
def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: dict) -> None:
"""Loads the loss_scaler state dict.
Args:
......
......@@ -12,16 +12,16 @@ class LogBuffer:
self.output = OrderedDict()
self.ready = False
def clear(self):
def clear(self) -> None:
self.val_history.clear()
self.n_history.clear()
self.clear_output()
def clear_output(self):
def clear_output(self) -> None:
self.output.clear()
self.ready = False
def update(self, vars, count=1):
def update(self, vars: dict, count: int = 1) -> None:
assert isinstance(vars, dict)
for key, var in vars.items():
if key not in self.val_history:
......@@ -30,7 +30,7 @@ class LogBuffer:
self.val_history[key].append(var)
self.n_history[key].append(count)
def average(self, n=0):
def average(self, n: int = 0) -> None:
"""Average latest n values or all values."""
assert n >= 0
for key in self.val_history:
......
# Copyright (c) OpenMMLab. All rights reserved.
from enum import Enum
from typing import Union
class Priority(Enum):
......@@ -39,7 +40,7 @@ class Priority(Enum):
LOWEST = 100
def get_priority(priority):
def get_priority(priority: Union[int, str, Priority]) -> int:
"""Get priority value.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment