Commit b9be57ed authored by Xiaoliang Dai's avatar Xiaoliang Dai Committed by Facebook GitHub Bot
Browse files

Add vit det to d2go

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/352

Reviewed By: newstzpz

Differential Revision: D37872639

fbshipit-source-id: 61acdaa669bc541dcb715af1172926efb53c0b2b
parent dba54f21
......@@ -49,6 +49,7 @@ def get_optimizer_param_groups(model: OptimizerModelsType, cfg):
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
weight_decay_overwrite=_merge_dict(cfg.SOLVER.WEIGHT_DECAY_OVERWRITE),
)
# parameter groups from model function `model.get_optimizer_param_groups(opts)`
......@@ -125,6 +126,7 @@ def get_optimizer_param_groups_weight_decay(
weight_decay_norm: Optional[float] = None,
weight_decay_bias: Optional[float] = None,
weight_decay_embed: Optional[float] = None,
weight_decay_overwrite: Optional[Dict[str, float]] = None,
):
"""
Allow setting up weight decay for normalization, embedding and bias
......@@ -162,6 +164,11 @@ def get_optimizer_param_groups_weight_decay(
cur_wd = weight_decay_embed
elif module_param_name == "bias":
cur_wd = weight_decay_bias
if weight_decay_overwrite is not None:
for kname, wd in weight_decay_overwrite.items():
if kname in module_param_name:
cur_wd = wd
if cur_wd is not None:
params += [
{
......
......@@ -68,6 +68,7 @@ def _add_detectron2go_runner_default_cfg(_C: CN) -> None:
_C.SOLVER.OPTIMIZER = "sgd"
_C.SOLVER.LR_MULTIPLIER_OVERWRITE = []
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
_C.SOLVER.WEIGHT_DECAY_OVERWRITE = []
# Betas are used in the AdamW optimizer
_C.SOLVER.BETAS = (0.9, 0.999)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment