Commit 737d099b authored by Valentin Andrei's avatar Valentin Andrei Committed by Facebook GitHub Bot
Browse files

Reduce number of parameter groups to make optimizer more efficient

Summary:
`torch.optim._multi_tensor` provides faster Optimizer implementations as it uses foreach APIs. We can enable it by modifying from `OPTIMIZER: "ADAMW"` to `OPTIMIZER: "ADAMW_MT"` in the config file.

In order to profit from the speedup, we need to reduce the number of parameter groups as suggested in this post: https://fb.workplace.com/groups/1405155842844877/permalink/4971600462867046/

The current implementation uses one parameter group per parameter which is not optimal. The proposed change groups parameters by learning rate and weight decay combinations.

Reviewed By: zhanghang1989

Differential Revision: D30272112

fbshipit-source-id: d8d24298a59b52c2fc2930f7d614a0c6380a432f
parent 6140395f
...@@ -12,6 +12,40 @@ from detectron2.utils.registry import Registry ...@@ -12,6 +12,40 @@ from detectron2.utils.registry import Registry
D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER") D2GO_OPTIM_MAPPER_REGISTRY = Registry("D2GO_OPTIM_MAPPER")
def reduce_param_groups(
param_groups: List[Dict[str, Any]]
):
# The number of parameter groups needs to be as small as possible in order
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
# of using a parameter_group per single parameter, we group all the params
# with the same lr and weight_decay in a single group. This approach speeds
# up optimizer step significantly.
dict_new_groups: Dict[str, Dict[str, Any]] = {}
for param_group in param_groups:
# value is a list of parameters from the previous group
value = param_group["params"]
# lr and weight_decay are floating point values
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
# Create the new groups using combinations of lr and weight_decay
group_key = (lr, weight_decay)
if group_key not in dict_new_groups:
dict_new_groups[group_key] = {
"params": value,
"lr": lr,
"weight_decay": weight_decay,
}
else:
# Add elements from an existing group to the new larger group
dict_new_groups[group_key]["params"].extend(value)
return list(dict_new_groups.values())
def get_default_optimizer_params( def get_default_optimizer_params(
model: torch.nn.Module, model: torch.nn.Module,
base_lr, base_lr,
...@@ -20,6 +54,7 @@ def get_default_optimizer_params( ...@@ -20,6 +54,7 @@ def get_default_optimizer_params(
weight_decay_embed, weight_decay_embed,
bias_lr_factor=1.0, bias_lr_factor=1.0,
weight_decay_bias=None, weight_decay_bias=None,
use_param_group_reduction=False,
overrides: Optional[Dict[str, Dict[str, float]]] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None,
lr_multipliers_overwrite: Optional[Dict[str, float]] = None, lr_multipliers_overwrite: Optional[Dict[str, float]] = None,
): ):
...@@ -36,6 +71,10 @@ def get_default_optimizer_params( ...@@ -36,6 +71,10 @@ def get_default_optimizer_params(
containing certain keys. For example, if lr_multipliers_overwrite={'backbone': 0.1}, containing certain keys. For example, if lr_multipliers_overwrite={'backbone': 0.1},
the LR for the parameters whose names containing 'backbone' will be scaled to 0.1x. the LR for the parameters whose names containing 'backbone' will be scaled to 0.1x.
Set lr_multipliers_overwrite={} if no multipliers required. Set lr_multipliers_overwrite={} if no multipliers required.
use_param_group_reduction:
if set to `False` we will have a parameter group for each parameter which makes
the optimizer very slow. This option should be used when using checkpoints of models
that were created using a parameter group for each param.
""" """
if weight_decay_bias is None: if weight_decay_bias is None:
weight_decay_bias = weight_decay weight_decay_bias = weight_decay
...@@ -93,6 +132,10 @@ def get_default_optimizer_params( ...@@ -93,6 +132,10 @@ def get_default_optimizer_params(
} }
] ]
if use_param_group_reduction:
# Reduce number of param groups to speed-up optimizer step
return reduce_param_groups(params)
return params return params
...@@ -167,6 +210,30 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -167,6 +210,30 @@ def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
) )
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor adamw optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
weight_decay_embed=cfg.SOLVER.WEIGHT_DECAY_EMBED,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
use_param_group_reduction=True,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
lr_multipliers_overwrite=_merge_dict(cfg.SOLVER.LR_MULTIPLIER_OVERWRITE),
)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.AdamW)(
params, cfg.SOLVER.BASE_LR
)
def build_optimizer_mapper(cfg, model): def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER name = cfg.SOLVER.OPTIMIZER
return D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model) return D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment