Unverified Commit 796162c5 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Paged Optimizer + Lion Optimizer for Trainer (#23217)



* Added lion and paged optimizers and made original tests pass.

* Added tests for paged and lion optimizers.

* Added and fixed optimizer tests.

* Style and quality checks.

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 9d73b922
...@@ -1170,6 +1170,38 @@ class Trainer: ...@@ -1170,6 +1170,38 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim in [
OptimizerNames.ADAMW_BNB,
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
OptimizerNames.PAGED_LION_8BIT,
]:
try:
from bitsandbytes.optim import AdamW, Lion
is_paged = False
optim_bits = 32
optimizer_cls = None
additional_optim_kwargs = adam_kwargs
if "paged" in args.optim:
is_paged = True
if "8bit" in args.optim:
optim_bits = 8
if "adam" in args.optim:
optimizer_cls = AdamW
elif "lion" in args.optim:
optimizer_cls = Lion
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
optimizer_kwargs.update(additional_optim_kwargs)
optimizer_kwargs.update(bnb_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB: elif args.optim == OptimizerNames.ADAMW_BNB:
try: try:
from bitsandbytes.optim import Adam8bit from bitsandbytes.optim import Adam8bit
......
...@@ -139,10 +139,17 @@ class OptimizerNames(ExplicitEnum): ...@@ -139,10 +139,17 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_ANYPRECISION = "adamw_anyprecision" ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd" SGD = "sgd"
ADAGRAD = "adagrad" ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
@dataclass @dataclass
......
...@@ -2474,6 +2474,11 @@ if is_torch_available(): ...@@ -2474,6 +2474,11 @@ if is_torch_available():
"lr": TrainingArguments.learning_rate, "lr": TrainingArguments.learning_rate,
} }
default_lion_kwargs = {
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2),
"lr": TrainingArguments.learning_rate,
}
default_anyprecision_kwargs = { default_anyprecision_kwargs = {
"use_kahan_summation": False, "use_kahan_summation": False,
"momentum_dtype": torch.float32, "momentum_dtype": torch.float32,
...@@ -2525,11 +2530,59 @@ if is_torch_available(): ...@@ -2525,11 +2530,59 @@ if is_torch_available():
optim_test_params.append( optim_test_params.append(
( (
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"), TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
bnb.optim.Adam8bit, bnb.optim.AdamW,
default_adam_kwargs, default_adam_kwargs,
) )
) )
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
bnb.optim.AdamW,
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
bnb.optim.AdamW,
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
bnb.optim.AdamW,
default_adam_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
bnb.optim.Lion,
default_lion_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
bnb.optim.Lion,
default_lion_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
bnb.optim.Lion,
default_lion_kwargs,
)
)
if is_torchdistx_available(): if is_torchdistx_available():
import torchdistx import torchdistx
...@@ -2598,15 +2651,113 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -2598,15 +2651,113 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
modules = { modules = {
"bitsandbytes": mock, "bitsandbytes": mock,
"bitsandbytes.optim": mock.optim, "bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit, "bitsandbytes.optim.AdamW": mock.optim.AdamW,
} }
with patch.dict("sys.modules", modules): with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs( self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"), TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None"),
mock.optim.Adam8bit, mock.optim.AdamW,
default_adam_kwargs, default_adam_kwargs,
) )
def test_bnb_paged_adam8bit_alias(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADAMW_8BIT, output_dir="None"),
mock.optim.AdamW,
default_adam_kwargs,
)
def test_bnb_paged_adam(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None"),
mock.optim.AdamW,
default_adam_kwargs,
)
def test_bnb_paged_adam8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdamW": mock.optim.AdamW,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None"),
mock.optim.AdamW,
default_adam_kwargs,
)
def test_bnb_lion(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Lion": mock.optim.Lion,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.LION, output_dir="None"),
mock.optim.Lion,
default_lion_kwargs,
)
def test_bnb_lion8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Lion": mock.optim.Lion,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.LION_8BIT, output_dir="None"),
mock.optim.Lion,
default_lion_kwargs,
)
def test_bnb_paged_lion8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Lion": mock.optim.Lion,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None"),
mock.optim.Lion,
default_lion_kwargs,
)
def test_bnb_paged_lion(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Lion": mock.optim.Lion,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None"),
mock.optim.Lion,
default_lion_kwargs,
)
def test_bnb_adam8bit_no_bnb(self): def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None") args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
...@@ -2616,6 +2767,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -2616,6 +2767,42 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args) Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_adam_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADAMW_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_lion_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_paged_lion8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION_8BIT, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
def test_anyprecision_adamw(self): def test_anyprecision_adamw(self):
# Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists. # Pretend that torchdistx is installed and mock torchdistx.optimizers.AnyPrecisionAdamW exists.
# Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the # Trainer.get_optimizer_cls_and_kwargs does not use AnyPrecisioinAdamW. It only has to return the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment