Unverified Commit 085bf5c1 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add `--optim adamw_torch_fused` for pt-2.0+ (#22144)

* [trainer] add --optim adamw_torch_fused

* change optim default

* deal with non-torch

* revert default change; prep; add fp16/amp assert

* typo

* typo
parent c6318c37
......@@ -1122,11 +1122,13 @@ class Trainer:
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_TORCH:
elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
optimizer_kwargs.update({"fused": True})
elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
try:
from torch_xla.amp.syncfree import AdamW
......
......@@ -121,6 +121,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
......@@ -457,7 +458,8 @@ class TrainingArguments:
The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor.
The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or
adafactor.
optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW.
group_by_length (`bool`, *optional*, defaults to `False`):
......@@ -940,8 +942,15 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
default_optim = "adamw_hf"
# XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out
# if is_torch_available() and version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.1.0"):
# default_optim = "adamw_torch_fused"
# and update the doc above to:
# optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_hf"`):
optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
default=default_optim,
metadata={"help": "The optimizer to use."},
)
optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
......@@ -1205,6 +1214,12 @@ class TrainingArguments:
FutureWarning,
)
self.optim = OptimizerNames.ADAFACTOR
if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available():
if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"):
raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher")
# there is a bug in fp16/AMP in pt-2.0.0
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
if (
self.framework == "pt"
......@@ -2275,8 +2290,8 @@ class TrainingArguments:
Args:
name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_apex_fused"`, `"adamw_anyprecision"` or
`"adafactor"`.
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
`"adamw_anyprecision"` or `"adafactor"`.
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate.
weight_decay (`float`, *optional*, defaults to 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment