"scripts/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "c7bfdb1cca2c411fc9e9c8ca87847f52c29ba31c"
Unverified Commit 3104036e authored by Manuel R. Ciosici's avatar Manuel R. Ciosici Committed by GitHub
Browse files

Add support for bitsandbytes (#15622)



* Add initial BNB integration

* fixup! Add initial BNB integration

* Add bnb test decorator

* Update Adamw8bit option name

* Use the full bnb package name

* Overide bnb for all embedding layers

* Fix package name

* Formatting

* Remove unnecessary import

* Update src/transformers/trainer.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Rename AdamwBNB optimizer option

* Add training test checking that bnb memory utilization is lower

* fix merge

* fix merge; fix + extend new test

* cleanup

* expand bnb

* move all require_* candidates to testing_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent e6d23a4b
...@@ -31,8 +31,16 @@ from unittest import mock ...@@ -31,8 +31,16 @@ from unittest import mock
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available from .deepspeed import is_deepspeed_available
from .integrations import is_optuna_available, is_ray_available, is_sigopt_available, is_wandb_available from .integrations import (
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
is_wandb_available,
)
from .utils import ( from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_detectron2_available, is_detectron2_available,
is_faiss_available, is_faiss_available,
is_flax_available, is_flax_available,
...@@ -638,6 +646,36 @@ def require_deepspeed(test_case): ...@@ -638,6 +646,36 @@ def require_deepspeed(test_case):
return test_case return test_case
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
def require_bitsandbytes(test_case):
"""
Decorator for bits and bytes (bnb) dependency
"""
if not is_bitsandbytes_available():
return unittest.skip("test requires bnb")(test_case)
else:
return test_case
def require_phonemizer(test_case): def require_phonemizer(test_case):
""" """
Decorator marking a test that requires phonemizer Decorator marking a test that requires phonemizer
......
...@@ -867,6 +867,15 @@ class Trainer: ...@@ -867,6 +867,15 @@ class Trainer:
) )
else: else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
for module in self.model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer) self.optimizer = smp.DistributedOptimizer(self.optimizer)
...@@ -917,6 +926,14 @@ class Trainer: ...@@ -917,6 +926,14 @@ class Trainer:
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
except ImportError: except ImportError:
raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
elif args.optim == OptimizerNames.ADAMW_BNB:
try:
from bitsandbytes.optim import Adam8bit
optimizer_cls = Adam8bit
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
else: else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs return optimizer_cls, optimizer_kwargs
......
...@@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum): ...@@ -79,6 +79,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_TORCH_XLA = "adamw_torch_xla" ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_APEX_FUSED = "adamw_apex_fused" ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor" ADAFACTOR = "adafactor"
ADAMW_BNB = "adamw_bnb_8bit"
@dataclass @dataclass
......
...@@ -85,6 +85,7 @@ from .import_utils import ( ...@@ -85,6 +85,7 @@ from .import_utils import (
DummyObject, DummyObject,
_LazyModule, _LazyModule,
is_apex_available, is_apex_available,
is_bitsandbytes_available,
is_coloredlogs_available, is_coloredlogs_available,
is_datasets_available, is_datasets_available,
is_detectron2_available, is_detectron2_available,
......
...@@ -400,6 +400,10 @@ def is_apex_available(): ...@@ -400,6 +400,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None return importlib.util.find_spec("apex") is not None
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
def is_faiss_available(): def is_faiss_available():
return _faiss_available return _faiss_available
......
...@@ -17,10 +17,11 @@ import os ...@@ -17,10 +17,11 @@ import os
import re import re
import sys import sys
import unittest import unittest
from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
from parameterized import parameterized from parameterized import parameterized
from transformers.integrations import is_fairscale_available from transformers import AutoModel
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureStderr, CaptureStderr,
ExtendSysPath, ExtendSysPath,
...@@ -28,6 +29,9 @@ from transformers.testing_utils import ( ...@@ -28,6 +29,9 @@ from transformers.testing_utils import (
execute_subprocess_async, execute_subprocess_async,
get_gpu_count, get_gpu_count,
get_torch_dist_unique_port, get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
...@@ -36,7 +40,6 @@ from transformers.testing_utils import ( ...@@ -36,7 +40,6 @@ from transformers.testing_utils import (
) )
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from transformers.utils import is_apex_available
bindir = os.path.abspath(os.path.dirname(__file__)) bindir = os.path.abspath(os.path.dirname(__file__))
...@@ -49,28 +52,6 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" ...@@ -49,28 +52,6 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MBART_TINY = "sshleifer/tiny-mbart" MBART_TINY = "sshleifer/tiny-mbart"
# a candidate for testing_utils
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
if not is_fairscale_available():
return unittest.skip("test requires fairscale")(test_case)
else:
return test_case
# a candidate for testing_utils
def require_apex(test_case):
"""
Decorator marking a test that requires apex
"""
if not is_apex_available():
return unittest.skip("test requires apex")(test_case)
else:
return test_case
@require_torch @require_torch
class TestTrainerExt(TestCasePlus): class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick( def run_seq2seq_quick(
...@@ -193,7 +174,7 @@ class TestTrainerExt(TestCasePlus): ...@@ -193,7 +174,7 @@ class TestTrainerExt(TestCasePlus):
self.assertEqual(n_matches, data["n_matches"]) self.assertEqual(n_matches, data["n_matches"])
@slow @slow
def test_run_seq2seq_slow(self): def test_run_seq2seq(self):
output_dir = self.run_trainer( output_dir = self.run_trainer(
eval_steps=2, eval_steps=2,
max_len=128, max_len=128,
...@@ -218,6 +199,88 @@ class TestTrainerExt(TestCasePlus): ...@@ -218,6 +199,88 @@ class TestTrainerExt(TestCasePlus):
assert "generated_predictions.txt" in contents assert "generated_predictions.txt" in contents
assert "predict_results.json" in contents assert "predict_results.json" in contents
@slow
@require_bitsandbytes
def test_run_seq2seq_bnb(self):
from transformers.training_args import OptimizerNames
def train_and_return_metrics(optim: str) -> Tuple[int, float]:
from pathlib import Path
extra_args = (
f"--skip_memory_metrics 0 --optim {optim} --do_eval False --do_predict "
"False --adafactor False --log_level debug"
)
output_dir = self.run_trainer(
eval_steps=2,
max_len=128,
model_name=MARIAN_MODEL,
learning_rate=3e-4,
num_train_epochs=1,
distributed=True, # force run in a new process
extra_args_str=extra_args,
do_eval=False,
do_predict=False,
)
# Check metrics
logs = TrainerState.load_from_json(Path(output_dir, "trainer_state.json")).log_history
gpu_peak_mem = logs[0]["train_mem_gpu_peaked_delta"]
gpu_alloc_mem = logs[0]["train_mem_gpu_alloc_delta"]
loss = logs[0]["train_loss"]
return gpu_peak_mem, gpu_alloc_mem, loss
gpu_peak_mem_orig, gpu_alloc_mem_orig, loss_orig = train_and_return_metrics(OptimizerNames.ADAMW_TORCH.value)
gpu_peak_mem_bnb, gpu_alloc_mem_bnb, loss_bnb = train_and_return_metrics(OptimizerNames.ADAMW_BNB.value)
gpu_peak_mem_diff_bytes = gpu_peak_mem_orig - gpu_peak_mem_bnb
gpu_peak_mem_diff_percent = gpu_peak_mem_diff_bytes / gpu_peak_mem_bnb
gpu_total_mem_orig = gpu_peak_mem_orig + gpu_alloc_mem_orig
gpu_total_mem_bnb = gpu_peak_mem_bnb + gpu_alloc_mem_bnb
gpu_total_mem_diff_bytes = gpu_total_mem_orig - gpu_total_mem_bnb
gpu_total_mem_diff_percent = gpu_total_mem_diff_bytes / gpu_total_mem_bnb
# leave this for now if CI gets very different results
# print(f"{gpu_alloc_mem_orig=:010d} {gpu_peak_mem_orig=:010d} {gpu_alloc_mem_orig+gpu_peak_mem_orig=:010d}" )
# print(f" {gpu_alloc_mem_bnb=:010d} {gpu_peak_mem_bnb=:010d} {gpu_alloc_mem_bnb+gpu_peak_mem_bnb=:010d}")
# print(f"{gpu_peak_mem_diff_bytes=}, {gpu_peak_mem_diff_percent=}")
# print(f"{gpu_total_mem_orig=}, {gpu_total_mem_bnb=}")
# print(f"{gpu_total_mem_diff_bytes=}, {gpu_total_mem_diff_percent=}")
self.assertGreater(
gpu_peak_mem_diff_percent,
10, # basically a huge difference - got ~30x on my desktop
"should use very little peak gpu memory with BNB, compared to without it"
f"but got gpu_peak_mem_orig={gpu_peak_mem_orig} and gpu_peak_mem_bnb={gpu_peak_mem_bnb}",
)
self.assertGreater(
gpu_total_mem_diff_percent,
0.20, # could easily be 0.50, but let's stay on the safe side
"Using BNB should use less total GPU memory than without it"
f"but got gpu_total_mem_orig={gpu_total_mem_orig} and gpu_total_mem_bnb={gpu_total_mem_bnb}",
)
self.assertEqual(
loss_orig, loss_bnb, "loss should be the same, but got loss_orig={loss_orig}, loss_bnb={loss_bnb}"
)
# Additionally let's test that the absolute gpu memory difference is larger or about the
# same as the expected saving coming from BNB (6 bytes per param)
model = AutoModel.from_pretrained(MARIAN_MODEL)
total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
bnb_saved_bytes = total_numel * 6 # 324MB
self.assertGreater(
gpu_total_mem_diff_bytes,
bnb_saved_bytes * 0.8, # add a safety margin, if it saved slightly less
f"BNB should have saved about {bnb_saved_bytes} bytes, but the saved bytes were {gpu_total_mem_diff_bytes}",
)
def run_trainer( def run_trainer(
self, self,
eval_steps: int, eval_steps: int,
...@@ -300,6 +363,8 @@ class TestTrainerExt(TestCasePlus): ...@@ -300,6 +363,8 @@ class TestTrainerExt(TestCasePlus):
{self.examples_dir_str}/pytorch/translation/run_translation.py {self.examples_dir_str}/pytorch/translation/run_translation.py
""".split() """.split()
cmd = [sys.executable] + distributed_args + args cmd = [sys.executable] + distributed_args + args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
else: else:
testargs = ["run_translation.py"] + args testargs = ["run_translation.py"] + args
......
...@@ -65,7 +65,7 @@ from transformers.testing_utils import ( ...@@ -65,7 +65,7 @@ from transformers.testing_utils import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils import WEIGHTS_NAME, is_apex_available from transformers.utils import WEIGHTS_NAME, is_apex_available, is_bitsandbytes_available
from transformers.utils.hp_naming import TrialShortNamer from transformers.utils.hp_naming import TrialShortNamer
...@@ -1870,6 +1870,7 @@ if is_torch_available(): ...@@ -1870,6 +1870,7 @@ if is_torch_available():
}, },
), ),
] ]
if is_apex_available(): if is_apex_available():
import apex import apex
...@@ -1881,6 +1882,17 @@ if is_torch_available(): ...@@ -1881,6 +1882,17 @@ if is_torch_available():
) )
) )
if is_bitsandbytes_available():
import bitsandbytes as bnb
optim_test_params.append(
(
OptimizerNames.ADAMW_BNB,
bnb.optim.Adam8bit,
default_adam_kwargs,
)
)
@require_torch @require_torch
class TrainerOptimizerChoiceTest(unittest.TestCase): class TrainerOptimizerChoiceTest(unittest.TestCase):
...@@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -1905,8 +1917,8 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
def test_fused_adam(self): def test_fused_adam(self):
# Pretend that apex is installed and mock apex.optimizers.FusedAdam exists. # Pretend that apex is installed and mock apex.optimizers.FusedAdam exists.
# Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam, but only has to return a # Trainer.get_optimizer_cls_and_kwargs does not use FusedAdam. It only has to return the
# class called, so mocking apex.optimizers.FusedAdam should be fine for testing and allow # class given, so mocking apex.optimizers.FusedAdam should be fine for testing and allow
# the test to run without requiring an apex installation. # the test to run without requiring an apex installation.
mock = Mock() mock = Mock()
modules = { modules = {
...@@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase): ...@@ -1930,6 +1942,33 @@ class TrainerOptimizerChoiceTest(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args) Trainer.get_optimizer_cls_and_kwargs(args)
def test_bnb_adam8bit(self):
# Pretend that Bits and Bytes is installed and mock bnb.optim.Adam8bit exists.
# Trainer.get_optimizer_cls_and_kwargs does not use Adam8bit. It only has to return the
# class given, so mocking bnb.optim.Adam8bit should be fine for testing and allow
# the test to run without requiring a bnb installation.
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.Adam8bit": mock.optim.Adam8bit,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
OptimizerNames.ADAMW_BNB,
default_adam_kwargs,
mock.optim.Adam8bit,
)
def test_bnb_adam8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADAMW_BNB, output_dir="None")
# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if bnb is installed.
with patch.dict("sys.modules", {"bnb.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)
@require_torch @require_torch
@require_wandb @require_wandb
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment