Unverified Commit a9767407 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[docs] Getting started refactor (#2534)



* docs: Add comprehensive Getting Started guide with benchmarks

- Add new Getting Started documentation with PyTorch and JAX tutorials
- Include benchmark scripts demonstrating TE performance benefits
- Add CSS styling for code output and tabs
- Replace old quickstart notebooks with improved documentation
- Add transformer layer diagram (SVG)
- Update docs configuration and workflow
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* 2026 in copyright
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c90a9214
......@@ -17,7 +17,7 @@ jobs:
uses: actions/checkout@v3
- name: 'Install dependencies'
run: |
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 sphinx-tabs==3.4.7
pip install breathe==4.35.0 sphinx-autoapi==3.3.2
sudo apt-get install -y pandoc graphviz doxygen
export GIT_SHA=$(git show-ref --hash HEAD)
......
/* Custom styling for program output blocks */
.program-output {
background-color: #f8f9fa;
padding: 0; /* No padding at all */
margin: 0; /* No margins at all */
border-radius: 0; /* No rounded corners */
font-family: 'Courier New', monospace;
font-size: 14px;
line-height: 1.5;
width: 100%;
max-width: 100%;
}
.program-output pre {
margin: 0;
padding: 0;
background: transparent !important;
border: none !important;
color: #2c3e50;
width: 100%;
}
.program-output .highlight {
background: transparent !important;
margin: 0;
width: 100%;
}
/* Alternative lighter style */
.output-block {
background-color: #fafbfc;
border: 1px solid #e1e4e8;
padding: 10px 14px;
margin: 10px 0;
border-radius: 3px;
font-family: 'SF Mono', 'Consolas', monospace;
font-size: 13px;
color: #24292e;
}
/* Console-like output style */
.console-output {
background-color: #1e1e1e;
border-left: 3px solid #76b900;
padding: 14px 18px;
margin: 12px 0;
border-radius: 5px;
font-family: 'Fira Code', 'Consolas', monospace;
font-size: 13px;
color: #d4d4d4;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.console-output pre {
margin: 0;
color: #d4d4d4;
background: transparent !important;
}
/* Custom styling for sphinx-tabs */
.sphinx-tabs {
margin-bottom: 1rem;
}
.sphinx-tabs-tab {
background-color: #f4f4f4;
border: 1px solid #ccc;
border-bottom: none;
padding: 0.5rem 1rem;
margin-right: 0.5rem;
cursor: pointer;
font-weight: 500;
transition: background-color 0.2s;
}
.sphinx-tabs-tab:hover {
background-color: #e0e0e0;
}
.sphinx-tabs-tab[aria-selected="true"] {
background-color: #76b900; /* NVIDIA green */
color: white;
border-color: #76b900;
margin-right: 0.5rem;
}
.sphinx-tabs-panel {
border: 1px solid #ccc;
padding: 1rem;
background-color: #f9f9f9;
}
/* Dark mode support for RTD theme */
.rst-content .sphinx-tabs-tab {
color: #333;
}
.rst-content .sphinx-tabs-tab[aria-selected="true"] {
color: white;
}
......@@ -58,6 +58,7 @@ extensions = [
"nbsphinx",
"breathe",
"autoapi.extension",
"sphinx_tabs.tabs",
]
templates_path = ["_templates"]
......@@ -83,6 +84,8 @@ html_show_sphinx = False
html_css_files = [
"css/nvidia_font.css",
"css/nvidia_footer.css",
"css/rtabs.css",
"css/output-style.css",
]
html_theme_options = {
......
......@@ -13,7 +13,7 @@
"id": "6dcbf25a",
"metadata": {},
"source": [
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
"This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
......
{
"cells": [
{
"cell_type": "markdown",
"id": "da9fd6a8",
"metadata": {},
"source": [
"# Getting Started\n",
"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"## Let's build a Transformer layer!\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular PyTorch modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating a GPT encoder layer using plain PyTorch. Figure 1 shows the overall structure.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"transformer_layer.png\" width=\"20%\">\n",
"<figcaption> Figure 1: Structure of a GPT encoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `torch.nn.LayerNorm`\n",
"- `QKV Projection`: `torch.nn.Linear` (conceptually three `Linear` layers for Q, K, and V separately, but we fuse into a single `Linear` layer that is three times larger)\n",
"- `DotProductAttention`: `DotProductAttention` from [quickstart_utils.py](quickstart_utils.py)\n",
"- `Projection`: `torch.nn.Linear`\n",
"- `Dropout`: `torch.nn.Dropout`\n",
"- `MLP`: `BasicMLP` from [quickstart_utils.py](quickstart_utils.py)\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_utils.py](quickstart_utils.py). Putting it all together:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2be43d64",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import quickstart_utils as utils\n",
"\n",
"class BasicTransformerLayer(torch.nn.Module):\n",
" def __init__(\n",
" self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.mlp = utils.BasicMLP(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" ) \n",
" \n",
" def forward(\n",
" self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor\n",
" ) -> torch.Tensor:\n",
" res = x\n",
" x = self.ln1(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = self.qkv_projection(x)\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln2(x)\n",
" x = self.mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "40724d1d",
"metadata": {},
"source": [
"That's it! We now have a simple Transformer layer. We can test it:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a786f0ea",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = torch.float16\n",
"\n",
"# Synthetic data\n",
"x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n",
"dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ffdbfb7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BasicTransformerLayer(\n",
" (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
" (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)\n",
" (attention): DotProductAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (projection): Linear(in_features=4096, out_features=4096, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): BasicMLP(\n",
" (linear1): Linear(in_features=4096, out_features=16384, bias=True)\n",
" (linear2): Linear(in_features=16384, out_features=4096, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"basic_transformer = BasicTransformerLayer(\n",
" hidden_size,\n",
" ffn_hidden_size,\n",
" num_attention_heads,\n",
")\n",
"basic_transformer.to(dtype=dtype).cuda()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0162ad40",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = basic_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "65ae6dd6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.0663916015625 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" basic_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "43717e36",
"metadata": {},
"source": [
"## Meet Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.\n",
"\n",
"</div>\n",
"\n",
"Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "004d3c92",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.pytorch as te"
]
},
{
"cell_type": "markdown",
"id": "1931f911",
"metadata": {},
"source": [
"TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1f44db50",
"metadata": {},
"outputs": [],
"source": [
"class BasicTEMLP(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int) -> None:\n",
" super().__init__()\n",
" self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)\n",
" self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear1(x)\n",
" x = torch.nn.functional.gelu(x, approximate='tanh')\n",
" x = self.linear2(x)\n",
" return x \n",
" \n",
"class BasicTETransformerLayer(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n",
" self.mlp = BasicTEMLP(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" )\n",
" \n",
" def forward(self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor):\n",
" res = x\n",
" x = self.ln1(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = self.qkv_projection(x)\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln2(x)\n",
" x = self.mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "916531e8",
"metadata": {},
"outputs": [],
"source": [
"basic_te_transformer = BasicTETransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
")\n",
"basic_te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_basic_te_model(basic_te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3643fa54",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = basic_te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "10b92894",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.1413232421875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" basic_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3f990226",
"metadata": {},
"source": [
"## Fused TE Modules\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormLinear`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c55eae1f",
"metadata": {},
"outputs": [],
"source": [
"class FusedTETransformerLayer(torch.nn.Module):\n",
" def __init__(self,\n",
" hidden_size: int,\n",
" ffn_hidden_size: int,\n",
" num_attention_heads: int,\n",
" layernorm_eps: int = 1e-5,\n",
" attention_dropout: float = 0.1,\n",
" hidden_dropout: float = 0.1):\n",
" super().__init__()\n",
" self.num_attention_heads = num_attention_heads\n",
" self.kv_channels = hidden_size // num_attention_heads\n",
" self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)\n",
" self.attention = utils.DotProductAttention(\n",
" num_attention_heads=num_attention_heads,\n",
" kv_channels=self.kv_channels,\n",
" attention_dropout=attention_dropout,\n",
" )\n",
" self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n",
" self.dropout = torch.nn.Dropout(hidden_dropout)\n",
" self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)\n",
" \n",
" \n",
" def forward(self, \n",
" x: torch.Tensor, \n",
" attention_mask: torch.Tensor):\n",
" res = x\n",
" qkv = self.ln_qkv(x)\n",
" \n",
" # Split qkv into query, key and value\n",
" qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n",
" \n",
" x = self.attention(q, k, v, attention_mask)\n",
" x = self.projection(x)\n",
" x = self.dropout(x)\n",
" x = res + x\n",
" res = x\n",
" x = self.ln_mlp(x)\n",
" \n",
" return x + res"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "85949421",
"metadata": {},
"outputs": [],
"source": [
"fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"fused_te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_fused_te_model(fused_te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2c263e71",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = fused_te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "24e101bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 43.1981201171875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" fused_te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "33f13c26",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ec8c3685",
"metadata": {},
"outputs": [],
"source": [
"te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e48cd590",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(1234)\n",
"y = te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "3ec3707d-e63f-4899-8308-b11c55b5caa4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 39.99169921875 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
")"
]
},
{
"cell_type": "markdown",
"id": "4034c3eb-8958-49f2-85f6-30c94977d884",
"metadata": {},
"source": [
"## Enabling FP8\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "31256aa7-3d5e-425c-91ab-502b1326a748",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"\n",
"te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n",
"te_transformer.to(dtype=dtype).cuda()\n",
"utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)\n",
"\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
"torch.manual_seed(1234)\n",
"with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n",
" y = te_transformer(x, attention_mask=None)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "793ebd2d-b84b-47bc-811a-7991df8500aa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 28.61394775390625 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" te_transformer,\n",
" x,\n",
" dy,\n",
" forward_kwargs = { \"attention_mask\": None },\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "markdown",
"id": "962d87bb",
"metadata": {},
"source": [
"\n",
"\n",
"# Getting Started\n",
"\n",
"## Overview\n",
"\n",
"Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
"\n",
"This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
"We recommend you to try understanding the basics of JAX first, using these resources:\n",
"\n",
"- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
"- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n",
"- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n",
"- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
"\n",
"## Let's build a Transformer decoder layer!\n",
"<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
"\n",
"</div>\n",
"\n",
"Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"transformer_layer.png\" width=\"20%\">\n",
"<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
"</figure>\n",
"\n",
"We construct the components as follows:\n",
"\n",
"- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
"- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
"- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n",
"- `Projection`: `nn.Dense` (Flax)\n",
"- `Dropout`: `nn.Dropout` (Flax)\n",
"- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
"\n",
"Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5284a38",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import quickstart_jax_utils as utils\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
"source": [
"class FlaxMLP(nn.Module):\n",
" \"\"\"Feed-forward network in Transformer layer\n",
" Built with plain Flax modules.\n",
" \"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
" x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n",
" x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
" return x\n",
"\n",
"class FlaxTransformerLayer(nn.Module):\n",
" \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
" hidden_size: int\n",
" ffn_hidden_size: int\n",
" num_attention_heads: int\n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
" \n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray, \n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # Fused QKV projection\n",
" qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" # which is the correct format for dot_product_attention\n",
" \n",
" # Apply dot product attention\n",
" # Note: dot_product_attention expects mask to be broadcastable to \n",
" # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n",
" \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
"\n",
" # Output projection\n",
" x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
" \n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
" res = x\n",
" x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
" \n",
" # MLP\n",
" mlp = FlaxMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size,\n",
" )\n",
" x = mlp(x)\n",
" \n",
" return x + res\n"
]
},
{
"cell_type": "markdown",
"id": "fbc3510b",
"metadata": {},
"source": [
"## Testing Performance\n",
"\n",
"Now let's test the performance of our FlaxTransformerLayer:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
"outputs": [],
"source": [
"# Layer configuration\n",
"hidden_size = 4096\n",
"sequence_length = 2048\n",
"batch_size = 4\n",
"ffn_hidden_size = 16384\n",
"num_attention_heads = 32\n",
"dtype = jnp.bfloat16\n",
"\n",
"# Synthetic data\n",
"key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
"x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
"dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e44ed26d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
]
}
],
"source": [
"# Initialize the FlaxTransformerLayer\n",
"flax_transformer = FlaxTransformerLayer(\n",
" hidden_size=hidden_size,\n",
" ffn_hidden_size=ffn_hidden_size,\n",
" num_attention_heads=num_attention_heads,\n",
")\n",
"\n",
"# Initialize parameters\n",
"params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
"print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de91af7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input shape: (4, 2048, 4096)\n",
"Output shape: (4, 2048, 4096)\n",
"Output dtype: float32\n",
"Forward pass completed successfully!\n"
]
}
],
"source": [
"# Example usage of forward pass\n",
"y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
"print(f\"Input shape: {x.shape}\")\n",
"print(f\"Output shape: {y.shape}\")\n",
"print(f\"Output dtype: {y.dtype}\")\n",
"print(\"Forward pass completed successfully!\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "037bc8d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 19.258604049682617 ms\n"
]
}
],
"source": [
"import importlib\n",
"import quickstart_jax_utils\n",
"importlib.reload(quickstart_jax_utils)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ccb16f31",
"metadata": {},
"source": [
"## Meet Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
"\n",
"</div>\n",
"\n",
"As a reminder, the FlaxTransformerLayer above used:\n",
"\n",
"- `nn.LayerNorm`: Flax LayerNorm\n",
"- `nn.Dense`: Flax Dense layer for QKV projection \n",
"- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n",
"- `nn.Dense`: Flax Dense layer for projection\n",
"- `nn.Dropout`: Flax Dropout\n",
"- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
"\n",
"Below we show how to use Transformer Engine Flax modules for better performance:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bed20d6b",
"metadata": {},
"outputs": [],
"source": [
"import transformer_engine.jax as te\n",
"import transformer_engine.jax.flax as te_flax"
]
},
{
"cell_type": "markdown",
"id": "f28cb444",
"metadata": {},
"source": [
"TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56105579",
"metadata": {},
"outputs": [],
"source": [
"class TEUnfusedMLP(nn.Module):\n",
" hidden_size : int\n",
" ffn_hidden_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n",
" x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n",
" x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n",
" x = te.activation.activation(x, activation_type=('gelu',))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n",
" return x\n",
"\n",
"class TEUnfusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1 \n",
" use_te_attention: bool = True # True for TE attention, False for Flax attention\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray: \n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # Fused QKV projection\n",
" qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention - either TE or Flax implementation\n",
" if self.use_te_attention:\n",
" # Use TE's DotProductAttention\n",
" attention = te_flax.DotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" )\n",
" x = attention(\n",
" q, k, v,\n",
" # Causal mask does not need an explicit instatiated mask as specialized kernels exist to handle it\n",
" sequence_descriptor=None, \n",
" deterministic=deterministic\n",
" )\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
" else:\n",
" # Use Flax's MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
"\n",
" # Create causal mask if not provided\n",
" if attention_mask is None:\n",
" attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
" \n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" dropout_rate=self.attention_dropout,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
"\n",
" # MLP\n",
" mlp = TEUnfusedMLP(\n",
" hidden_size=self.hidden_size,\n",
" ffn_hidden_size=self.ffn_hidden_size\n",
" )\n",
"\n",
" x = mlp(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "a76911ac",
"metadata": {},
"source": [
"Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4b67511f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.003193855285645 ms\n"
]
}
],
"source": [
"te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
" use_te_attention=False\n",
")\n",
"\n",
"te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "0b230058",
"metadata": {},
"source": [
"Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5146cd99",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 8.897695541381836 ms\n"
]
}
],
"source": [
"te_unfused_transformer = TEUnfusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads,\n",
")\n",
"\n",
"te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c9a101d3",
"metadata": {},
"source": [
"## Enabling Quantization (FP8 or FP4)\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We configure a TE module to perform compute in FP8.\n",
"\n",
"</div>\n",
"\n",
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
"\n",
"<div class=\"alert alert-warning\">\n",
"\n",
"<b>Important: FP8 Metadata Initialization</b>\n",
"\n",
"When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
"\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c2eee376",
"metadata": {},
"outputs": [],
"source": [
"from transformer_engine.common.recipe import Format, DelayedScaling\n",
"fp8_format = Format.HYBRID\n",
"fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "de96827c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 5.651178359985352 ms\n"
]
}
],
"source": [
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
"\n",
" # Example usage of forward \n",
" y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_unfused_transformer.apply,\n",
" variables=te_unfused_params, # Ensure the correct `params` is passed\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3801b201",
"metadata": {},
"source": [
"\n",
"## Fused TE Modules\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Summary</b>\n",
" \n",
"We optimize the example Transformer layer with TE modules for fused operations.\n",
"\n",
"</div>\n",
"\n",
"The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
"\n",
"Transformer Engine therefore provides coarser modules that span multiple layers:\n",
"\n",
"* `LayerNormDenseGeneral`\n",
"* `LayerNormMLP`\n",
"* `TransformerLayer`\n",
"\n",
"To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
"\n",
"Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "11203785",
"metadata": {},
"outputs": [],
"source": [
"class TEFusedTransformerLayer(nn.Module):\n",
" hidden_size: int\n",
" ffn_hidden_size: int \n",
" num_attention_heads: int \n",
" layernorm_eps: float = 1e-5\n",
" attention_dropout: float = 0.1\n",
"\n",
" def setup(self):\n",
" self.kv_channels = self.hidden_size // self.num_attention_heads\n",
"\n",
" @nn.compact\n",
" def __call__(\n",
" self, \n",
" x: jnp.ndarray,\n",
" attention_mask: Optional[jnp.ndarray] = None,\n",
" deterministic: bool = False\n",
" ) -> jnp.ndarray:\n",
" res = x\n",
"\n",
" # Fused QKV projection\n",
" qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n",
" epsilon=self.layernorm_eps, \n",
" use_bias=True, \n",
" return_layernorm_output=False)(x)\n",
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
"\n",
" # Attention using TE's DotProductAttention\n",
" attention = te_flax.DotProductAttention(\n",
" head_dim=self.kv_channels,\n",
" num_attention_heads=self.num_attention_heads,\n",
" num_gqa_groups=self.num_attention_heads, \n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
" x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
" x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
" x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
"\n",
" x = res + x\n",
"\n",
" # Second residual connection\n",
" res = x\n",
" x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n",
" epsilon=self.layernorm_eps,\n",
" use_bias=True,\n",
" activations=('gelu',),\n",
" intermediate_dropout_rate=0.0,\n",
" return_layernorm_output=False\n",
" )(x, deterministic=deterministic)\n",
"\n",
" return x + res"
]
},
{
"cell_type": "markdown",
"id": "334cff59",
"metadata": {},
"source": [
"Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6b0c705e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 5.493879318237305 ms\n"
]
}
],
"source": [
"te_fused_transformer = TEFusedTransformerLayer(\n",
" hidden_size, \n",
" ffn_hidden_size, \n",
" num_attention_heads\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
" # Example usage of forward \n",
" y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n",
"\n",
"utils.speedometer(\n",
" model_apply_fn=te_fused_transformer.apply,\n",
" variables=te_fused_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a45c12c8",
"metadata": {},
"source": [
"Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b2aaa8ef",
"metadata": {},
"outputs": [],
"source": [
"te_transformer = te_flax.TransformerLayer(\n",
" hidden_size=hidden_size,\n",
" mlp_hidden_size=ffn_hidden_size, \n",
" num_attention_heads=num_attention_heads,\n",
" mlp_activations=(\"gelu\",),\n",
" self_attn_mask_type='causal',\n",
" layernorm_epsilon=1e-5,\n",
" use_bias=True,\n",
" intermediate_dropout=0.0,\n",
" enable_relative_embedding=False,\n",
" self_attn_bias_type='no_bias',\n",
" hidden_dropout=0.0,\n",
")\n",
"\n",
"with te.autocast(enabled=True, recipe=fp8_recipe):\n",
" te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
" y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b9cdbf22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 5.334172248840332 ms\n"
]
}
],
"source": [
"utils.speedometer(\n",
" model_apply_fn=te_transformer.apply,\n",
" model_init_fn=te_transformer.init,\n",
" variables=te_transformer_params,\n",
" input=x,\n",
" output_grad=dy,\n",
" forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
" autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n",
" rngs={\"dropout\": dropout_key},\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -5,13 +5,9 @@
import jax
import jax.numpy as jnp
import time
import math
from typing import Callable, Any, Dict, Optional, Tuple
from flax import linen as nn
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
def speedometer(
......
......@@ -264,7 +264,7 @@
"id": "5e9310c9",
"metadata": {},
"source": [
"# Transformer Engine"
"## Transformer Engine"
]
},
{
......
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Choose your framework to get started with Transformer Engine:
.. toctree::
:maxdepth: 1
PyTorch <examples/quickstart.ipynb>
JAX <examples/quickstart_jax.ipynb>
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
# BENCHMARK_BASELINE_OUTPUT_START
Baseline Flax:
Mean time: 86.580 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 42.252 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.054 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 22.638 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 23.703 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 22.812 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_jax_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - JAX Example
======================================================
This example shows how to build a Transformer decoder layer using JAX/Flax
and how to optimize it with Transformer Engine.
"""
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Optional
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_jax import speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.bfloat16
# Create synthetic data
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)
mesh_resource = MeshResource()
# =============================================================================
# Baseline: Pure Flax Implementation
# =============================================================================
# BASELINE_MLP_START
class FlaxMLP(nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain Flax modules.
"""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class FlaxTransformerLayer(nn.Module):
"""Basic Transformer layer using plain Flax modules."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# Fused QKV projection
qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = baseline.init(key, x, deterministic=False)
print("Baseline Flax:")
time_baseline = speedometer(
baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline"
)
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x)
x = x.reshape(*x.shape[:-1], 1, x.shape[-1])
x = te.activation.activation(x, activation_type=("gelu",))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(nn.Module):
"""Transformer layer using basic TE modules (without TE attention)."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = te_unfused.init(key, x, deterministic=False)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x)
qkv = te_flax.DenseGeneral(
features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16
)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=False, mesh_resource=mesh_resource):
params = te_unfused_attn.init(key, x, deterministic=False)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource},
label="te_unfused_attn",
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_unfused_fp8.init(key, x, deterministic=False)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(nn.Module):
"""Transformer layer using fused TE modules for better performance."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
# Fused LayerNorm + QKV projection
qkv, _ = te_flax.LayerNormDenseGeneral(
features=3 * self.hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
return_layernorm_output=False,
)(x)
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels)
q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
qkv_layout="bshd_bshd_bshd",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
# Fused LayerNorm + MLP
x, _ = te_flax.LayerNormMLP(
intermediate_dim=self.ffn_hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
activations=("gelu",),
intermediate_dropout_rate=0.0,
return_layernorm_output=False,
)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_fused_fp8.init(key, x, deterministic=False)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = te_flax.TransformerLayer(
hidden_size=hidden_size,
mlp_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
mlp_activations=("gelu",),
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
use_bias=True,
attention_dropout=0.0,
intermediate_dropout=0.0,
hidden_dropout=0.0,
enable_relative_embedding=False,
self_attn_bias_type="no_bias",
dtype=jnp.bfloat16,
transpose_batch_sequence=False,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_transformer_layer.init(key, x, deterministic=False)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_jax_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_jax_summary.csv")
Implementation,Time (ms),Speedup
Baseline Flax,86.58,1.00x
TE Unfused,42.25,2.05x
TE Unfused + TE Attention,35.05,2.47x
TE Unfused + TE Attention + FP8,22.64,3.82x
TE Fused + TE Attention + FP8,23.70,3.65x
TE TransformerLayer + FP8,22.81,3.80x
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926
dispatch key: ADInplaceOrView
previous kernel: no debug info
new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
self.m.impl(
# BENCHMARK_BASELINE_OUTPUT_START
Baseline PyTorch:
Mean time: 48.280 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 49.342 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.709 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 23.406 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 22.964 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 21.670 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_pytorch_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - PyTorch Example
==========================================================
This example shows how to build a Transformer layer using PyTorch
and how to optimize it with Transformer Engine.
"""
from typing import Optional
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_pytorch import DotProductAttention, speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16
# Create synthetic data
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
# =============================================================================
# Baseline: Pure PyTorch Implementation
# =============================================================================
# BASELINE_MLP_START
class PyTorchMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain PyTorch modules.
"""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class PyTorchTransformerLayer(torch.nn.Module):
"""Basic Transformer layer using plain PyTorch modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = torch.nn.Dropout(hidden_dropout)
self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = (
PyTorchTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("Baseline PyTorch:")
time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline")
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(torch.nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(torch.nn.Module):
"""Transformer layer using basic TE modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = (
TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(torch.nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn"
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(torch.nn.Module):
"""Transformer layer using fused TE modules for better performance."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
# Fused LayerNorm + QKV projection
self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
# Fused LayerNorm + MLP
self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
# Fused LayerNorm + QKV projection
qkv = self.ln_qkv(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Fused LayerNorm + MLP
res = x
x = self.ln_mlp(x)
x = self.dropout2(x)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = (
TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = (
te.TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
bias=True,
hidden_dropout=0.0,
attention_dropout=0.0,
)
.to(dtype=dtype)
.cuda()
)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_pytorch_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_pytorch_summary.csv")
Implementation,Time (ms),Speedup
Baseline PyTorch,48.28,1.00x
TE Unfused,49.34,0.98x
TE Unfused + TE Attention,35.71,1.35x
TE Unfused + TE Attention + FP8,23.41,2.06x
TE Fused + TE Attention + FP8,22.96,2.10x
TE TransformerLayer + FP8,21.67,2.23x
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for Getting Started with Transformer Engine - JAX
====================================================================
Helper classes and functions for the getting started examples.
"""
import time
from typing import Callable, Any, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
import transformer_engine.jax as te
from transformer_engine.jax.sharding import MeshResource
def speedometer(
apply_fn: Callable,
params: Any,
x: jnp.ndarray,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 100,
warmup_iters: int = 10,
label: str = "benchmark",
) -> float:
"""Measure average forward + backward pass time for a JAX module.
Args:
apply_fn: JIT-compiled apply function
params: Model parameters
x: Input tensor
forward_kwargs: Additional kwargs for forward pass
autocast_kwargs: Kwargs for te.autocast context
timing_iters: Number of timing iterations
warmup_iters: Number of warmup iterations
label: Optional label for logging
Returns:
Average time per iteration in milliseconds
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
else:
autocast_kwargs = dict(autocast_kwargs)
autocast_kwargs.setdefault("mesh_resource", MeshResource())
def loss_fn(params, x):
y = apply_fn(params, x, **forward_kwargs)
return jnp.sum(y)
# JIT compile within autocast context
with te.autocast(**autocast_kwargs):
grad_fn = jax.jit(jax.value_and_grad(loss_fn))
# Warmup runs
for _ in range(warmup_iters):
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
# Timing runs
times = []
for _ in range(timing_iters):
start = time.perf_counter()
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
times.append(time.perf_counter() - start)
avg_time = sum(times) / len(times) * 1000
print(f"Mean time: {avg_time:.3f} ms")
return avg_time
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for Getting Started with Transformer Engine - PyTorch
========================================================================
Helper classes and functions for the getting started examples.
"""
import math
from typing import Optional
import torch
import transformer_engine.pytorch as te
def speedometer(
module: torch.nn.Module,
x: torch.Tensor,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 100,
warmup_iters: int = 10,
label: str = "benchmark",
) -> float:
"""Measure average forward + backward pass time for a PyTorch module.
Args:
module: PyTorch module to benchmark
x: Input tensor
forward_kwargs: Additional kwargs for forward pass
autocast_kwargs: Kwargs for te.autocast context
timing_iters: Number of timing iterations
warmup_iters: Number of warmup iterations
label: Optional label for logging
Returns:
Average time per iteration in milliseconds
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
with te.autocast(**autocast_kwargs):
y = module(x, **forward_kwargs)
loss = y.sum()
loss.backward()
torch.cuda.synchronize()
# Timing runs
start.record()
for _ in range(timing_iters):
with te.autocast(**autocast_kwargs):
y = module(x, **forward_kwargs)
loss = y.sum()
loss.backward()
end.record()
torch.cuda.synchronize()
avg_time = start.elapsed_time(end) / timing_iters
print(f"Mean time: {avg_time:.3f} ms")
return avg_time
class DotProductAttention(torch.nn.Module):
"""Attention operation in Transformer layer.
Built with plain PyTorch modules.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
) -> None:
super().__init__()
self.projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout)
def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
if mask is not None:
inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b = query.size(1)
np = query.size(2)
sq = query.size(0)
sk = key.size(0)
hn = value.size(3)
query = query.view(sq, b * np, -1)
key = key.view(sk, b * np, -1)
bmm1 = (
torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
)
attention_scores = bmm1.view(b, np, sq, sk)
attention_probs = self.masked_softmax(attention_scores, attention_mask)
attention_probs = self.dropout(attention_probs)
value = value.view(sk, b * np, -1)
attention_probs = attention_probs.view(b * np, sq, -1)
context = torch.bmm(attention_probs, value.transpose(0, 1))
context = context.view(b, np, sq, hn)
context = context.permute(2, 0, 1, 3).contiguous()
context = context.view(sq, b, self.projection_size)
return context
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Overview
--------
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs,
providing better performance with lower memory utilization in both training and inference.
It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as
8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs.
TE implements a collection of highly optimized building blocks for popular Transformer
architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly
with your deep learning code.
Currently two frameworks are supported: PyTorch and JAX.
.. tabs::
.. tab:: PyTorch
Basic knowledge of PyTorch is recommended:
- `PyTorch Tutorials <https://pytorch.org/tutorials/>`_
- `PyTorch Documentation <https://pytorch.org/docs/stable/index.html>`_
.. tab:: JAX
We recommend understanding the basics of JAX first:
- `Thinking in JAX <https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html>`_
- `JAX 101 <https://docs.jax.dev/en/latest/jax-101.html>`_
- `Key concepts in JAX <https://docs.jax.dev/en/latest/key-concepts.html>`_
- `Flax 101 <https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html>`_
Baseline: Pure Framework Implementation
---------------------------------------
Let's build a Transformer decoder layer!
We'll create a basic GPT-style layer with causal masking,
which prevents each position from attending to future positions. This will be our baseline
for later comparisons with Transformer Engine.
.. raw:: html
:file: transformer_layer.svg
.. raw:: html
<p style="text-align: center; font-style: italic; color: #666;">Structure of a GPT decoder layer</p>
We construct the components as follows:
.. tabs::
.. tab:: PyTorch
* **LayerNorm**: ``torch.nn.LayerNorm``
* **QKV Projection**: ``torch.nn.Linear`` (fused Q, K, V into single layer 3x larger)
* **DotProductAttention**: Custom implementation using ``torch.bmm``
* **Projection**: ``torch.nn.Linear``
* **Dropout**: ``torch.nn.Dropout``
* **MLP**: Two ``torch.nn.Linear`` layers with ``torch.nn.functional.gelu`` activation
.. tab:: JAX
* **LayerNorm**: ``nn.LayerNorm``
* **QKV Projection**: ``nn.Dense`` (fused Q, K, V into single layer 3x larger)
* **DotProductAttention**: ``nn.dot_product_attention``
* **Projection**: ``nn.Dense``
* **Dropout**: ``nn.Dropout``
* **MLP**: Two ``nn.Dense`` layers with ``nn.gelu`` activation
Putting it all together:
.. tabs::
.. tab:: PyTorch
First, define the MLP block:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BASELINE_MLP_START
:end-before: # BASELINE_MLP_END
Now, putting it all together into a GPT decoder layer:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BASELINE_LAYER_START
:end-before: # BASELINE_LAYER_END
Benchmark the baseline implementation:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_BASELINE_START
:end-before: # BENCHMARK_BASELINE_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_BASELINE_OUTPUT_START
:end-before: # BENCHMARK_BASELINE_OUTPUT_END
.. tab:: JAX
First, define the MLP block:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BASELINE_MLP_START
:end-before: # BASELINE_MLP_END
Now, putting it all together into a GPT decoder layer:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BASELINE_LAYER_START
:end-before: # BASELINE_LAYER_END
Benchmark the baseline implementation:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_BASELINE_START
:end-before: # BENCHMARK_BASELINE_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_BASELINE_OUTPUT_START
:end-before: # BENCHMARK_BASELINE_OUTPUT_END
TE Unfused: Basic TE Modules
----------------------------
Now let's replace the standard framework modules with TE equivalents.
This is the simplest way to start using Transformer Engine.
.. tabs::
.. tab:: PyTorch
Replace PyTorch modules with TE equivalents:
.. code-block:: python
import transformer_engine.pytorch as te
Mapping:
* ``torch.nn.Linear`` → ``te.Linear``
* ``torch.nn.LayerNorm`` → ``te.LayerNorm``
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_MLP_START
:end-before: # TE_UNFUSED_MLP_END
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_LAYER_START
:end-before: # TE_UNFUSED_LAYER_END
Benchmark the TE unfused implementation:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_START
:end-before: # BENCHMARK_TE_UNFUSED_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END
.. tab:: JAX
Replace Flax modules with TE equivalents:
.. code-block:: python
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
Mapping:
* ``nn.Dense`` → ``te_flax.DenseGeneral``
* ``nn.LayerNorm`` → ``te_flax.LayerNorm``
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_MLP_START
:end-before: # TE_UNFUSED_MLP_END
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_LAYER_START
:end-before: # TE_UNFUSED_LAYER_END
Benchmark the TE unfused implementation:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_START
:end-before: # BENCHMARK_TE_UNFUSED_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END
TE Unfused + TE Attention
-------------------------
Now let's also replace the attention mechanism with TE's optimized ``DotProductAttention``.
TE's attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration,
delivering optimal performance without manual tuning.
.. tabs::
.. tab:: PyTorch
Replace the custom attention with TE's optimized implementation:
* Custom ``DotProductAttention`` → ``te.DotProductAttention``
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_ATTN_LAYER_START
:end-before: # TE_UNFUSED_ATTN_LAYER_END
Benchmark TE Unfused with TE Attention:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
.. tab:: JAX
Replace Flax's attention with TE's optimized implementation:
* ``nn.dot_product_attention`` → ``te_flax.DotProductAttention``
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_ATTN_LAYER_START
:end-before: # TE_UNFUSED_ATTN_LAYER_END
Benchmark TE Unfused with TE Attention:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
TE Unfused + TE Attention + FP8
-------------------------------
Now let's combine TE modules with TE Attention and enable FP8 precision.
Wrap your code within an ``autocast`` context manager to enable FP8.
This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs).
.. tabs::
.. tab:: PyTorch
.. code-block:: python
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
y = te_unfused(x, attention_mask=None)
.. note::
The ``autocast`` should only wrap the forward pass and must exit before
starting a backward pass.
Benchmark TE Unfused with FP8:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_FP8_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
.. tab:: JAX
.. code-block:: python
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
params = te_unfused.init(key, x, deterministic=False)
y = te_unfused.apply(params, x, deterministic=True)
.. important::
When using FP8 in JAX, the model **must be initialized within the autocast context**
to create the ``fp8_metas`` collection.
Benchmark TE Unfused with FP8:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_FP8_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
TE Fused + TE Attention + FP8: Optimized Modules
------------------------------------------------
Fused modules use kernel fusion to combine multiple operations.
While speedups are modest on a single GPU, they scale better in multi-GPU setups.
Combined with TE Attention and FP8, this delivers peak performance.
.. tabs::
.. tab:: PyTorch
Fused modules available:
* ``te.LayerNormLinear`` - fuses LayerNorm + Linear
* ``te.LayerNormMLP`` - fuses LayerNorm + MLP
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_FUSED_LAYER_START
:end-before: # TE_FUSED_LAYER_END
Benchmark TE Fused with FP8:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_FUSED_FP8_START
:end-before: # BENCHMARK_TE_FUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END
.. tab:: JAX
Fused modules available:
* ``te_flax.LayerNormDenseGeneral`` - fuses LayerNorm + Dense
* ``te_flax.LayerNormMLP`` - fuses LayerNorm + MLP
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_FUSED_LAYER_START
:end-before: # TE_FUSED_LAYER_END
Benchmark TE Fused with FP8:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_FUSED_FP8_START
:end-before: # BENCHMARK_TE_FUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END
TE TransformerLayer + FP8: Ready-to-use Module
----------------------------------------------
For the simplest integration, Transformer Engine provides a ready-to-use ``TransformerLayer``
module that includes all optimizations out of the box.
.. tabs::
.. tab:: PyTorch
Just use ``te.TransformerLayer`` - it handles everything for you:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
.. tab:: JAX
Just use ``te_flax.TransformerLayer`` - it handles everything for you:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Benchmark Summary
-----------------
The table below summarizes the performance improvements achieved with Transformer Engine
on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this
tutorial focuses on a simple single-GPU scenario, features like fused layers can provide
additional benefits in more complex setups such as multi-GPU training.
.. tabs::
.. tab:: PyTorch
.. csv-table::
:header-rows: 1
:widths: 40, 20, 20
:file: getting_started_pytorch_summary.csv
.. tab:: JAX
.. csv-table::
:header-rows: 1
:widths: 40, 20, 20
:file: getting_started_jax_summary.csv
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 700" style="display: block; margin: 0 auto; max-width: 280px;">
<defs>
<style>
.box { fill: #a8c686; stroke: #7a9a5a; stroke-width: 2; }
.circle { fill: #b8d4a0; stroke: #7a9a5a; stroke-width: 2; }
.text { font-family: Arial, sans-serif; font-size: 16px; font-weight: 500; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.arrow { stroke: #6b8fb3; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.skip { stroke: #6b8fb3; stroke-width: 2; fill: none; }
</style>
<marker id="arrowhead" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#6b8fb3"/>
</marker>
</defs>
<!-- Input arrow -->
<line x1="160" y1="5" x2="160" y2="40" class="arrow"/>
<!-- Skip connection 1 (input to first +) -->
<path d="M 160 20 L 280 20 L 280 420" class="skip"/>
<line x1="280" y1="420" x2="185" y2="420" class="arrow"/>
<!-- LayerNorm 1 -->
<rect x="60" y="40" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="62" class="text">LayerNorm</text>
<line x1="160" y1="85" x2="160" y2="110" class="arrow"/>
<!-- QKV Projection -->
<rect x="60" y="110" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="132" class="text">QKV Projection</text>
<line x1="160" y1="155" x2="160" y2="180" class="arrow"/>
<!-- Dot Product Attention -->
<rect x="60" y="180" width="200" height="55" rx="10" ry="10" class="box"/>
<text x="160" y="200" class="text">Dot Product</text>
<text x="160" y="220" class="text">Attention</text>
<line x1="160" y1="235" x2="160" y2="260" class="arrow"/>
<!-- Projection -->
<rect x="60" y="260" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="282" class="text">Projection</text>
<line x1="160" y1="305" x2="160" y2="330" class="arrow"/>
<!-- Dropout -->
<rect x="60" y="330" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="352" class="text">Dropout</text>
<line x1="160" y1="375" x2="160" y2="395" class="arrow"/>
<!-- First + circle -->
<circle cx="160" cy="420" r="25" class="circle"/>
<text x="160" y="420" class="text" font-size="24">+</text>
<line x1="160" y1="445" x2="160" y2="480" class="arrow"/>
<!-- Skip connection 2 (first + to second +) -->
<path d="M 160 455 L 280 455 L 280 640" class="skip"/>
<line x1="280" y1="640" x2="185" y2="640" class="arrow"/>
<!-- LayerNorm 2 -->
<rect x="60" y="480" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="502" class="text">LayerNorm</text>
<line x1="160" y1="525" x2="160" y2="555" class="arrow"/>
<!-- MLP -->
<rect x="60" y="555" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="577" class="text">MLP</text>
<line x1="160" y1="600" x2="160" y2="615" class="arrow"/>
<!-- Second + circle -->
<circle cx="160" cy="640" r="25" class="circle"/>
<text x="160" y="640" class="text" font-size="24">+</text>
<!-- Output arrow -->
<line x1="160" y1="665" x2="160" y2="695" class="arrow"/>
</svg>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment