import re import torch from mmdet.utils import build_from_cfg, get_root_logger from .registry import OPTIMIZERS def build_optimizer(model, optimizer_cfg): """Build optimizer from configs. Args: model (:obj:`nn.Module`): The model with parameters to be optimized. optimizer_cfg (dict): The config dict of the optimizer. Positional fields are: - type: class name of the optimizer. - lr: base learning rate. Optional fields are: - any arguments of the corresponding optimizer type, e.g., weight_decay, momentum, etc. - paramwise_options: a dict with 4 accepted fileds (bias_lr_mult, bias_decay_mult, norm_decay_mult, dwconv_decay_mult). `bias_lr_mult` and `bias_decay_mult` will be multiplied to the lr and weight decay respectively for all bias parameters (except for the normalization layers), and `norm_decay_mult` will be multiplied to the weight decay for all weight and bias parameters of normalization layers. `dwconv_decay_mult` will be multiplied to the weight decay for all weight and bias parameters of depthwise conv layers. Returns: torch.optim.Optimizer: The initialized optimizer. Example: >>> import torch >>> model = torch.nn.modules.Conv1d(1, 1, 1) >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, >>> weight_decay=0.0001) >>> optimizer = build_optimizer(model, optimizer_cfg) """ if hasattr(model, 'module'): model = model.module optimizer_cfg = optimizer_cfg.copy() if isinstance(optimizer_cfg, list): # Assume paramwise_options is None if optimizer_cfg is list from .mix_optimizer import MixedOptimizer logger = get_root_logger() keys = [optimizer.pop('key') for optimizer in optimizer_cfg] keys_params = {key: [] for key in keys} keys_params_name = {key: [] for key in keys} keys_optimizer = [] for name, param in model.named_parameters(): param_group = {'params': [param]} find_flag = False for key in keys: if key in name: keys_params[key].append(param_group) keys_params_name[key].append(name) find_flag = True break assert find_flag, 'key {} is not matched to any optimizer'.format( name) step_intervals = [] for key, single_cfg in zip(keys, optimizer_cfg): optimizer_cls = getattr(torch.optim, single_cfg.pop('type')) step_intervals.append(single_cfg.pop('step_interval', 1)) single_optim = optimizer_cls(keys_params[key], **single_cfg) keys_optimizer.append(single_optim) logger.info('{} optimizes key:\n {}\n'.format( optimizer_cls.__name__, keys_params_name[key])) mix_optimizer = MixedOptimizer(keys_optimizer, step_intervals) return mix_optimizer else: paramwise_options = optimizer_cfg.pop('paramwise_options', None) # if no paramwise option is specified, just use the global setting if paramwise_options is None: params = model.parameters() else: assert isinstance(paramwise_options, dict) # get base lr and weight decay base_lr = optimizer_cfg['lr'] base_wd = optimizer_cfg.get('weight_decay', None) # weight_decay must be explicitly specified if mult is specified if ('bias_decay_mult' in paramwise_options or 'norm_decay_mult' in paramwise_options or 'dwconv_decay_mult' in paramwise_options): assert base_wd is not None # get param-wise options bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.) bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.) norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.) dwconv_decay_mult = paramwise_options.get('dwconv_decay_mult', 1.) named_modules = dict(model.named_modules()) # set param-wise lr and weight decay params = [] for name, param in model.named_parameters(): param_group = {'params': [param]} if not param.requires_grad: # FP16 training needs to copy gradient/weight between master # weight copy and model weight, it is convenient to keep all # parameters here to align with model.parameters() params.append(param_group) continue # for norm layers, overwrite the weight decay of weight and bias # TODO: obtain the norm layer prefixes dynamically if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name): if base_wd is not None: param_group['weight_decay'] = base_wd * norm_decay_mult # for other layers, overwrite both lr and weight decay of bias elif name.endswith('.bias'): param_group['lr'] = base_lr * bias_lr_mult if base_wd is not None: param_group['weight_decay'] = base_wd * bias_decay_mult module_name = name.replace('.weight', '').replace('.bias', '') if module_name in named_modules and base_wd is not None: module = named_modules[module_name] # if this Conv2d is depthwise Conv2d if isinstance(module, torch.nn.Conv2d) and \ module.in_channels == module.groups: param_group['weight_decay'] = base_wd * dwconv_decay_mult # otherwise use the global settings params.append(param_group) optimizer_cfg['params'] = params return build_from_cfg(optimizer_cfg, OPTIMIZERS)