Unverified Commit 8b210490 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Added comments about Llama3 weights to Llama tutorial (#830)



* Llama 3 update
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Times update
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Times update
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* utils.py fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* utils.py fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* utils.py fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* update te llama tutorial to allow running with llama 3 weights
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add llama 3 vs llama 2 distinctions
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* paraphrasing and corrected facts
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 4478b044
...@@ -100,13 +100,21 @@ class TELlamaForCausalLM: ...@@ -100,13 +100,21 @@ class TELlamaForCausalLM:
subfolder = "" subfolder = ""
variant = None variant = None
if os.path.isfile( if os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant))
): ):
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join( archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)
) )
is_sharded = True is_sharded = True
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
)
is_sharded = True
else: else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment") raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
......
...@@ -2,23 +2,23 @@ ...@@ -2,23 +2,23 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "2cac9d39", "id": "6a5b2993",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Accelerating a Hugging Face Llama 2 model with Transformer Engine\n", "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"\n", "\n",
"<b>Goal</b>\n", "<b>Goal</b>\n",
"\n", "\n",
"This tutorial showcases how to accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n", "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"\n", "\n",
"</div>\n" "</div>\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "401f7fb1", "id": "331f476a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Dependencies for this tutorial\n", "## Dependencies for this tutorial\n",
...@@ -26,16 +26,28 @@ ...@@ -26,16 +26,28 @@
"Following files and media are necessary to effectively run this tutorial:\n", "Following files and media are necessary to effectively run this tutorial:\n",
"\n", "\n",
"1. `te_llama.py`\n", "1. `te_llama.py`\n",
" - This file contains the code to load a Hugging Face Llama 2 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
"2. `utils.py`\n", "2. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
"3. `media/`\n", "3. `media/`\n",
" - This directory contains the images used in the following tutorial." " - This directory contains the images used in the following tutorial.\n",
"\n",
"These packages are necessary to run this tutorial:\n",
"`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n",
"\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note on running the tutorial with Llama 3 weights</b>\n",
"\n",
"This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
"\n",
"</div>\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "33bdb5fe", "id": "44abae4f",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Table of contents\n", "## Table of contents\n",
...@@ -53,7 +65,7 @@ ...@@ -53,7 +65,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "7645f176", "id": "e37e2cc1",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## From \"Transformer\" to \"Llama\" \n", "## From \"Transformer\" to \"Llama\" \n",
...@@ -67,10 +79,13 @@ ...@@ -67,10 +79,13 @@
"\n", "\n",
"- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n", "- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n",
"- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n", "- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n",
"- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases. \n", "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n",
"- One of the latest in this line of pretrained models which is also open source is Meta's [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n", "- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
" - These models range from 7B to 65B parameters.\n", " - These models range from 7B to 70B parameters.\n",
" - LLaMA 2 was pretrained on 2 trillion tokens.\n", " - LLaMA 2 was pretrained on 2 trillion tokens.\n",
"- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n",
" - These models range from 8B to 70B parameters.\n",
" - LLaMA 3 was pretrained on 15 trillion tokens.\n",
"\n", "\n",
"For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n", "For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n",
"\n", "\n",
...@@ -78,9 +93,16 @@ ...@@ -78,9 +93,16 @@
"2. RMSNorm in place of the LayerNorm\n", "2. RMSNorm in place of the LayerNorm\n",
"3. SwiGLU activation function\n", "3. SwiGLU activation function\n",
"4. RoPE as positional embeddings \n", "4. RoPE as positional embeddings \n",
"5. Grouped Query Attention\n", "5. Grouped Query Attention for the 70B model\n",
"6. Trained on 4K context length\n", "6. Trained on 4K context length\n",
"\n", "\n",
"Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n",
"\n",
"1. Use of bigger tokenizer - 128256 vs 32K.\n",
"2. Grouped Query Attention is used also by smaller 8B model.\n",
"3. The context length increased to 8K for all models.\n",
"3. Llama 3 was trained on 8x more data than Llama 2.\n",
"\n",
"<figure align=\"center\">\n", "<figure align=\"center\">\n",
"<img src=\"media/transformer_vs_llama.svg\">\n", "<img src=\"media/transformer_vs_llama.svg\">\n",
" <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n", " <figcaption> Fig 2: Comparing GPT and Llama architectures. </figcaption>\n",
...@@ -89,7 +111,7 @@ ...@@ -89,7 +111,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "d0cfa787", "id": "a110de1a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Hugging Face's `LlamaModel`\n", "## Hugging Face's `LlamaModel`\n",
...@@ -166,7 +188,7 @@ ...@@ -166,7 +188,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f4f21369", "id": "c9529229",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", "## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n",
...@@ -190,14 +212,14 @@ ...@@ -190,14 +212,14 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "24a8d0a5", "id": "b38eb3ac",
"metadata": {}, "metadata": {},
"source": [ "source": [
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"\n", "\n",
"<b>Note</b>\n", "<b>Note</b>\n",
" \n", " \n",
"This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", "This tutorial loads and trains a Llama 3 8B or a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n",
"\n", "\n",
"If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n",
"\n", "\n",
...@@ -207,7 +229,7 @@ ...@@ -207,7 +229,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "e36ff380", "id": "2e9d7a8c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -215,7 +237,7 @@ ...@@ -215,7 +237,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"10 finetuning steps complete!\n", "10 finetuning steps complete!\n",
"Average time taken per step: 315 milliseconds\n" "Average time taken per step: 248 milliseconds\n"
] ]
} }
], ],
...@@ -231,8 +253,8 @@ ...@@ -231,8 +253,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
...@@ -248,19 +270,19 @@ ...@@ -248,19 +270,19 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "a64f0f33", "id": "4035ccb7",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n",
"\n", "\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 315 | 1 |" "| HF (baseline) | BF16 | 248 | 1 |"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "d9898383", "id": "3db90dff",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", "## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n",
...@@ -532,8 +554,8 @@ ...@@ -532,8 +554,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "4974b738", "id": "bdb34b91",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -541,7 +563,7 @@ ...@@ -541,7 +563,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"10 finetuning steps complete!\n", "10 finetuning steps complete!\n",
"Average time taken per step: 252 milliseconds\n" "Average time taken per step: 185 milliseconds\n"
] ]
} }
], ],
...@@ -557,8 +579,8 @@ ...@@ -557,8 +579,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.mixed_precision = \"bf16\"\n",
"\n", "\n",
...@@ -574,20 +596,20 @@ ...@@ -574,20 +596,20 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "85c78c7f", "id": "0c9fbd65",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **25%** even when using only BF16 precision!\n", "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **34%** even when using only BF16 precision!\n",
"\n", "\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 315 | 1 |\n", "| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |" "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e2fb88e9", "id": "98cd8efb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", "## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n",
...@@ -613,7 +635,7 @@ ...@@ -613,7 +635,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "8f2b752e", "id": "772c6f22",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -621,7 +643,7 @@ ...@@ -621,7 +643,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"10 finetuning steps complete!\n", "10 finetuning steps complete!\n",
"Average time taken per step: 226 milliseconds\n" "Average time taken per step: 160 milliseconds\n"
] ]
} }
], ],
...@@ -637,8 +659,8 @@ ...@@ -637,8 +659,8 @@
"\n", "\n",
"# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n",
"## !!! `model_name` attr must point to the location of the model weights !!!\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n",
"## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", "# For Llama 2, download weights from https://huggingface.co/meta-llama/Llama-2-7b-hf (Hugging Face weight format).\n",
"## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "# For Llama 3, download weights from https://huggingface.co/meta-llama/Meta-Llama-3-8B (Hugging Face weight format).\n",
"hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n",
"hyperparams.mixed_precision = \"fp8\"\n", "hyperparams.mixed_precision = \"fp8\"\n",
"\n", "\n",
...@@ -654,27 +676,39 @@ ...@@ -654,27 +676,39 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "67ec126c", "id": "e7cf9c3a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 315 | 1 |\n", "| HF (baseline) | BF16 | 248 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |\n", "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 185 | 1.34 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 226 | 1.39 |\n", "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 160 | 1.55 |\n",
"\n",
"\n", "\n",
"After turning on FP8 precision, we get even more speedup of **55%** (with Llama 2 7B)!\n",
"\n",
"#### Llama 3 performance results\n",
"Running the same tutorial with **Llama 3 8B** yields the following performance numbers:\n",
"\n",
"| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n",
"|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n",
"| HF (baseline) | BF16 | 270 | 1 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 217 | 1.24 |\n",
"| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 185 | 1.46 |\n",
"\n", "\n",
"After turning on FP8 precision, we get even more speedup of almost **40%**!" "For Llama 3 8B, we get the most speedup of **46%** with FP8 precision!\n",
"\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "41b80b0f", "id": "95d6c42b",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Conclusion\n", "## Conclusion\n",
"\n", "\n",
"Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 implementation. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 and Llama 3 implementations. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!"
] ]
} }
], ],
......
...@@ -82,6 +82,7 @@ def init_baseline_model(hyperparams): ...@@ -82,6 +82,7 @@ def init_baseline_model(hyperparams):
config=config, config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison # Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
model.config.use_cache=False model.config.use_cache=False
...@@ -97,6 +98,7 @@ def init_te_llama_model(hyperparams): ...@@ -97,6 +98,7 @@ def init_te_llama_model(hyperparams):
config=config, config=config,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
) )
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM # Needed for the cases when using TELlamaForCausalLM
model.config.use_cache=False model.config.use_cache=False
...@@ -117,7 +119,7 @@ def wrap_with_accelerator(model, hyperparams): ...@@ -117,7 +119,7 @@ def wrap_with_accelerator(model, hyperparams):
train_dataloader = get_dataloaders(accelerator, hyperparams) train_dataloader = get_dataloaders(accelerator, hyperparams)
# Wrap model, optimizer/scheduler, dataloaders in accelerate # Wrap model, optimizer/scheduler, dataloaders in accelerate
optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate) optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True)
lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=100, num_warmup_steps=100,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment