Commit b1e2cc56 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

remove redundant build_optimizer()

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/96

In `DETRRunner`, the method `build_optimizer` customized the following logics, which are actually redundant to parent class implementation and can be removed.
- Discount LR for certain modules, such as those with name `reference_points`, `backbone`, and `sampling_offsets`.
  - Those can be achieved by `SOLVER.LR_MULTIPLIER_OVERWRITE` after we update `get_default_optimizer_params` in `mobile-vision/d2go/d2go/optimizer/build.py`.
- Full model gradient clipping
  - This is also implemented in `mobile-vision/d2go/d2go/optimizer/build.py`

It also has minor issues
- It ignores `SOLVER.WEIGHT_DECAY_NORM` which can set a different weight decay for affine parameters in the norm modules.

Reviewed By: zhanghang1989

Differential Revision: D29420642

fbshipit-source-id: deeb9348c9d282231c540dde6161acedd8e3a119
parent 4f3f3401
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import itertools import itertools
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
import torch
from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
)
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
from detectron2.solver.build import maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER") D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def get_default_optimizer_params( def get_default_optimizer_params(
model: torch.nn.Module, model: torch.nn.Module,
base_lr, base_lr,
...@@ -51,7 +53,7 @@ def get_default_optimizer_params( ...@@ -51,7 +53,7 @@ def get_default_optimizer_params(
) )
params: List[Dict[str, Any]] = [] params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set() memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules(): for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False): for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad: if not value.requires_grad:
continue continue
...@@ -77,9 +79,9 @@ def get_default_optimizer_params( ...@@ -77,9 +79,9 @@ def get_default_optimizer_params(
schedule_params.update(overrides[module_param_name]) schedule_params.update(overrides[module_param_name])
if lr_multipliers_overwrite is not None: if lr_multipliers_overwrite is not None:
for kname, mult in lr_multipliers_overwrite.items(): for kname, mult in lr_multipliers_overwrite.items():
if kname in module_param_name: if kname in module_name:
# apply multiplier for the params containing kname, e.g. backbone # apply multiplier for the params containing kname, e.g. backbone
schedule_params['lr'] = schedule_params['lr'] * mult schedule_params["lr"] = schedule_params["lr"] * mult
params += [ params += [
{ {
"params": [value], "params": [value],
...@@ -110,6 +112,7 @@ def maybe_add_gradient_clipping(cfg, optim): # optim: the optimizer class ...@@ -110,6 +112,7 @@ def maybe_add_gradient_clipping(cfg, optim): # optim: the optimizer class
return FullModelGradientClippingOptimizer return FullModelGradientClippingOptimizer
return d2_maybe_add_gradient_clipping(cfg, optim) return d2_maybe_add_gradient_clipping(cfg, optim)
def _merge_dict(in_dict): def _merge_dict(in_dict):
ret_dict = {} ret_dict = {}
assert all(isinstance(x, dict) for x in in_dict) assert all(isinstance(x, dict) for x in in_dict)
...@@ -117,6 +120,7 @@ def _merge_dict(in_dict): ...@@ -117,6 +120,7 @@ def _merge_dict(in_dict):
ret_dict.update(dic) ret_dict.update(dic)
return ret_dict return ret_dict
@D2GO_OPTIM_MAPPER_REGISTRY.register() @D2GO_OPTIM_MAPPER_REGISTRY.register()
def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
...@@ -132,7 +136,10 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -132,7 +136,10 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE), lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
) )
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)( return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV params,
cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
) )
...@@ -151,7 +158,8 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -151,7 +158,8 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE), lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
) )
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)( return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR) params, cfg.SOLVER.BASE_LR
)
def build_optimizer_mapper(cfg, model): def build_optimizer_mapper(cfg, model):
......
#!/usr/bin/env python3 #!/usr/bin/env python3
from detr.d2 import DetrDatasetMapper, add_detr_config
from detectron2.solver.build import maybe_add_gradient_clipping
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.runner import GeneralizedRCNNRunner
from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from d2go.runner import GeneralizedRCNNRunner
from detr.d2 import DetrDatasetMapper, add_detr_config
@D2GO_DATA_MAPPER_REGISTRY.register() @D2GO_DATA_MAPPER_REGISTRY.register()
...@@ -27,63 +25,10 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper): ...@@ -27,63 +25,10 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper):
def __call__(self, dataset_dict): def __call__(self, dataset_dict):
return D2GoDatasetMapper.__call__(self, dataset_dict) return D2GoDatasetMapper.__call__(self, dataset_dict)
class DETRRunner(GeneralizedRCNNRunner): class DETRRunner(GeneralizedRCNNRunner):
def get_default_cfg(self): def get_default_cfg(self):
_C = super().get_default_cfg() _C = super().get_default_cfg()
add_detr_config(_C) add_detr_config(_C)
_C.MODEL.DETR = CN(_C.MODEL.DETR) _C.MODEL.DETR = CN(_C.MODEL.DETR)
return _C return _C
# TODO rm this after update optimizer
@classmethod
def build_optimizer(cls, cfg, model):
import torch
import itertools
from typing import Any, Dict, List, Set
from detectron2.solver.build import maybe_add_gradient_clipping
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "backbone.0" in key or "reference_points" in key or "sampling_offsets" in key:
lr = lr * 0.1
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment