Unverified Commit 42d22740 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Quickstart documentation (#2310)



* jax quickstart guide first commit
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* edit the syntax errors and remove unnecessary comments in utils. Add some footnotes in the quick start notebook
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fix greptiles comments on spelling, deepcopy, vjp function signature comaptibility with speedometer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add Copyright to utils and fix some more greptiles complaints
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add comments to alternative of layers
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Remove weight sharing between different iterations of the transformerLayer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add enum for attention implementations. Fix inconsistency between fuse and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix bug in TransformerLayer expected input shape being [sequence, batch, ...] instead of [batch, sequence,...]
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Changing structure of notebook to  bring fp8 ahead of fuse, to allow for fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* add option to choose between different attention implementation in call of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fix mistake in lacking attention_implementation in FuseTETransformerLayer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Removing AttentionWrapper and custom built DPA, using flax and TE's impl only, removing last mention of Pytorch
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* More changing to markdowns to remove pytorch
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* cosmetics fixes
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* changing names of all implementations
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change fp8_autocast to autocast, make causal mask, and some wording changes
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatartdophung <tdophung@dc2-container-xterm-034.prd.it.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
parent 66aed3ae
{
"cells": [
{
"cell_type": "markdown",
"id": "962d87bb",
"metadata": {},
"source": [
"\n",
"\n",
"# Getting Started\n",
"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
"We recommend you to try understanding the basics of JAX first, using these resources:\n",
"\n",
"- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
"- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n",
"- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n",
"- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
"\n",
"## Let's build a Transformer decoder layer!\n",
"<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"transformer_layer.png\" width=\"20%\">\n",
"<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
"- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
"- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n",
"- `Projection`: `nn.Dense` (Flax)\n",
"- `Dropout`: `nn.Dropout` (Flax)\n",
"- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" Built with plain Flax modules.\n",
" \"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n",
" x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
" return x\n",
"\n",
"class FlaxTransformerLayer(nn.Module):\n",
" \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" num_attention_heads: int\n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
" \n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray, \n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # Reshape to [batch, seq_len, num_heads * head_dim] for Flax MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n",
" # Attention using Flax's MultiHeadDotProductAttention\n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" dropout_rate=self.attention_dropout,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = jnp.bfloat16\n",
"\n",
"# Synthetic data\n",
"key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
"x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
"dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e44ed26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}, 'MultiHeadDotProductAttention_0': {'key': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'out': {'bias': (4096,), 'kernel': (32, 4, 4096)}, 'query': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'value': {'bias': (32, 4), 'kernel': (4096, 32, 4)}}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
")\n",
"\n",
"# Initialize parameters\n",
"params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de91af7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "037bc8d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 17.708301544189453 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ccb16f31",
"metadata": {},
"source": [
"## Meet Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
"\n",
"</div>\n",
"\n",
"As a reminder, the FlaxTransformerLayer above used:\n",
"\n",
"- `nn.LayerNorm`: Flax LayerNorm\n",
"- `nn.Dense`: Flax Dense layer for QKV projection \n",
"- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n",
"- `nn.Dense`: Flax Dense layer for projection\n",
"- `nn.Dropout`: Flax Dropout\n",
"- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
"\n",
"Below we show how to use Transformer Engine Flax modules for better performance:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bed20d6b",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.jax as te\n",
"import transformer_engine.jax.flax as te_flax"
]
},
{
"cell_type": "markdown",
"id": "f28cb444",
"metadata": {},
"source": [
"TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56105579",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
"\n",
"\n",
"class TEUnfusedMLP(nn.Module):\n",
" hidden_size : int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n",
" x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n",
" x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n",
" x = te.activation.activation(x, activation_type=('gelu',))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n",
" return x\n",
"\n",
"class TEUnfusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1 \n",
" use_te_attention: bool = True # True for TE attention, False for Flax attention\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # Fused QKV projection\n",
" qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention - either TE or Flax implementation\n",
" if self.use_te_attention:\n",
" # Use TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
" else:\n",
" # Use Flax's MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" dropout_rate=self.attention_dropout,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # MLP\n",
" mlp = TEUnfusedMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size\n",
" )\n",
"\n",
" x = mlp(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "a76911ac",
"metadata": {},
"source": [
"Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b67511f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.505107879638672 ms\n"
]
}
],
"source": [
"te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
" use_te_attention=False\n",
")\n",
"\n",
"te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "0b230058",
"metadata": {},
"source": [
"Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5146cd99",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 12.80329704284668 ms\n"
]
}
],
"source": [
"te_unfused_transformer = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
")\n",
"\n",
"te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c9a101d3",
"metadata": {},
"source": [
"## Enabling Quantization (FP8 or FP4)\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](.../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: FP8 Metadata Initialization</b>\n",
"\n",
"When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c2eee376",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "de96827c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.615030288696289 ms\n"
]
}
],
"source": [
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
" # Example usage of forward \n",
" y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_unfused_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3801b201",
"metadata": {},
"source": [
"\n",
"## Fused TE Modules\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormDenseGeneral`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "11203785",
"metadata": {},
"outputs": [],
"source": [
"class TEFusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" res = x\n",
"\n",
" # Fused QKV projection\n",
" qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n",
" epsilon=self.layernorm_eps, \n",
" use_bias=True, \n",
" return_layernorm_output=False)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention using TE's DotProductAttention\n",
" attention = TEDotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, \n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n",
" epsilon=self.layernorm_eps,\n",
" use_bias=True,\n",
" activations=('gelu',),\n",
" intermediate_dropout_rate=0.0,\n",
" return_layernorm_output=False\n",
" )(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "334cff59",
"metadata": {},
"source": [
"Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "6b0c705e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.331779479980469 ms\n"
]
}
],
"source": [
"te_fused_transformer = TEFusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
" # Example usage of forward \n",
" y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_fused_transformer.apply,\n",
" variables=te_fused_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a45c12c8",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b2aaa8ef",
"metadata": {},
"outputs": [],
"source": [
"\n",
"te_transformer = te_flax.TransformerLayer(\n",
" hidden_size=hidden_size,\n",
" mlp_hidden_size=ffn_hidden_size, \n",
" num_attention_heads=num_attention_heads,\n",
" mlp_activations=(\"gelu\",),\n",
" self_attn_mask_type='causal',\n",
" layernorm_epsilon=1e-5,\n",
" use_bias=True,\n",
" intermediate_dropout=0.0,\n",
" enable_relative_embedding=False,\n",
" self_attn_bias_type='no_bias',\n",
" hidden_dropout=0.0\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
" y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b9cdbf22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.23741340637207 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" model_apply_fn=te_transformer.apply,\n",
" model_init_fn=te_transformer.init,\n",
" variables=te_transformer_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" dropout_key=dropout_key,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import jax.numpy as jnp
import time
import math
from typing import Callable, Any, Dict, Optional, Tuple
from flax import linen as nn
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
def speedometer(
model_apply_fn: Callable,
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
model_init_fn = None
train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)
# Warm up runs
key = dropout_key
for _ in range(warmup_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
# Timing runs
start = time.time()
for _ in range(timing_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
end = time.time()
print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
def create_train_step_fn(
model_apply_fn: Callable,
autocast_kwargs: Dict[str, Any],
forward_kwargs: Dict[str, Any] = None,
) -> Callable:
"""
Creates a JIT-compiled function that performs one forward/backward pass.
"""
if forward_kwargs is None:
forward_kwargs = {}
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
rngs = {"dropout": dropout_key}
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
out = model_apply_fn(variables, inp, **call_kwargs)
# grad_target = derivative of L (loss fn) over y (output) = signma(L)/sigma(y)
# where grad_w(L) = gradient of loss over params = sigma(L)/sigma(y) * sigma(y)/sigma(w) --> chain rule
# sigma(y)/sigma(w) = J_model(w)
return jnp.vdot(out, grad_target)
def fwd_bwd_fn(*args, **kwargs):
return jax.value_and_grad(loss_fn, argnums=(0, 1))(*args, **kwargs)
# Use jax.value_and_grad to get the loss value and gradients simultaneously. (forward + backward pass)
# ∇_params[output^T · grad_target] = grad_target^T · J_output(params) = VJP
# fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment