Commit b9be57ed authored by Xiaoliang Dai's avatar Xiaoliang Dai Committed by Facebook GitHub Bot
Browse files

Add vit det to d2go

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/352

Reviewed By: newstzpz

Differential Revision: D37872639

fbshipit-source-id: 61acdaa669bc541dcb715af1172926efb53c0b2b
parent dba54f21
...@@ -49,6 +49,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg): ...@@ -49,6 +49,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED, weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
weight_decay_overwrite=_merge_dict(cfg.SOLVER.WEIGHT_DECAY_OVERWRITE),
) )
# parameter groups from model function `model.get_optimizer_param_groups(opts)` # parameter groups from model function `model.get_optimizer_param_groups(opts)`
...@@ -125,6 +126,7 @@ def get_optimizer_param_groups_weight_decay( ...@@ -125,6 +126,7 @@ def get_optimizer_param_groups_weight_decay(
weight_decay_norm: Optional[float] = None, weight_decay_norm: Optional[float] = None,
weight_decay_bias: Optional[float] = None, weight_decay_bias: Optional[float] = None,
weight_decay_embed: Optional[float] = None, weight_decay_embed: Optional[float] = None,
weight_decay_overwrite: Optional[Dict[str, float]] = None,
): ):
""" """
Allow setting up weight decay for normalization, embedding and bias Allow setting up weight decay for normalization, embedding and bias
...@@ -162,6 +164,11 @@ def get_optimizer_param_groups_weight_decay( ...@@ -162,6 +164,11 @@ def get_optimizer_param_groups_weight_decay(
cur_wd = weight_decay_embed cur_wd = weight_decay_embed
elif module_param_name == "bias": elif module_param_name == "bias":
cur_wd = weight_decay_bias cur_wd = weight_decay_bias
if weight_decay_overwrite is not None:
for kname, wd in weight_decay_overwrite.items():
if kname in module_param_name:
cur_wd = wd
if cur_wd is not None: if cur_wd is not None:
params += [ params += [
{ {
......
...@@ -68,6 +68,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None: ...@@ -68,6 +68,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.OPTIMIZER = "sgd" _C.SOLVER.OPTIMIZER = "sgd"
_C.SOLVER.LR_MULTIPLIER_OVERWRITE = [] _C.SOLVER.LR_MULTIPLIER_OVERWRITE = []
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0 _C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
# Betas are used in the AdamW optimizer # Betas are used in the AdamW optimizer
_C.SOLVER.BETAS = (0.9, 0.999) _C.SOLVER.BETAS = (0.9, 0.999)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment