Unverified Commit 2dbfbc74 authored by Santosh Bhavani's avatar Santosh Bhavani Committed by GitHub
Browse files

fix(examples): te_llama compatibility with transformers >= 4.57 (#2572)



* fix(examples): te_llama compatibility with HuggingFace transformers >= 4.57

The te_llama.py example was failing with HuggingFace transformers 4.57+
due to API changes in how decoder layer outputs are handled.

Changes:
- Handle case where hidden_states is passed as a tuple (older HF versions)
- Return tensor directly instead of wrapped in tuple (HF 4.57+ expects this)
- Fix regex pattern to use raw string (fixes SyntaxWarning)

Error fixed:
  AttributeError: 'tuple' object has no attribute 'contiguous'

Tested with:
- transformer_engine 2.5.0
- transformers 4.57.3
- PyTorch container nvcr.io/nvidia/pytorch:25.08-py3
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

* docs(te_llama): add requirements.txt
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

* fix(docs): add missing notebook output names
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>

---------
Signed-off-by: default avatarSantosh Bhavani <santosh.bhavani@live.com>
parent 72592763
transformers==4.57.0
accelerate==1.10.0
peft==0.15.2
datasets==4.0.0
sentencepiece==0.2.1
......@@ -72,10 +72,15 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
# Handle case where hidden_states might be a tuple (from previous layer output)
# This can happen with older versions of HuggingFace transformers
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Return tensor directly for HuggingFace transformers >= 4.57
# (older versions wrapped output in tuple and extracted with layer_outputs[0])
return super().forward(
hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
),
)
......@@ -162,7 +167,7 @@ def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
layer_prefix_pat = r"model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
......
......@@ -2,7 +2,6 @@
"cells": [
{
"cell_type": "markdown",
"id": "6a5b2993",
"metadata": {},
"source": [
"# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
......@@ -14,11 +13,11 @@
"This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"\n",
"</div>\n"
]
],
"id": "6a5b2993"
},
{
"cell_type": "markdown",
"id": "331f476a",
"metadata": {},
"source": [
"## Dependencies for this tutorial\n",
......@@ -29,12 +28,11 @@
" - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
"2. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
"3. `media/`\n",
"3. `requirements.txt`\n",
" - This file contains the necessary Python packages for this tutorial.\n",
"4. `media/`\n",
" - This directory contains the images used in the following tutorial.\n",
"\n",
"These packages are necessary to run this tutorial:\n",
"`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n",
"\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
......@@ -42,12 +40,34 @@
"\n",
"This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
"\n",
"</div>\n"
]
"</div>\n",
""
],
"id": "331f476a"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Install the required Python packages using the following command:"
],
"id": "b56526b3"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Uncomment and run this cell when running the tutorial for the first time\n",
"# %pip install -r requirements.txt"
],
"id": "099697e2",
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"id": "44abae4f",
"metadata": {},
"source": [
"## Table of contents\n",
......@@ -61,11 +81,11 @@
" - Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"7. [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"8. Conclusion"
]
],
"id": "44abae4f"
},
{
"cell_type": "markdown",
"id": "e37e2cc1",
"metadata": {},
"source": [
"## From \"Transformer\" to \"Llama\" \n",
......@@ -107,11 +127,11 @@
"<img src=\"media/transformer_vs_llama.svg\">\n",
" <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n",
"</figure>"
]
],
"id": "e37e2cc1"
},
{
"cell_type": "markdown",
"id": "a110de1a",
"metadata": {},
"source": [
"## Hugging Face's `LlamaModel`\n",
......@@ -155,7 +175,7 @@
")\n",
"```\n",
"\n",
"#### Hugging Face's `LlamaDecoderLayer`\n",
"### Hugging Face's `LlamaDecoderLayer`\n",
"\n",
"Let's take a closer look at `LlamaDecoderLayer`. It is composed of `input_layernorm`, `self_attn`, `post_attention_layernorm` and `mlp` modules. Each module has associated weights as shown in the diagram.\n",
"\n",
......@@ -164,10 +184,10 @@
" <figcaption> Fig 4: Causal Llama Model Block Diagram (with simplified illustration of the [LlamaDecoderLayer](https://github.com/huggingface/transformers/blob/e770f0316d2a9b787c9d1440f204fcb65e176682/src/transformers/models/llama/modeling_llama.py#L695)). </figcaption>\n",
"</figure>\n",
"\n",
"##### Self_Attn Layer\n",
"#### Self_Attn Layer\n",
"For simplicity in the block diagram illustration of the \"self_attn\" box, we omit the \"Grouped Query Attention\" operation and only showcase the modules which have associated weights.\n",
" \n",
"##### MLP Layer\n",
"#### MLP Layer\n",
"\n",
"SwiGLU is an activation defined as follows in the [modeling_llama.py](https://github.com/huggingface/transformers/blob/7c4995f93d8d24aae05e1e43279c96dce736e5c8/src/transformers/models/llama/modeling_llama.py#L236) file in the Hugging Face github repo:\n",
"```\n",
......@@ -184,11 +204,11 @@
"<img src=\"media/swiglu.svg\">\n",
" <figcaption> Fig 5: A look inside the feedforward layer with <code>swiglu</code> activation function. </figcaption>\n",
"</figure>"
]
],
"id": "a110de1a"
},
{
"cell_type": "markdown",
"id": "c9529229",
"metadata": {},
"source": [
"## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
......@@ -208,11 +228,11 @@
"The baseline implementation will be run in `BF16` precision.\n",
"\n",
"</div>"
]
],
"id": "c9529229"
},
{
"cell_type": "markdown",
"id": "b38eb3ac",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
......@@ -224,23 +244,12 @@
"If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n",
"\n",
"</div>\n"
]
],
"id": "b38eb3ac"
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2e9d7a8c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 248 milliseconds\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
......@@ -275,11 +284,22 @@
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 248 milliseconds\n"
]
}
],
"id": "2e9d7a8c"
},
{
"cell_type": "markdown",
"id": "4035ccb7",
"metadata": {},
"source": [
"Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
......@@ -287,18 +307,18 @@
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |"
]
],
"id": "4035ccb7"
},
{
"cell_type": "markdown",
"id": "3db90dff",
"metadata": {},
"source": [
"## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
"\n",
"In addition to basic layers like `Linear` and `LayerNorm`, Transformer Engine offers larger modules like `MultiheadAttention` (combines \"LayerNorm\" and \"Self Attention\") and `LayerNormMLP` (combines \"LayerNorm\" and \"MLP\") that could replace their counterparts in the `LlamaDecoderLayer` and potentially provide a speedup. Transformer Engine also offers a full `TransformerLayer` (which further combines `MultiheadAttention` and `LayerNormMLP` layers) which could replace `LlamaDecoderLayer` and provide a speedup (with careful mapping of the weights since the name of the weights are different for those two layers). Let's take a closer look at Transformer Engine's `TransformerLayer`. \n",
"\n",
"#### Transformer Engine's `TransformerLayer`\n",
"### Transformer Engine's `TransformerLayer`\n",
"\n",
"At a higher level, TE's `TransformerLayer` could be visualized as an apt replacement for the `LlamaDecoderLayer`. But the internals of the `TransformerLayer` are organized a bit differently. \n",
"\n",
......@@ -327,7 +347,7 @@
" <figcaption> Fig 8: Abstract illustration of the SwiGLU implementation in Transformer Engine. </figcaption>\n",
"</figure>\n",
"\n",
"#### `TransformerLayer` options explained\n",
"### `TransformerLayer` options explained\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
......@@ -404,7 +424,7 @@
"A major portion of the Hugging Face model implementation (32 `LlamaDecoderLayer` layers) could be potentially replaced with Transformer Engine's `TransformerLayer` layers. Let's see how it is made possible.\n",
"\n",
"\n",
"#### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"### Mapping weights from HF's `LlamaDecoderLayer` to TE's `TransformerLayer`\n",
"\n",
"Refer the accompanying file `te_llama.py` which provides a reference to create a Llama 2 model with TE's `TransformerLayer` after replacing HF's `LlamaDecoderLayer`.\n",
"\n",
......@@ -559,23 +579,12 @@
"\n",
"Let's first run this \"TELlama\" implementation in `BF16` precision.\n",
"</div>"
]
],
"id": "3db90dff"
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bdb34b91",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 185 milliseconds\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
......@@ -610,11 +619,22 @@
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 185 milliseconds\n"
]
}
],
"id": "bdb34b91"
},
{
"cell_type": "markdown",
"id": "0c9fbd65",
"metadata": {},
"source": [
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n",
......@@ -623,18 +643,18 @@
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |"
]
],
"id": "0c9fbd65"
},
{
"cell_type": "markdown",
"id": "98cd8efb",
"metadata": {},
"source": [
"## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
"\n",
"Now that most of the HF Llama model implementation (`LlamaDecoderLayer`s) has been swapped with Transformer Engine implementation (`TELlamaDecoderLayer` or `TransformerLayer`), let's see how finetuning in `FP8` precision helps improve performance.\n",
"\n",
"#### How to run the model in `FP8` precision\n",
"### How to run the model in `FP8` precision\n",
"\n",
"After the substitution, the model can be run in `FP8` precision by the following change over the previous BF16 runs. (For more information, refer the corresponding `wrap_with_accelerator` function in the accompanying `utils.py` file).\n",
"\n",
......@@ -648,23 +668,12 @@
" kwargs_handlers=fp8_kwarg_handler\n",
")\n",
"```"
]
],
"id": "98cd8efb"
},
{
"cell_type": "code",
"execution_count": 1,
"id": "772c6f22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 160 milliseconds\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
......@@ -699,11 +708,22 @@
"\n",
"# Finetune the model\n",
"finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"10 finetuning steps complete!\n",
"Average time taken per step: 160 milliseconds\n"
]
}
],
"id": "772c6f22"
},
{
"cell_type": "markdown",
"id": "e7cf9c3a",
"metadata": {},
"source": [
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
......@@ -715,7 +735,7 @@
"\n",
"After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n",
"\n",
"#### Llama 3 performance results\n",
"### Llama 3 performance results\n",
"Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
......@@ -726,17 +746,18 @@
"\n",
"For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n",
"\n"
]
],
"id": "e7cf9c3a"
},
{
"cell_type": "markdown",
"id": "95d6c42b",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
]
],
"id": "95d6c42b"
}
],
"metadata": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment