Unverified Commit 84c9cc6d authored by atturaioe's avatar atturaioe Committed by GitHub
Browse files

Add AnyPrecisionAdamW optimizer (#18961)

* Add AnyPrecisionAdamW optimizer

* Add optim_args argument to TrainingArgs

* Add tests for AnyPrecisionOptimizer

* Change AnyPrecisionAdam default params to float32

* Move default_anyprecision_kwargs in trainer test

* Rename AnyPrecisionAdamW
parent 37e01633
...@@ -29,6 +29,7 @@ import sys ...@@ -29,6 +29,7 @@ import sys
import time import time
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
from distutils.util import strtobool
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -1081,7 +1082,16 @@ class Trainer: ...@@ -1081,7 +1082,16 @@ class Trainer:
The training arguments for the training session. The training arguments for the training session.
""" """
# parse args.optim_args
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optimizer_kwargs = {"lr": args.learning_rate} optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = { adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2), "betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon, "eps": args.adam_epsilon,
...@@ -1123,6 +1133,26 @@ class Trainer: ...@@ -1123,6 +1133,26 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
try:
from torchdistx.optimizers import AnyPrecisionAdamW
optimizer_cls = AnyPrecisionAdamW
optimizer_kwargs.update(adam_kwargs)
# TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
optimizer_kwargs.update(
{
"use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
"momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
"variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
"compensation_buffer_dtype": getattr(
torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
),
}
)
except ImportError:
raise ValueError("Please install https://github.com/pytorch/torchdistx")
elif args.optim == OptimizerNames.SGD: elif args.optim == OptimizerNames.SGD:
optimizer_cls = torch.optim.SGD optimizer_cls = torch.optim.SGD
elif args.optim == OptimizerNames.ADAGRAD: elif args.optim == OptimizerNames.ADAGRAD:
......
...@@ -113,6 +113,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -113,6 +113,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit" ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd" SGD = "sgd"
ADAGRAD = "adagrad" ADAGRAD = "adagrad"
...@@ -401,7 +402,9 @@ class TrainingArguments: ...@@ -401,7 +402,9 @@ class TrainingArguments:
The options should be separated by whitespaces. The options should be separated by whitespaces.
optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`):
The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor.
optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW.
adafactor (`bool`, *optional*, defaults to `False`): adafactor (`bool`, *optional*, defaults to `False`):
This argument is deprecated. Use `--optim adafactor` instead. This argument is deprecated. Use `--optim adafactor` instead.
group_by_length (`bool`, *optional*, defaults to `False`): group_by_length (`bool`, *optional*, defaults to `False`):
...@@ -857,6 +860,7 @@ class TrainingArguments: ...@@ -857,6 +860,7 @@ class TrainingArguments:
default="adamw_hf", default="adamw_hf",
metadata={"help": "The optimizer to use."}, metadata={"help": "The optimizer to use."},
) )
optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
group_by_length: bool = field( group_by_length: bool = field(
default=False, default=False,
......
...@@ -153,6 +153,7 @@ from .import_utils import ( ...@@ -153,6 +153,7 @@ from .import_utils import (
is_torch_tf32_available, is_torch_tf32_available,
is_torch_tpu_available, is_torch_tpu_available,
is_torchaudio_available, is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available, is_torchdynamo_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
is_vision_available, is_vision_available,
......
...@@ -508,6 +508,10 @@ def is_bitsandbytes_available(): ...@@ -508,6 +508,10 @@ def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None return importlib.util.find_spec("bitsandbytes") is not None
def is_torchdistx_available():
return importlib.util.find_spec("torchdistx") is not None
def is_faiss_available(): def is_faiss_available():
return _faiss_available return _faiss_available
......
...@@ -71,7 +71,13 @@ from transformers.testing_utils import ( ...@@ -71,7 +71,13 @@ from transformers.testing_utils import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_apex_available,
is_bitsandbytes_available,
is_torchdistx_available,
)
from transformers.utils.hp_naming import TrialShortNamer from transformers.utils.hp_naming import TrialShortNamer
...@@ -2287,24 +2293,31 @@ if is_torch_available(): ...@@ -2287,24 +2293,31 @@ if is_torch_available():
"lr": TrainingArguments.learning_rate, "lr": TrainingArguments.learning_rate,
} }
default_anyprecision_kwargs = {
"use_kahan_summation": False,
"momentum_dtype": torch.float32,
"variance_dtype": torch.float32,
"compensation_buffer_dtype": torch.bfloat16,
}
optim_test_params = [ optim_test_params = [
( (
OptimizerNames.ADAMW_HF, TrainingArguments(optim=OptimizerNames.ADAMW_HF, output_dir="None"),
transformers.optimization.AdamW, transformers.optimization.AdamW,
default_adam_kwargs, default_adam_kwargs,
), ),
( (
OptimizerNames.ADAMW_HF.value, TrainingArguments(optim=OptimizerNames.ADAMW_HF.value, output_dir="None"),
transformers.optimization.AdamW, transformers.optimization.AdamW,
default_adam_kwargs, default_adam_kwargs,
), ),
( (
OptimizerNames.ADAMW_TORCH, TrainingArguments(optim=OptimizerNames.ADAMW_TORCH, output_dir="None"),
torch.optim.AdamW, torch.optim.AdamW,
default_adam_kwargs, default_adam_kwargs,
), ),
( (
OptimizerNames.ADAFACTOR, TrainingArguments(optim=OptimizerNames.ADAFACTOR, output_dir="None"),
transformers.optimization.Adafactor, transformers.optimization.Adafactor,
{ {
"scale_parameter": False, "scale_parameter": False,
...@@ -2319,7 +2332,7 @@ if is_torch_available(): ...@@ -2319,7 +2332,7 @@ if is_torch_available():
optim_test_params.append( optim_test_params.append(
( (
OptimizerNames.ADAMW_APEX_FUSED, TrainingArguments(OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
apex.optimizers.FusedAdam, apex.optimizers.FusedAdam,
default_adam_kwargs, default_adam_kwargs,
) )
...@@ -2330,32 +2343,42 @@ if is_torch_available(): ...@@ -2330,32 +2343,42 @@ if is_torch_available():
optim_test_params.append( optim_test_params.append(
( (
OptimizerNames.ADAMW_BNB, TrainingArguments(optim=OptimizerNames.ADAMW_BNB, ouput_dir="None"),
bnb.optim.Adam8bit, bnb.optim.Adam8bit,
default_adam_kwargs, default_adam_kwargs,
) )
) )
if is_torchdistx_available():
import torchdistx
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
torchdistx.optimizers.AnyPrecisionAdamW,
dict(default_adam_kwargs, **default_anyprecision_kwargs),
)
)
@require_torch @require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase): class TrainerOptimizerChoiceTest(unittest.TestCase):
def check_optim_and_kwargs(self, optim: OptimizerNames, mandatory_kwargs, expected_cls): def check_optim_and_kwargs(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
args = TrainingArguments(optim=optim, output_dir="None") actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
actual_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
self.assertEqual(expected_cls, actual_cls) self.assertEqual(expected_cls, actual_cls)
self.assertIsNotNone(optim_kwargs) self.assertIsNotNone(optim_kwargs)
for p, v in mandatory_kwargs.items(): for p, v in expected_kwargs.items():
self.assertTrue(p in optim_kwargs) self.assertTrue(p in optim_kwargs)
actual_v = optim_kwargs[p] actual_v = optim_kwargs[p]
self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.") self.assertTrue(actual_v == v, f"Failed check for {p}. Expected {v}, but got {actual_v}.")
@parameterized.expand(optim_test_params, skip_on_empty=True) @parameterized.expand(optim_test_params, skip_on_empty=True)
def test_optim_supported(self, name: str, expected_cls, mandatory_kwargs): def test_optim_supported(self, training_args: TrainingArguments, expected_cls, expected_kwargs):
# exercises all the valid --optim options # exercises all the valid --optim options
self.check_optim_and_kwargs(name, mandatory_kwargs, expected_cls) self.check_optim_and_kwargs(training_args, expected_cls, expected_kwargs)
trainer = get_regression_trainer(optim=name) trainer = get_regression_trainer(**training_args.to_dict())
trainer.train() trainer.train()
def test_fused_adam(self): def test_fused_adam(self):
...@@ -2371,9 +2394,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -2371,9 +2394,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
} }
with patch.dict("sys.modules", modules): with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs( self.check_optim_and_kwargs(
OptimizerNames.ADAMW_APEX_FUSED, TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"),
default_adam_kwargs,
mock.optimizers.FusedAdam, mock.optimizers.FusedAdam,
default_adam_kwargs,
) )
def test_fused_adam_no_apex(self): def test_fused_adam_no_apex(self):
...@@ -2398,9 +2421,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -2398,9 +2421,9 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
} }
with patch.dict("sys.modules", modules): with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs( self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB, TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
default_adam_kwargs,
mock.optim.Adam8bit, mock.optim.Adam8bit,
default_adam_kwargs,
) )
def test_bnb_adam8bit_no_bnb(self): def test_bnb_adam8bit_no_bnb(self):
...@@ -2412,6 +2435,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -2412,6 +2435,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args) Trainer.get_optimizer_cls_and_kwargs(args)
def test_anyprecision_adamw(self):
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
# class given, so mocking torchdistx.optimizers.AnyPrecisionAdamW should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"torchdistx": mock,
"torchdistx.optimizers": mock.optimizers,
"torchdistx.optimizers.AnyPrecisionAdamW.": mock.optimizers.AnyPrecisionAdamW,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None"),
mock.optimizers.AnyPrecisionAdamW,
dict(default_adam_kwargs, **default_anyprecision_kwargs),
)
def test_no_torchdistx_anyprecision_adamw(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_ANYPRECISION, output_dir="None")
# Pretend that torchdistx does not exist, even if installed. By setting torchdistx to None, importing
# torchdistx.optimizers will fail even if torchdistx is installed.
with patch.dict("sys.modules", {"torchdistx.optimizers": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
@require_torch @require_torch
@require_wandb @require_wandb
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment