Unverified Commit 309a9066 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[FEAT] Add Neftune into transformers Trainer (#27141)



* add v1 neftune

* use `unwrap_model` instead

* add test + docs

* Apply suggestions from code review
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* more details

* fixup

* Update docs/source/en/main_classes/trainer.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* refactor a bit

* more elaborated test

* fix unwrap issue

---------
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent f53041a7
......@@ -740,3 +740,27 @@ Sections that were moved:
| <a href="./deepspeed#deepspeed-grad-clip">Gradient Clipping</a><a id="gradient-clipping"></a>
| <a href="./deepspeed#deepspeed-weight-extraction">Getting The Model Weights Out</a><a id="getting-the-model-weights-out"></a>
]
## Boost your fine-tuning performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper “NEFTune: Noisy Embeddings Improve Instruction Finetuning” from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
</div>
To use it in `Trainer` simply pass `neftune_noise_alpha` when creating your `TrainingArguments` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
```python
from transformers import Trainer, TrainingArguments
args = TrainingArguments(..., neftune_noise_alpha=0.1)
trainer = Trainer(..., args=args)
...
trainer.train()
```
......@@ -113,6 +113,7 @@ from .trainer_utils import (
find_executable_batch_size,
get_last_checkpoint,
has_length,
neftune_post_forward_hook,
number_of_arguments,
seed_worker,
set_seed,
......@@ -486,6 +487,8 @@ class Trainer:
self.model_wrapped = model
self.model = model
self.neftune_noise_alpha = args.neftune_noise_alpha
self.compute_metrics = compute_metrics
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
self.optimizer, self.lr_scheduler = optimizers
......@@ -634,6 +637,42 @@ class Trainer:
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
https://arxiv.org/abs/2310.05914
"""
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
del unwrapped_model
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
def _deactivate_neftune(self, model):
"""
Deactivates the neftune method. Make sure to call `_activate_neftune` first.
"""
if not hasattr(self, "neftune_hook_handle"):
raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha, unwrapped_model
def add_callback(self, callback):
"""
Add a callback to the current list of [`~transformer.TrainerCallback`].
......@@ -1444,6 +1483,10 @@ class Trainer:
self.is_in_train = True
# Attach NEFTune hooks if necessary
if self.neftune_noise_alpha is not None:
self.model = self._activate_neftune(self.model)
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
......@@ -1956,6 +1999,11 @@ class Trainer:
# Wait for the checkpoint to be uploaded.
self._finish_current_push()
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _get_output_dir(self, trial):
......
......@@ -105,6 +105,32 @@ def set_seed(seed: int):
tf.random.set_seed(seed)
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
layers. This method is slightly adapted from the original source code that can be found here:
https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
class EvalPrediction:
"""
Evaluation output (always contains labels), to be used to compute metrics.
......
......@@ -627,6 +627,11 @@ class TrainingArguments:
This will iterate over the entire training dataloader once beforehand,
and will slow down the entire process.
neftune_noise_alpha (`Optional[float]`):
If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft.
"""
framework = "pt"
......@@ -1226,6 +1231,13 @@ class TrainingArguments:
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
)
neftune_noise_alpha: float = field(
default=None,
metadata={
"help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instrcution fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes."
},
)
def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
......
......@@ -838,6 +838,50 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_neftune(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
# Trainer without inf/nan filter
args = TrainingArguments(
"./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
trainer.model = trainer._activate_neftune(trainer.model)
dummy_input = torch.LongTensor([[1, 0, 1]]).to(torch_device)
emb1 = trainer.model.get_input_embeddings()(dummy_input)
emb2 = trainer.model.get_input_embeddings()(dummy_input)
self.assertFalse(torch.allclose(emb1, emb2), "Neftune noise is not applied!")
# redefine the model
tiny_gpt2 = GPT2LMHeadModel(config)
# Trainer without inf/nan filter
args = TrainingArguments(
"./test", learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, neftune_noise_alpha=0.4
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
# Check that it trains without errors
trainer.train()
# Make sure forward pass works fine
_ = trainer.model(dummy_input)
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
trainer.model.eval()
# Check that we get identical embeddings just in case
emb1 = trainer.model.get_input_embeddings()(dummy_input)
emb2 = trainer.model.get_input_embeddings()(dummy_input)
self.assertTrue(torch.allclose(emb1, emb2), "Neftune noise is still applied!")
def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment