Commit 847c6025 authored by Denis Savenkov's avatar Denis Savenkov Committed by Facebook GitHub Bot
Browse files

Adds LAMB optimizer for large batch training

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/500

Adds LAMB optimizer from Apex to D2Go. LAMB is really helpful in large batch settings, e.g. see [scaling of XRay Video model](https://fb.workplace.com/notes/1569293900138973).

NOTE: this diff just adds an optimizer. Quality experiments haven't been finished yet, so we don't switch default optimizer.

Reviewed By: ertrue

Differential Revision: D43920637

fbshipit-source-id: 5dbbc79bbe34ddc36b422f9746cffed2991b2512
parent 1506551f
......@@ -4,6 +4,8 @@ import itertools
import logging
from typing import Any, Dict, List, Optional, Union
import apex
import torch
# FIXME: optimizer should not depend on quantization (or vice versa)
......@@ -316,6 +318,24 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def lamb(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
LAMB optimizer has been proposed in `Large Batch Optimization for Deep Learning:
Training BERT in 76 minutes` (https://arxiv.org/abs/1904.00962). It helped scale
LLM training to batch sizes of 32K samples.
"""
params = get_optimizer_param_groups(model, cfg)
assert cfg.SOLVER.FUSED, "Only fused version of LAMB optimizer is supported"
return maybe_add_gradient_clipping(cfg, apex.optimizers.FusedLAMB)(
params=params,
lr=cfg.SOLVER.BASE_LR,
betas=cfg.SOLVER.BETAS,
eps=cfg.SOLVER.EPS,
)
def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER
optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment