Unverified Commit c4d1fd77 authored by Yanming Wang's avatar Yanming Wang Committed by GitHub
Browse files

Set syncfree AdamW as the default optimizer for xla:gpu device in amp mode (#15361)

* Use syncfree AdamW for xla:gpu device by default

* Make syncfree AdamW optional
parent 2e4559fa
...@@ -868,6 +868,14 @@ class Trainer: ...@@ -868,6 +868,14 @@ class Trainer:
optimizer_cls = AdamW optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
try:
from torch_xla.amp.syncfree import AdamW
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
try: try:
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
......
...@@ -77,6 +77,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -77,6 +77,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_HF = "adamw_hf" ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch" ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment