Commit b1e2cc56 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

remove redundant build_optimizer()

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/96

In `DETRRunner`, the method `build_optimizer` customized the following logics, which are actually redundant to parent class implementation and can be removed.
- Discount LR for certain modules, such as those with name `reference_points`, `backbone`, and `sampling_offsets`.
  - Those can be achieved by `SOLVER.LR_MULTIPLIER_OVERWRITE` after we update `get_default_optimizer_params` in `mobile-vision/d2go/d2go/optimizer/build.py`.
- Full model gradient clipping
  - This is also implemented in `mobile-vision/d2go/d2go/optimizer/build.py`

It also has minor issues
- It ignores `SOLVER.WEIGHT_DECAY_NORM` which can set a different weight decay for affine parameters in the norm modules.

Reviewed By: zhanghang1989

Differential Revision: D29420642

fbshipit-source-id: deeb9348c9d282231c540dde6161acedd8e3a119
parent 4f3f3401
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import itertools
from typing import Any, Dict, List, Optional, Set
import torch
from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
)
from detectron2.utils.registry import Registry
from detectron2.solver.build import maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def get_default_optimizer_params(
model: torch.nn.Module,
base_lr,
......@@ -51,7 +53,7 @@ def get_default_optimizer_params(
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules():
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
......@@ -77,9 +79,9 @@ def get_default_optimizer_params(
schedule_params.update(overrides[module_param_name])
if lr_multipliers_overwrite is not None:
for kname, mult in lr_multipliers_overwrite.items():
if kname in module_param_name:
if kname in module_name:
# apply multiplier for the params containing kname, e.g. backbone
schedule_params['lr'] = schedule_params['lr'] * mult
schedule_params["lr"] = schedule_params["lr"] * mult
params += [
{
"params": [value],
......@@ -110,6 +112,7 @@ def maybe_add_gradient_clipping(cfg, optim): # optim: the optimizer class
return FullModelGradientClippingOptimizer
return d2_maybe_add_gradient_clipping(cfg, optim)
def _merge_dict(in_dict):
ret_dict = {}
assert all(isinstance(x, dict) for x in in_dict)
......@@ -117,6 +120,7 @@ def _merge_dict(in_dict):
ret_dict.update(dic)
return ret_dict
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
......@@ -132,7 +136,10 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV
params,
cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
)
......@@ -151,7 +158,8 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR)
params, cfg.SOLVER.BASE_LR
)
def build_optimizer_mapper(cfg, model):
......
#!/usr/bin/env python3
from detr.d2 import DetrDatasetMapper, add_detr_config
from detectron2.solver.build import maybe_add_gradient_clipping
from d2go.config import CfgNode as CN
from d2go.runner import GeneralizedRCNNRunner
from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from d2go.runner import GeneralizedRCNNRunner
from detr.d2 import DetrDatasetMapper, add_detr_config
@D2GO_DATA_MAPPER_REGISTRY.register()
......@@ -27,63 +25,10 @@ class DETRDatasetMapper(DetrDatasetMapper, D2GoDatasetMapper):
def __call__(self, dataset_dict):
return D2GoDatasetMapper.__call__(self, dataset_dict)
class DETRRunner(GeneralizedRCNNRunner):
def get_default_cfg(self):
_C = super().get_default_cfg()
add_detr_config(_C)
_C.MODEL.DETR = CN(_C.MODEL.DETR)
return _C
# TODO rm this after update optimizer
@classmethod
def build_optimizer(cls, cfg, model):
import torch
import itertools
from typing import Any, Dict, List, Set
from detectron2.solver.build import maybe_add_gradient_clipping
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for key, value in model.named_parameters(recurse=True):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "backbone.0" in key or "reference_points" in key or "sampling_offsets" in key:
lr = lr * 0.1
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment