Commit cbb6843e authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Add MaskFormer to d2go

Summary: Add MaskFormer to d2go

Reviewed By: bichenwu09

Differential Revision: D30006691

fbshipit-source-id: 15c85f4ab8b3d515805d639ad8cf47532af81f5e
parent 610d2d03
...@@ -17,6 +17,7 @@ def get_default_optimizer_params( ...@@ -17,6 +17,7 @@ def get_default_optimizer_params(
base_lr, base_lr,
weight_decay, weight_decay,
weight_decay_norm, weight_decay_norm,
weight_decay_embed,
bias_lr_factor=1.0, bias_lr_factor=1.0,
weight_decay_bias=None, weight_decay_bias=None,
overrides: Optional[Dict[str, Dict[str, float]]] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None,
...@@ -75,6 +76,8 @@ def get_default_optimizer_params( ...@@ -75,6 +76,8 @@ def get_default_optimizer_params(
# weights. # weights.
schedule_params["lr"] = base_lr * bias_lr_factor schedule_params["lr"] = base_lr * bias_lr_factor
schedule_params["weight_decay"] = weight_decay_bias schedule_params["weight_decay"] = weight_decay_bias
if isinstance(module, torch.nn.Embedding):
schedule_params["weight_decay"] = weight_decay_embed
if overrides is not None and module_param_name in overrides: if overrides is not None and module_param_name in overrides:
schedule_params.update(overrides[module_param_name]) schedule_params.update(overrides[module_param_name])
if lr_multipliers_overwrite is not None: if lr_multipliers_overwrite is not None:
...@@ -131,6 +134,7 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -131,6 +134,7 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
base_lr=cfg.SOLVER.BASE_LR, base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE), lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
...@@ -153,6 +157,7 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -153,6 +157,7 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
base_lr=cfg.SOLVER.BASE_LR, base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY, weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE), lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
......
...@@ -68,6 +68,7 @@ def get_default_cfg(_C): ...@@ -68,6 +68,7 @@ def get_default_cfg(_C):
# Set default optimizer # Set default optimizer
_C.SOLVER.OPTIMIZER = "sgd" _C.SOLVER.OPTIMIZER = "sgd"
_C.SOLVER.LR_MULTIPLIER_OVERWRITE = [] _C.SOLVER.LR_MULTIPLIER_OVERWRITE = []
_C.SOLVER.WEIGHT_DECAY_EMBED = 0.0
# Default world size in D2 is 0, which means scaling is not applied. For D2Go # Default world size in D2 is 0, which means scaling is not applied. For D2Go
# auto scale is encouraged, setting it to 8 # auto scale is encouraged, setting it to 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment