Unverified Commit 1f250010 authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/runner (#2002)



* Fix

* Fix

* fix type hint

* minor fix

* remove some type hints of functions or methods

* minor fix

* Apply suggestions from code review

* minor fix

* minor refinement
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent b9a96e56
...@@ -64,8 +64,8 @@ class BaseMergeCell(nn.Module): ...@@ -64,8 +64,8 @@ class BaseMergeCell(nn.Module):
if self.with_out_conv: if self.with_out_conv:
self.out_conv = ConvModule( self.out_conv = ConvModule(
fused_channels, fused_channels, # type: ignore
out_channels, out_channels, # type: ignore
**out_conv_cfg, **out_conv_cfg,
norm_cfg=out_norm_cfg, norm_cfg=out_norm_cfg,
order=out_conv_order) order=out_conv_order)
......
...@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -50,10 +50,10 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
@property @property
def is_init(self): def is_init(self) -> bool:
return self._is_init return self._is_init
def init_weights(self): def init_weights(self) -> None:
"""Initialize the weights.""" """Initialize the weights."""
is_top_level_module = False is_top_level_module = False
...@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -68,7 +68,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# which indicates whether the parameter has been modified. # which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters # this attribute would be deleted after all parameters
# is initialized. # is initialized.
self._params_init_info = defaultdict(dict) self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True is_top_level_module = True
# Initialize the `_params_init_info`, # Initialize the `_params_init_info`,
...@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -134,7 +134,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
del sub_module._params_init_info del sub_module._params_init_info
@master_only @master_only
def _dump_init_info(self, logger_name: str): def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named """Dump the initialization information to a file named
`initialization.log.json` in workdir. `initialization.log.json` in workdir.
......
...@@ -10,9 +10,10 @@ import warnings ...@@ -10,9 +10,10 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Sequence, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn
import torchvision import torchvision
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -28,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' ...@@ -28,7 +29,7 @@ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache' DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home(): def _get_mmcv_home() -> str:
mmcv_home = os.path.expanduser( mmcv_home = os.path.expanduser(
os.getenv( os.getenv(
ENV_MMCV_HOME, ENV_MMCV_HOME,
...@@ -39,7 +40,7 @@ def _get_mmcv_home(): ...@@ -39,7 +40,7 @@ def _get_mmcv_home():
return mmcv_home return mmcv_home
def load_state_dict(module: torch.nn.Module, def load_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict], state_dict: Union[dict, OrderedDict],
strict: bool = False, strict: bool = False,
logger: Optional[logging.Logger] = None) -> None: logger: Optional[logging.Logger] = None) -> None:
...@@ -51,19 +52,19 @@ def load_state_dict(module: torch.nn.Module, ...@@ -51,19 +52,19 @@ def load_state_dict(module: torch.nn.Module,
Args: Args:
module (Module): Module that receives the state_dict. module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights. state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used. message. If not specified, print function will be used.
""" """
unexpected_keys: List = [] unexpected_keys: List[str] = []
all_missing_keys: List = [] all_missing_keys: List[str] = []
err_msg: List = [] err_msg: List[str] = []
metadata = getattr(state_dict, '_metadata', None) metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy() state_dict = state_dict.copy() # type: ignore
if metadata is not None: if metadata is not None:
state_dict._metadata = metadata # type: ignore state_dict._metadata = metadata # type: ignore
...@@ -187,7 +188,7 @@ def get_deprecated_model_names(): ...@@ -187,7 +188,7 @@ def get_deprecated_model_names():
return deprecate_urls return deprecate_urls
def _process_mmcls_checkpoint(checkpoint): def _process_mmcls_checkpoint(checkpoint: Dict) -> Dict:
if 'state_dict' in checkpoint: if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict'] state_dict = checkpoint['state_dict']
else: else:
...@@ -209,7 +210,10 @@ class CheckpointLoader: ...@@ -209,7 +210,10 @@ class CheckpointLoader:
_schemes: dict = {} _schemes: dict = {}
@classmethod @classmethod
def _register_scheme(cls, prefixes, loader, force=False): def _register_scheme(cls,
prefixes: Union[str, List, Tuple],
loader: Callable,
force: bool = False) -> None:
if isinstance(prefixes, str): if isinstance(prefixes, str):
prefixes = [prefixes] prefixes = [prefixes]
else: else:
...@@ -227,9 +231,9 @@ class CheckpointLoader: ...@@ -227,9 +231,9 @@ class CheckpointLoader:
@classmethod @classmethod
def register_scheme(cls, def register_scheme(cls,
prefixes: Union[str, Sequence[str]], prefixes: Union[str, List[str], Tuple[str, ...]],
loader: Optional[Callable] = None, loader: Optional[Callable] = None,
force: bool = False): force: bool = False) -> Callable:
"""Register a loader to CheckpointLoader. """Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator. This method can be used as a normal class method or a decorator.
...@@ -246,7 +250,7 @@ class CheckpointLoader: ...@@ -246,7 +250,7 @@ class CheckpointLoader:
if loader is not None: if loader is not None:
cls._register_scheme(prefixes, loader, force=force) cls._register_scheme(prefixes, loader, force=force)
return return # type: ignore
def _register(loader_cls): def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force) cls._register_scheme(prefixes, loader_cls, force=force)
...@@ -255,7 +259,7 @@ class CheckpointLoader: ...@@ -255,7 +259,7 @@ class CheckpointLoader:
return _register return _register
@classmethod @classmethod
def _get_checkpoint_loader(cls, path): def _get_checkpoint_loader(cls, path: str):
"""Finds a loader that supports the given path. Falls back to the local """Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found. loader if no other loader is found.
...@@ -293,10 +297,10 @@ class CheckpointLoader: ...@@ -293,10 +297,10 @@ class CheckpointLoader:
""" """
checkpoint_loader = cls._get_checkpoint_loader(filename) checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__ class_name = checkpoint_loader.__name__ # type: ignore
mmcv.print_log( mmcv.print_log(
f'load checkpoint from {class_name[10:]} path: {filename}', logger) f'load checkpoint from {class_name[10:]} path: {filename}', logger)
return checkpoint_loader(filename, map_location) return checkpoint_loader(filename, map_location) # type: ignore
@CheckpointLoader.register_scheme(prefixes='') @CheckpointLoader.register_scheme(prefixes='')
...@@ -719,7 +723,7 @@ def get_state_dict(module: torch.nn.Module, ...@@ -719,7 +723,7 @@ def get_state_dict(module: torch.nn.Module,
destination._metadata = OrderedDict() # type: ignore destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version) version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars) _save_to_state_dict(module, destination, prefix, keep_vars) # type: ignore
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
get_state_dict( get_state_dict(
...@@ -766,7 +770,7 @@ def save_checkpoint(model: torch.nn.Module, ...@@ -766,7 +770,7 @@ def save_checkpoint(model: torch.nn.Module,
checkpoint = { checkpoint = {
'meta': meta, 'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)) 'state_dict': weights_to_cpu(get_state_dict(model)) # type: ignore
} }
# save optimizer state dict in the checkpoint # save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer): if isinstance(optimizer, Optimizer):
......
...@@ -16,7 +16,7 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, ...@@ -16,7 +16,7 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
from mmcv.utils import IS_MLU_AVAILABLE from mmcv.utils import IS_MLU_AVAILABLE
def _find_free_port(): def _find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us # Binding to port 0 will cause the OS to find an available port for us
...@@ -27,7 +27,7 @@ def _find_free_port(): ...@@ -27,7 +27,7 @@ def _find_free_port():
return port return port
def _is_free_port(port): def _is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1] ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost') ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
...@@ -47,7 +47,7 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: ...@@ -47,7 +47,7 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
raise ValueError(f'Invalid launcher type: {launcher}') raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend: str, **kwargs): def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus # TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
...@@ -64,7 +64,7 @@ def _init_dist_pytorch(backend: str, **kwargs): ...@@ -64,7 +64,7 @@ def _init_dist_pytorch(backend: str, **kwargs):
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend: str, **kwargs): def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ: if 'MASTER_PORT' not in os.environ:
...@@ -77,7 +77,7 @@ def _init_dist_mpi(backend: str, **kwargs): ...@@ -77,7 +77,7 @@ def _init_dist_mpi(backend: str, **kwargs):
dist.init_process_group(backend=backend, **kwargs) dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend: str, port: Optional[int] = None): def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
"""Initialize slurm distributed training environment. """Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system If argument ``port`` is not specified, then the master port will be system
...@@ -187,7 +187,9 @@ def allreduce_grads(params: List[torch.nn.Parameter], ...@@ -187,7 +187,9 @@ def allreduce_grads(params: List[torch.nn.Parameter],
dist.all_reduce(tensor.div_(world_size)) dist.all_reduce(tensor.div_(world_size))
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): def _allreduce_coalesced(tensors: torch.Tensor,
world_size: int,
bucket_size_mb: int = -1) -> None:
if bucket_size_mb > 0: if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024 bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes) buckets = _take_tensors(tensors, bucket_size_bytes)
......
...@@ -103,10 +103,10 @@ def auto_fp16( ...@@ -103,10 +103,10 @@ def auto_fp16(
>>> pass >>> pass
""" """
def auto_fp16_wrapper(old_func): def auto_fp16_wrapper(old_func: Callable) -> Callable:
@functools.wraps(old_func) @functools.wraps(old_func)
def new_func(*args, **kwargs): def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not, # check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method. # just fallback to the original method.
if not isinstance(args[0], supported_types): if not isinstance(args[0], supported_types):
...@@ -195,7 +195,7 @@ def force_fp32(apply_to: Optional[Iterable] = None, ...@@ -195,7 +195,7 @@ def force_fp32(apply_to: Optional[Iterable] = None,
def force_fp32_wrapper(old_func): def force_fp32_wrapper(old_func):
@functools.wraps(old_func) @functools.wraps(old_func)
def new_func(*args, **kwargs): def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not, # check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method. # just fallback to the original method.
if not isinstance(args[0], torch.nn.Module): if not isinstance(args[0], torch.nn.Module):
...@@ -380,7 +380,7 @@ class LossScaler: ...@@ -380,7 +380,7 @@ class LossScaler:
return True return True
return False return False
def _has_inf_or_nan(x): def _has_inf_or_nan(x: torch.Tensor) -> bool:
"""Check if params contain NaN.""" """Check if params contain NaN."""
try: try:
cpu_sum = float(x.float().sum()) cpu_sum = float(x.float().sum())
...@@ -407,7 +407,7 @@ class LossScaler: ...@@ -407,7 +407,7 @@ class LossScaler:
self.cur_scale *= self.scale_factor self.cur_scale *= self.scale_factor
self.cur_iter += 1 self.cur_iter += 1
def state_dict(self): def state_dict(self) -> dict:
"""Returns the state of the scaler as a :class:`dict`.""" """Returns the state of the scaler as a :class:`dict`."""
return dict( return dict(
cur_scale=self.cur_scale, cur_scale=self.cur_scale,
...@@ -431,5 +431,5 @@ class LossScaler: ...@@ -431,5 +431,5 @@ class LossScaler:
self.scale_window = state_dict['scale_window'] self.scale_window = state_dict['scale_window']
@property @property
def loss_scale(self): def loss_scale(self) -> float:
return self.cur_scale return self.cur_scale
...@@ -15,7 +15,7 @@ import torch ...@@ -15,7 +15,7 @@ import torch
import mmcv import mmcv
def get_host_info(): def get_host_info() -> str:
"""Get hostname and username. """Get hostname and username.
Return empty string if exception raised, e.g. ``getpass.getuser()`` will Return empty string if exception raised, e.g. ``getpass.getuser()`` will
...@@ -30,7 +30,7 @@ def get_host_info(): ...@@ -30,7 +30,7 @@ def get_host_info():
return host return host
def get_time_str(): def get_time_str() -> str:
return time.strftime('%Y%m%d_%H%M%S', time.localtime()) return time.strftime('%Y%m%d_%H%M%S', time.localtime())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment