"tests/nn/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "e83da060759e37020d59f28f7b3648ec7da6ff24"
Unverified Commit 966b7428 authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/runner/optimizer (#2001)

parent 1577f407
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import inspect import inspect
from typing import Dict, List
import torch import torch
...@@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer') ...@@ -10,7 +11,7 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder') OPTIMIZER_BUILDERS = Registry('optimizer builder')
def register_torch_optimizers(): def register_torch_optimizers() -> List:
torch_optimizers = [] torch_optimizers = []
for module_name in dir(torch.optim): for module_name in dir(torch.optim):
if module_name.startswith('__'): if module_name.startswith('__'):
...@@ -26,11 +27,11 @@ def register_torch_optimizers(): ...@@ -26,11 +27,11 @@ def register_torch_optimizers():
TORCH_OPTIMIZERS = register_torch_optimizers() TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer_constructor(cfg): def build_optimizer_constructor(cfg: Dict):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
def build_optimizer(model, cfg): def build_optimizer(model, cfg: Dict):
optimizer_cfg = copy.deepcopy(cfg) optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor', constructor_type = optimizer_cfg.pop('constructor',
'DefaultOptimizerConstructor') 'DefaultOptimizerConstructor')
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Dict, List, Optional, Union
import torch import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm from torch.nn import GroupNorm, LayerNorm
from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
...@@ -93,7 +95,9 @@ class DefaultOptimizerConstructor: ...@@ -93,7 +95,9 @@ class DefaultOptimizerConstructor:
>>> # model.cls_head is (0.01, 0.95). >>> # model.cls_head is (0.01, 0.95).
""" """
def __init__(self, optimizer_cfg, paramwise_cfg=None): def __init__(self,
optimizer_cfg: Dict,
paramwise_cfg: Optional[Dict] = None):
if not isinstance(optimizer_cfg, dict): if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict', raise TypeError('optimizer_cfg should be a dict',
f'but got {type(optimizer_cfg)}') f'but got {type(optimizer_cfg)}')
...@@ -103,7 +107,7 @@ class DefaultOptimizerConstructor: ...@@ -103,7 +107,7 @@ class DefaultOptimizerConstructor:
self.base_wd = optimizer_cfg.get('weight_decay', None) self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg() self._validate_cfg()
def _validate_cfg(self): def _validate_cfg(self) -> None:
if not isinstance(self.paramwise_cfg, dict): if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, ' raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}') f'but got {type(self.paramwise_cfg)}')
...@@ -126,7 +130,7 @@ class DefaultOptimizerConstructor: ...@@ -126,7 +130,7 @@ class DefaultOptimizerConstructor:
if self.base_wd is None: if self.base_wd is None:
raise ValueError('base_wd should not be None') raise ValueError('base_wd should not be None')
def _is_in(self, param_group, param_group_list): def _is_in(self, param_group: Dict, param_group_list: List) -> bool:
assert is_list_of(param_group_list, dict) assert is_list_of(param_group_list, dict)
param = set(param_group['params']) param = set(param_group['params'])
param_set = set() param_set = set()
...@@ -135,7 +139,11 @@ class DefaultOptimizerConstructor: ...@@ -135,7 +139,11 @@ class DefaultOptimizerConstructor:
return not param.isdisjoint(param_set) return not param.isdisjoint(param_set)
def add_params(self, params, module, prefix='', is_dcn_module=None): def add_params(self,
params: List[Dict],
module: nn.Module,
prefix: str = '',
is_dcn_module: Optional[Union[int, float]] = None) -> None:
"""Add all parameters of module to the params list. """Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param The parameters of the given module will be added to the list of param
...@@ -232,7 +240,7 @@ class DefaultOptimizerConstructor: ...@@ -232,7 +240,7 @@ class DefaultOptimizerConstructor:
prefix=child_prefix, prefix=child_prefix,
is_dcn_module=is_dcn_module) is_dcn_module=is_dcn_module)
def __call__(self, model): def __call__(self, model: nn.Module):
if hasattr(model, 'module'): if hasattr(model, 'module'):
model = model.module model = model.module
...@@ -243,7 +251,7 @@ class DefaultOptimizerConstructor: ...@@ -243,7 +251,7 @@ class DefaultOptimizerConstructor:
return build_from_cfg(optimizer_cfg, OPTIMIZERS) return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively # set param-wise lr and weight decay recursively
params = [] params: List[Dict] = []
self.add_params(params, model) self.add_params(params, model)
optimizer_cfg['params'] = params optimizer_cfg['params'] = params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment