Unverified Commit 085bf5c1 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] add `--optim adamw_torch_fused` for pt-2.0+ (#22144)

* [trainer] add --optim adamw_torch_fused

* change optim default

* deal with non-torch

* revert default change; prep; add fp16/amp assert

* typo

* typo
parent c6318c37
...@@ -1122,11 +1122,13 @@ class Trainer: ...@@ -1122,11 +1122,13 @@ class Trainer:
optimizer_cls = AdamW optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_TORCH: elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
from torch.optim import AdamW from torch.optim import AdamW
optimizer_cls = AdamW optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
optimizer_kwargs.update({"fused": True})
elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
try: try:
from torch_xla.amp.syncfree import AdamW from torch_xla.amp.syncfree import AdamW
......
...@@ -121,6 +121,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -121,6 +121,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_HF = "adamw_hf" ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch" ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
...@@ -457,7 +458,8 @@ class TrainingArguments: ...@@ -457,7 +458,8 @@ class TrainingArguments:
The options should be separated by whitespaces. The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor. The optimizer to use: adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or
adafactor.
optim_args (`str`, *optional*): optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW. Optional arguments that are supplied to AnyPrecisionAdamW.
group_by_length (`bool`, *optional*, defaults to `False`): group_by_length (`bool`, *optional*, defaults to `False`):
...@@ -940,8 +942,15 @@ class TrainingArguments: ...@@ -940,8 +942,15 @@ class TrainingArguments:
label_smoothing_factor: float = field( label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
) )
default_optim = "adamw_hf"
# XXX: enable when pytorch==2.0.1 comes out - we want to give it time to get all the bugs sorted out
# if is_torch_available() and version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.1.0"):
# default_optim = "adamw_torch_fused"
# and update the doc above to:
# optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch_fused"` (for torch<2.1.0 `"adamw_hf"`):
optim: Union[OptimizerNames, str] = field( optim: Union[OptimizerNames, str] = field(
default="adamw_hf", default=default_optim,
metadata={"help": "The optimizer to use."}, metadata={"help": "The optimizer to use."},
) )
optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."}) optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
...@@ -1205,6 +1214,12 @@ class TrainingArguments: ...@@ -1205,6 +1214,12 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
self.optim = OptimizerNames.ADAFACTOR self.optim = OptimizerNames.ADAFACTOR
if self.optim == OptimizerNames.ADAMW_TORCH_FUSED and is_torch_available():
if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.0.0"):
raise ValueError("--optim adamw_torch_fused requires PyTorch 2.0 or higher")
# there is a bug in fp16/AMP in pt-2.0.0
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")
if ( if (
self.framework == "pt" self.framework == "pt"
...@@ -2275,8 +2290,8 @@ class TrainingArguments: ...@@ -2275,8 +2290,8 @@ class TrainingArguments:
Args: Args:
name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_apex_fused"`, `"adamw_anyprecision"` or The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
`"adafactor"`. `"adamw_anyprecision"` or `"adafactor"`.
learning_rate (`float`, *optional*, defaults to 5e-5): learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate. The initial learning rate.
weight_decay (`float`, *optional*, defaults to 0): weight_decay (`float`, *optional*, defaults to 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment