"...text-generation-inference.git" did not exist on "724199aaf172590c3658018c0e6bc6152cda4c2f"
Commit 7ace1ef0 authored by Matteo Presutto's avatar Matteo Presutto Committed by Facebook GitHub Bot
Browse files

Adding search for all torch multi-tensor optimizers

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/629

This diff adds all of torch multi-tensor optimizers to d2go since it only supports Adamw, Adam and SGD in its current form.

Reviewed By: mlopezantequera

Differential Revision: D50498623

fbshipit-source-id: 5a38509354e565dd22256261bf1a688bcdc94951
parent b18c078a
...@@ -326,6 +326,126 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer: ...@@ -326,6 +326,126 @@ def adamw_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
) )
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def nadam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor NAdam optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.NAdam)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def radam_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RAdam optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RAdam)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def rmsprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RMSprop optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.RMSprop)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def rprop_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor RMSprop optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Rprop)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def asgd_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor ASGD optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.ASGD)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adamax_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adamax optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adamax)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adadelta_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adadelta optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adadelta)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
@D2GO_OPTIM_MAPPER_REGISTRY.register()
def adagrad_mt(cfg, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build a multi_tensor Adagrad optimizer that works significantly faster.
This version is expected to be the default implementation for adamw
optimizer by end of H1'21. To benefit from the speedup, the number
of parameter groups needs to be reduced using `reduce_param_groups`.
"""
params = get_optimizer_param_groups(model, cfg)
return maybe_add_gradient_clipping(cfg, torch.optim._multi_tensor.Adagrad)(
params=params,
lr=cfg.SOLVER.BASE_LR,
)
def build_optimizer_mapper(cfg, model): def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER name = cfg.SOLVER.OPTIMIZER
optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model) optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment