Commit 08a0f260 authored by Pascual Martinez Gomez's avatar Pascual Martinez Gomez Committed by Facebook GitHub Bot
Browse files

Add Adam optimizer in D2Go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/359

Currently, D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go is missing the Adam optimizer. This Diff addresses the gap.

Reviewed By: tglik, asanakoy

Differential Revision: D38492151

fbshipit-source-id: 27791c23c73942b7a466f2ca91f6b3631733ba16
parent b9be57ed
...@@ -254,6 +254,19 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -254,6 +254,19 @@ def sgd(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
) )
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adam(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_optimizer_param_groups(model, cfg)
optim = maybe_add_gradient_clipping(cfg, torch.optim.Adam)(
params, cfg.SOLVER.BASE_LR, betas=cfg.SOLVER.BETAS
)
return optim
@D2GO_OPTIM_MAPPER_REGISTRY.register() @D2GO_OPTIM_MAPPER_REGISTRY.register()
def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: def adamw(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
""" """
......
...@@ -163,7 +163,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -163,7 +163,7 @@ class TestOptimizer(unittest.TestCase):
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
multipliers = [None, [{"conv": 0.1}]] multipliers = [None, [{"conv": 0.1}]]
for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]: for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT", "Adam"]:
for mult in multipliers: for mult in multipliers:
cfg.SOLVER.BASE_LR = 0.01 cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.OPTIMIZER = optimizer_name cfg.SOLVER.OPTIMIZER = optimizer_name
...@@ -174,7 +174,7 @@ class TestOptimizer(unittest.TestCase): ...@@ -174,7 +174,7 @@ class TestOptimizer(unittest.TestCase):
runner = default_runner.Detectron2GoRunner() runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]: for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT", "Adam"]:
cfg.SOLVER.BASE_LR = 0.02 cfg.SOLVER.BASE_LR = 0.02
cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.2 cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.2
cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment