Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import re
import gc
import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
"""
This file contains logic of mapping the HuggingFace GemmaModel parameters
with TransformerEngine TransformerLayer. When we have initialized Transformer models
both with HF and with TE, we can copy parameters from the first to the second.
"""
def _load_weights_for_fp8_model(vanilla_model, hyperparams):
"""
Loads weights and FP8 metadata from a calibrated weights file.
The weights are in BF16 precision, but the state dict also contains
fp8 metadata computed by the calibration procedure.
"""
fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename)
# A hack to remove the extra state from the fp8_metadata_sd
# that contains the extra state from the core_attention module.
fp8_metadata_sd = {
k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k
}
vanilla_model.load_state_dict(
fp8_metadata_sd,
strict=False,
# Because some parameters have multiple pointers to the same weight
# vanilla_model._model_context_phase.model and
# vanilla_model._model_generation_phase.model we need to load the
# weights in a non-strict manner.
)
def _load_weights_for_standard_model(vanilla_model, config):
"""
Loads weights from the HuggingFace checkpoint.
"""
archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json")
resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file)
total_dict = {}
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
total_dict.update(state_dict)
replace_params(
total_dict,
vanilla_model.state_dict(),
config,
qkv_fused_and_interleaved=config.fuse_qkv_params,
)
# Copy remaining parameters like embedding.
vanilla_model.load_state_dict(total_dict, strict=False)
# Force mem release. Taken from huggingface code.
del total_dict
gc.collect()
def load_te_model(cls, config):
"""
Loads the TE model with proper weights.
"""
# Force the dtype to bfloat16 while loading the model.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo:
https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
# Copy proper weights into the model. If loading weights with FP8 metadata,
# then the source weights are basically the same as the weights in the model.
# If not, then we need to load the weights from the HuggingFace checkpoint
# and do mapping of the weight names from HF to the TE model.
if config.fp8_model_weights_filename is not None:
_load_weights_for_fp8_model(vanilla_model, config)
else:
_load_weights_for_standard_model(vanilla_model, config)
# Restore the original dtype.
torch.set_default_dtype(old_dtype)
return vanilla_model
def _get_all_layer_prefixes_to_update(hf_state_dict):
"""
There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
This function extracts all strings like "model.layers.[number]."
that are starting strings of keys in hf_state_dict.
"""
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
return all_layer_prefixes
def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
"""
Replaces params from TE TransformerLayer state_dict with corresponding parameters
from HuggingFace GemmaModel state_dict.
"""
all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
for layer_prefix in all_layer_prefixes:
def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
te_state_dict[layer_prefix + te_name].data[start:end].copy_(
hf_state_dict[layer_prefix + hf_name]
)
copy_from_ht_to_te(
"self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
)
copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
)
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
)
if qkv_fused_and_interleaved:
"""
When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
in TE TransformerLayer. Moreover they are interleaved within each head.
Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
Then TE stores weight tensor in the form:
[q1 k1 v1 q2 k2 v2 ...]
This is done to maximally optimize performance time.
"""
te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
def copy_interleave(hf_name, idx):
src = hf_state_dict[layer_prefix + hf_name]
for head_nr in range(config.num_attention_heads):
dst_offset = head_nr * config.head_dim * 3
dst_slice = slice(
dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
)
src_slice = slice(
head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
)
te_qkv_layer[dst_slice, :] = src[src_slice, :]
copy_interleave("self_attn.q_proj.weight", 0)
copy_interleave("self_attn.k_proj.weight", 1)
copy_interleave("self_attn.v_proj.weight", 2)
else:
copy_from_ht_to_te(
"self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
)
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te(
"self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
)
return all_layer_prefixes
{
"cells": [
{
"cell_type": "markdown",
"id": "87e8360b-8d08-44bc-9333-79ba949afe8c",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"# Accelerating Hugging Face Gemma Inference with Transformer Engine"
]
},
{
"cell_type": "markdown",
"id": "2da33092-eef5-46a4-b222-0188cc6e5079",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"## Introduction\n",
"\n",
"Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/generation_animation.gif\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
"<figcaption>\n",
"Animation 1: Hugging Face Gemma model token generation.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"\n",
"### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n",
"\n",
"The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n",
"\n",
"\n",
"Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n",
"\n",
"1. Internal fragmentation\n",
"2. External Fragmentation\n",
"\n",
"More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n",
"\n",
"\n",
"Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n",
"\n",
"\n",
"### 2. CUDA Graphs API\n",
"\n",
"The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n",
"\n",
"One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n",
"\n",
"PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/graphs.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
"<figcaption>\n",
"Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### 3. FP8 Scaling Factors Calibration\n",
"\n",
"This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n",
"\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"\n",
"It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n",
"\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 2:\n",
"Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### 4. FP8 Model Weights\n",
"\n",
"The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n",
"\n",
"The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 3: Model under <b>fp8_autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>fp8_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### Benchmarking\n",
"\n",
"We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
" \n",
"This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n",
"</div>"
]
},
{
"cell_type": "markdown",
"id": "b18f91a9",
"metadata": {},
"source": [
"## Dependencies for this tutorial"
]
},
{
"cell_type": "markdown",
"id": "e5201d77",
"metadata": {},
"source": [
"The following files and media are necessary to effectively run this tutorial:\n",
"\n",
"1. `te_gemma.py`\n",
" - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n",
"2. `te_gemma_loading_weights.py`\n",
" - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n",
"3. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n",
"4. `requirements.txt`\n",
" - This file contains the necessary Python packages for this tutorial.\n",
"5. `media/`\n",
" - This directory contains the images and other artefacts used in this tutorial."
]
},
{
"cell_type": "markdown",
"id": "36767694-a1c5-4a00-a075-7addc55d8307",
"metadata": {},
"source": [
"### Setup and checks"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment and run this cell when running the tutorial for the first time\n",
"# %pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c756ebbd-24c9-4a54-a381-e7c02c555206",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import torch\n",
"cudnn_version = torch.backends.cudnn.version()\n",
"assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\""
]
},
{
"cell_type": "markdown",
"id": "e8dfabbf",
"metadata": {},
"source": [
"## [Baseline] Running Hugging Face generation with Gemma model"
]
},
{
"cell_type": "markdown",
"id": "59560bff",
"metadata": {},
"source": [
"HuggingFace Transformers library offers generation API. \n",
"HuggingFace generation for the Gemma model will be used as a baseline."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2803e0ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why GPUs are so good at graphics. The\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n",
"* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 46.60 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.batch_size = 64\n",
"run_config.max_seq_length = 512\n",
"\n",
"model = init_baseline_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "b3698dc6",
"metadata": {},
"source": [
"Let's put this time into the table for later comparison.\n",
"\n",
"| Models | Time | Speedup | \n",
"|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n",
"| HF (baseline) | 46.6 s | - |"
]
},
{
"cell_type": "markdown",
"id": "8bb40f45",
"metadata": {},
"source": [
"## [Optimization 1] Accelerating generation with Transformer Engine "
]
},
{
"cell_type": "markdown",
"id": "263b40f2",
"metadata": {},
"source": [
"Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9dceef93",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why they are so good at graphics. The second\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
"* NVIDIA is the world leader in AI computing.\n",
"* NVIDIA is the world leader in graphics processing units (GP\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 12.25 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.batch_size = 64\n",
"run_config.max_seq_length = 512\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "b5d40836",
"metadata": {},
"source": [
"With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!"
]
},
{
"cell_type": "markdown",
"id": "006d18e8",
"metadata": {},
"source": [
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |"
]
},
{
"cell_type": "markdown",
"id": "21a89d9c",
"metadata": {},
"source": [
"## [Optimization 2] More acceleration with CUDA Graphs"
]
},
{
"cell_type": "markdown",
"id": "e2d53e7b",
"metadata": {},
"source": [
"Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n",
"```python\n",
" def __init__(self, config : GemmaConfig):\n",
" \"\"\"\n",
" Here \"the trick\" happens. `_model_context_phase` and\n",
" `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n",
" their recorded version. Once the graphs are recorded, they can be\n",
" replayed with minimal usage of CPU and that leads to speedup.\n",
" \"\"\"\n",
" (...)\n",
" # Record the graph for context/prefill phase.\n",
" self._model_context_phase = \n",
" self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n",
"\n",
" (...) \n",
" # Record the graph for generation phase.\n",
" self._model_generation_phase = \n",
" self.record_graph(self._model_generation_phase, self.generation_buffer)\n",
"\n",
" @torch.no_grad()\n",
" def record_graph(self, function, input_tensor):\n",
" \"\"\"\n",
" Records the graph for the given function. The function is invoked on\n",
" argument (self.hidden_states,) and all kernels are recorded.\n",
" It then returns the captured callable, which can be run later while\n",
" minimizing CPU usage.\n",
" \"\"\"\n",
" fp8_recipe = get_default_fp8_recipe()\n",
"\n",
" # We need both autocasts: FP8 for operations that can run in lower\n",
" # precision and BF16 for those that cannot.\n",
" with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n",
" graphed_function = te.pytorch.make_graphed_callables(\n",
" function,\n",
" (input_tensor,),\n",
" fp8_enabled=self.config.fp8,\n",
" fp8_recipe=fp8_recipe,\n",
" allow_unused_input=True,\n",
" num_warmup_iters=5,\n",
" sample_kwargs=sample_kwargs,\n",
" )\n",
" return graphed_function\n",
"```\n",
"\n",
"It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n",
"\n",
"*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "31a3a8a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why they are so good at graphics. The second\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
"* NVIDIA is the world leader in AI computing.\n",
"* NVIDIA is the world leader in graphics processing units (GP\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 6.39 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.max_seq_length = 512\n",
"run_config.batch_size = 64\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# It is necessary to preallocate a static buffer.\n",
"# CUDA graphs require static input tensors for every kernel.\n",
"# This approach may result in a slight increase in memory consumption;\n",
"# however, the substantial speedup achieved makes it worthwhile.\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "53bb430f",
"metadata": {},
"source": [
"A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n",
"\n",
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |"
]
},
{
"cell_type": "markdown",
"id": "0a11b75c",
"metadata": {},
"source": [
"Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n",
"\n",
"1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n",
"2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n",
"3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n",
"4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n",
"\n",
"(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/transformer_cuda_graphed.png\" width=\"80%\" \">\n",
"<figcaption>\n",
" \n",
"Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n",
"</figcaption>\n",
"</figure>\n"
]
},
{
"cell_type": "markdown",
"id": "e6b171a0",
"metadata": {},
"source": [
"## [Optimization 3] Even more acceleration with FP8 precision "
]
},
{
"cell_type": "markdown",
"id": "1a80288b",
"metadata": {},
"source": [
"### Calibrating FP8 scaling factors for correctness\n",
"\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"\n",
"1. Model weight tensors\n",
"2. Input tensors\n",
"\n",
"If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n",
"\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"\n",
"*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n",
" \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"\n",
"The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "aecee0e1",
"metadata": {},
"outputs": [],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"import transformer_engine.pytorch as te\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"run_config.fuse_qkv_params = True\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"# Calibration\n",
"with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" model.train()\n",
" run_forward_pass(model, run_config, num_iters=64)\n",
"\n",
"# Compute scale_fwd with enabled fp8 autocast\n",
"with te.fp8_autocast(enabled=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" run_forward_pass(model, run_config, 1)\n",
"\n",
"# Some parameters are in pointing to the same tensors, double save is avoided here.\n",
"dict_to_save = {\n",
" k: v\n",
" for k, v in model.state_dict().items()\n",
" if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n",
"}\n",
"torch.save(\n",
" dict_to_save, \"calibrated_weights.pth\"\n",
") # <-- Add path to save calibrated weights."
]
},
{
"cell_type": "markdown",
"id": "b6dcd135",
"metadata": {},
"source": [
"### Generation with better FP8 scaling factors\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a913f54d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing the same thing over and over again.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
"* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
"* NVIDIA is a key player\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 8.73 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.fuse_qkv_params = True # This is needed by the last improvement.\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# CUDA Graphs related config\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"# Enable FP8\n",
"run_config.fp8 = True\n",
"# Calibrated fp8 weights are loaded directly from the file.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "8cdbb56c",
"metadata": {},
"source": [
"One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n",
"\n",
"### Use of FP8-only model weights\n",
"\n",
"Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n",
"\n",
"This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n",
"</figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d",
"metadata": {},
"source": [
"However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n",
"\n",
"1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n",
"2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n",
"\n",
"\n",
"Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4562ee82-8c95-4736-8815-cd386078a485",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Memory required for 16384x16384 linear layer: \n",
"FP32 - 1024.0 MB, \n",
"BF16 - 512.0 MB, \n",
"FP8 - 256.0 MB, \n",
"\n",
"Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n",
"Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n",
"Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n"
]
}
],
"source": [
"import torch\n",
"import transformer_engine.pytorch as te\n",
"\n",
"H = 2**14\n",
"D = 2**14\n",
"print(f\"Memory required for {H}x{D} linear layer: \\n\"\n",
" f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n",
" f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n",
" f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n",
"\n",
"linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n",
"print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp32\n",
"\n",
"linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n",
"print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_bf16\n",
"\n",
"# Initialize model weights in FP8 precision\n",
"with torch.no_grad(), te.fp8_model_init(enabled=True):\n",
" linear_fp8 = te.Linear(H, D)\n",
"print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp8"
]
},
{
"cell_type": "markdown",
"id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b",
"metadata": {},
"source": [
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Let's run the code with `fp8_model_init`:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "96264b9c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing the same thing over and over again.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
"* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
"* NVIDIA is a key player\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 4.99 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"# Import necessary packages and methods\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.fuse_qkv_params = True # This is needed by the last improvement.\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# CUDA Graphs related config\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"# Enable FP8 math and FP8 model weights\n",
"run_config.fp8 = True\n",
"run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "3e30ca5a",
"metadata": {},
"source": [
"The final speedup is **9.3x**. \n",
"\n",
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
]
},
{
"cell_type": "markdown",
"id": "c6e87275",
"metadata": {},
"source": [
"## Conclusions"
]
},
{
"cell_type": "markdown",
"id": "7bb2452d",
"metadata": {},
"source": [
"This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n",
"\n",
"1. Support for KV Caching (both non-paged and paged),\n",
"2. Integration with CUDA Graphs,\n",
"3. FP8 scaling factors calibration,\n",
"4. Keeping model parameters in FP8 precision.\n",
"\n",
"It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n",
"\n",
"1. Longer context lengths (with paged KV cache) \n",
"2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n",
"\n",
"Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import sys
import IPython
import random
import string
from te_gemma_loading_weights import load_te_model
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
)
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs
random.seed(42)
torch.manual_seed(42)
class RunConfiguration:
def __init__(self):
self.mixed_precision = "bf16"
self.model_name = None
# FP8 precision settings
self.fp8 = False
self.fp8_model_weights_filename = None
self.fp8_model_init = False
# Cuda graphs
self.generation_cuda_graphs = False
self.cuda_graphs_static_batch_size = 64
self.cuda_graphs_static_max_seq_len = 512
self.cuda_graphs_static_max_context_len = 512
# Finetuning/calibration/generation settings
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
self.batch_size = 64
self.max_seq_length = 512
self.gradient_accumulation_steps = 1
self.num_warmup_steps = 5
self.num_training_steps = 10
# Coalesced QKV params or not
self.fuse_qkv_params = False
# Attention
self.is_paged = False
# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""
# Global variable for the run configuration so that it can be easily accessed
# throughout the jupyter notebook with an `import * from utils` statement
run_config = RunConfiguration()
def get_dataloaders(run_config):
"""
Returns a basic dataloader for the dataset which contains tokenized batches
of text.
"""
dataset = load_dataset(run_config.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize(element):
outputs = tokenizer(
element["text"],
truncation=True,
padding=False,
max_length=run_config.max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
# Tokenize the dataset
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
# Simply pad to the multiple of 16 for both FP8 and BF16 precision
pad_to_multiple_of = 16
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=pad_to_multiple_of,
)
dataloader_params = {
"batch_size": run_config.batch_size,
"collate_fn": data_collator,
"drop_last": True,
}
train_dataloader = DataLoader(dataset, **dataloader_params)
return train_dataloader
def ensure_model_is_downloaded(run_config):
"""
Downloads and caches the model weights if not already downloaded. A valid
Huggingface Access Token is required to download the model weights.
"""
assert run_config.model_name in [
"google/gemma-7b",
], "Only Gemma 7B model is supported!"
# Login using Huggingface Hub API
from huggingface_hub import login
try:
login(run_config.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")
# Download the model if it doesn't exist
from huggingface_hub import snapshot_download
supplied_cache_dir = (
run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None
)
run_config.weights_cache_dir = snapshot_download(
repo_id=run_config.model_name, cache_dir=supplied_cache_dir
)
def init_baseline_model(run_config):
"""
Initializes a baseline HF Gemma model with the model name provided in
the run_config.
"""
# Download and cache the weights if not already downloaded
ensure_model_is_downloaded(run_config)
# Init the model
config = AutoConfig.from_pretrained(run_config.model_name)
# Make sure to use flash_attention to do iso comparison with TEGemmaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
run_config.model_name,
config=config,
torch_dtype=torch.bfloat16,
).cuda()
return model
def init_te_gemma_model(run_config):
"""
Initializes a Gemma model with `GemmaDecoderLayer`s swapped with
`TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled,
the model is initialized from `TEGemmaForCausalLMCudaGraphs` class.
"""
# Download and cache the weights if not already downloaded
ensure_model_is_downloaded(run_config)
cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM
config = AutoConfig.from_pretrained(run_config.model_name)
# Inject all fields from the `run_config` to the model `config` to make the
# code simpler.
for key, value in run_config.__dict__.items():
setattr(config, key, value)
# Initialize the model and move it to the GPU.
model = load_te_model(cls, config).cuda()
# Record the model if CUDA Graphs are enabled.
if run_config.generation_cuda_graphs:
model.record()
return model
def restart_jupyter_notebook():
# Try restarting the Jupyter kernel
IPython.Application.instance().kernel.do_shutdown(True)
# Check whether the device memory has been flushed
if torch.cuda.memory_allocated() != 0:
import warnings
warnings.warn("The device memory hasn't been flushed, trying with a second method!")
# Try restarting the Jupyter kernel another way
# Restart the kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
if torch.cuda.memory_allocated() != 0:
print(
"The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
)
# Suppress the warnings
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
torch.set_warn_always(False)
@torch.no_grad()
def run_forward_pass(model, run_config, num_iters):
"""
Runs the forward pass of the model with sample data. Intended to use for
warmup and/or calibration.
"""
train_dataloader = get_dataloaders(run_config)
model.train()
train_dataloader = enumerate(train_dataloader)
for _ in range(num_iters):
_, batch = next(train_dataloader)
batch["input_ids"] = batch["input_ids"].cuda()
batch["attention_mask"] = batch["attention_mask"].cuda()
model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
###############################################################################
# Benchmarking and example generation functions.
###############################################################################
def print_sample_of_generated_texts(model, run_config):
"""
Prints a sample of generated texts from the input model.
"""
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = [
"Here are the two facts about GPUs:",
"Some facts about NVIDIA:",
"The fundamental theorem of calculus for the layman:",
"A fact about AI:",
]
# Repeat prompts to match batch size
prompts *= run_config.batch_size // len(prompts)
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
max_total_tokens = (
run_config.max_seq_length
if not run_config.generation_cuda_graphs
else run_config.cuda_graphs_static_max_seq_len
)
max_length = inputs["input_ids"].size(1)
new_length = ((max_length + 63) // 64) * max_total_tokens
# Add padding to the left
inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id
)
# Add padding to the left (only intended for baseline generation with HF
# which expects padding to the left)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (new_length - max_length, 0), value=0
)
inputs["input_ids"] = inputs["input_ids"].cuda()
inputs["attention_mask"] = inputs["attention_mask"].cuda()
outputs = model.generate(**inputs, max_new_tokens=50)
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
def print_output(prompts, generated_texts, idx):
print("=" * 30 + f" Generation example {idx+1} " + "=" * 30)
print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"')
print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"')
# Print the output from first two prompts
for i in range(2):
print_output(prompts, generated_texts, i)
def _generate_random_words(num_words, max_word_length):
"""
Generates random words for the benchmark.
"""
words = []
for _ in range(num_words):
word_length = random.randint(1, max_word_length)
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
return words
def benchmark_generation(model, run_config, context_length=20):
"""
Benchmarks the generation time for a random input to the model.
"""
batch_size = run_config.batch_size
max_total_tokens = (
run_config.max_seq_length
if not run_config.generation_cuda_graphs
else run_config.cuda_graphs_static_max_seq_len
)
max_new_tokens = max_total_tokens - context_length
print("\n" + "=" * 80)
print(
f"Benchmarking for batch_size = {batch_size}, prefill tokens ="
f" {context_length} and max new tokens = {max_new_tokens}"
)
input_str = _generate_random_words(batch_size, context_length)
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
inputs = tokenizer(input_str, return_tensors="pt", padding=True)
max_context_tokens = inputs["input_ids"].size(1)
# Add padding to the left
inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"],
(max_total_tokens - max_context_tokens, 0),
value=tokenizer.pad_token_id,
)
# Add padding to the left (only intended for baseline generation with HF
# which expects padding to the left)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0
)
inputs["input_ids"] = inputs["input_ids"].cuda()
inputs["attention_mask"] = inputs["attention_mask"].cuda()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens)
torch.cuda.synchronize()
end.record()
print(f"Time: {start.elapsed_time(end)/1000:.2f} s.")
......@@ -5,7 +5,7 @@
"id": "6a5b2993",
"metadata": {},
"source": [
"# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n",
"# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
......
......@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb
.. toctree::
......
......@@ -267,7 +267,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......
......@@ -264,7 +264,7 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS),
):
rng = jax.random.PRNGKey(args.seed)
......
......@@ -382,7 +382,10 @@ def train_and_evaluate(args):
) as mesh, te.fp8_autocast(
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
mesh_resource=te.MeshResource(
dp_resource=DEVICE_DP_AXIS,
tpsp_resource=DEVICE_TP_AXIS,
),
):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
......
......@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
......@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
......
......@@ -263,7 +263,13 @@ def _train(opts):
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_size,
use_fp8=opts.fp8,
quantization_modes=[
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
......
......@@ -23,38 +23,33 @@ set -x
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime"
pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then
cd $TE_PATH/tests/cpp_distributed
cmake -GNinja -S. -Bbuild
cmake --build build
mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm
fi
......@@ -9,3 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
......@@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pip3 install onnxruntime==1.20.1
pip3 install onnxruntime_extensions==0.13.0
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -6,6 +6,7 @@
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
from importlib import metadata
import os
import time
from pathlib import Path
......@@ -19,6 +20,7 @@ from build_tools.te_version import te_version
from build_tools.utils import (
rocm_build,
cuda_archs,
cuda_version,
get_frameworks,
remove_dups,
)
......@@ -82,6 +84,18 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))):
cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON")
cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution(
f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
......
......@@ -77,6 +77,7 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(../../transformer_engine/common)
include_directories(../../transformer_engine)
include_directories(${CMAKE_SOURCE_DIR})
if(USE_CUDA)
......
......@@ -28,10 +28,19 @@ namespace {
template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
const NormType norm_type, const bool use_cudnn,
const bool zero_centered_gamma_in_weight_dtype, const bool fused_bwd_add) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}
if (norm_type == LayerNorm && fused_bwd_add) {
GTEST_SKIP() << "Fused LN backward+add not currently supported";
}
if (fused_bwd_add && zero_centered_gamma_in_weight_dtype) {
GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported "
<< "in fused norm backward+add";
}
if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
......@@ -46,7 +55,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
return;
}
Tensor input("input", std::vector<size_t>{ N, H }, itype);
......@@ -56,6 +64,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
Tensor bwd_add("bwd_add", std::vector<size_t>{ N, H }, wtype);
Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
......@@ -66,6 +75,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&beta);
setRandomScale(&z);
fillUniform(&dz);
if (fused_bwd_add) {
fillUniform(&bwd_add);
} else {
fillCase<WeightType>(&bwd_add, zeros);
}
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
......@@ -86,7 +100,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
......@@ -126,15 +139,23 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
if (fused_bwd_add) {
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
} else {
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
}
}
if (use_cudnn){
......@@ -168,6 +189,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
bwd_add.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
......@@ -215,30 +237,40 @@ std::vector<std::pair<size_t, size_t>> test_cases = {
} // namespace
class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool>> {};
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool,
bool>> {};
TEST_P(NormTestSuite, TestNorm) {
using namespace transformer_engine;
using namespace test;
using namespace transformer_engine;
using namespace test;
const bool use_cudnn = std::get<0>(GetParam());
const NormType norm_type = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
const bool fused_bwd_add = std::get<7>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(
size.first,
size.second,
zero_centered_gamma,
norm_type,
use_cudnn,
cudnn_zero_centered_gamma_in_weight_dtype,
fused_bwd_add
);
);
);
}
INSTANTIATE_TEST_SUITE_P(
......@@ -251,6 +283,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
......@@ -262,6 +295,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
std::to_string(std::get<6>(info.param)) + "X" +
std::to_string(std::get<7>(info.param));
return name;
});
......@@ -140,7 +140,8 @@ void compute_ref_output(NormType norm_type,
template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad,
const OutputType *add, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
......@@ -179,7 +180,8 @@ void compute_ref_backward(const NormType norm_type, const OutputType *output_gra
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
const compute_t a = static_cast<compute_t>(add[i * H + j]);
const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment