Unverified Commit a9767407 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[docs] Getting started refactor (#2534)



* docs: Add comprehensive Getting Started guide with benchmarks

- Add new Getting Started documentation with PyTorch and JAX tutorials
- Include benchmark scripts demonstrating TE performance benefits
- Add CSS styling for code output and tabs
- Replace old quickstart notebooks with improved documentation
- Add transformer layer diagram (SVG)
- Update docs configuration and workflow
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* 2026 in copyright
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c90a9214
......@@ -17,7 +17,7 @@ jobs:
uses: actions/checkout@v3
- name: 'Install dependencies'
run: |
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2
pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 sphinx-tabs==3.4.7
pip install breathe==4.35.0 sphinx-autoapi==3.3.2
sudo apt-get install -y pandoc graphviz doxygen
export GIT_SHA=$(git show-ref --hash HEAD)
......
/* Custom styling for program output blocks */
.program-output {
background-color: #f8f9fa;
padding: 0; /* No padding at all */
margin: 0; /* No margins at all */
border-radius: 0; /* No rounded corners */
font-family: 'Courier New', monospace;
font-size: 14px;
line-height: 1.5;
width: 100%;
max-width: 100%;
}
.program-output pre {
margin: 0;
padding: 0;
background: transparent !important;
border: none !important;
color: #2c3e50;
width: 100%;
}
.program-output .highlight {
background: transparent !important;
margin: 0;
width: 100%;
}
/* Alternative lighter style */
.output-block {
background-color: #fafbfc;
border: 1px solid #e1e4e8;
padding: 10px 14px;
margin: 10px 0;
border-radius: 3px;
font-family: 'SF Mono', 'Consolas', monospace;
font-size: 13px;
color: #24292e;
}
/* Console-like output style */
.console-output {
background-color: #1e1e1e;
border-left: 3px solid #76b900;
padding: 14px 18px;
margin: 12px 0;
border-radius: 5px;
font-family: 'Fira Code', 'Consolas', monospace;
font-size: 13px;
color: #d4d4d4;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.console-output pre {
margin: 0;
color: #d4d4d4;
background: transparent !important;
}
/* Custom styling for sphinx-tabs */
.sphinx-tabs {
margin-bottom: 1rem;
}
.sphinx-tabs-tab {
background-color: #f4f4f4;
border: 1px solid #ccc;
border-bottom: none;
padding: 0.5rem 1rem;
margin-right: 0.5rem;
cursor: pointer;
font-weight: 500;
transition: background-color 0.2s;
}
.sphinx-tabs-tab:hover {
background-color: #e0e0e0;
}
.sphinx-tabs-tab[aria-selected="true"] {
background-color: #76b900; /* NVIDIA green */
color: white;
border-color: #76b900;
margin-right: 0.5rem;
}
.sphinx-tabs-panel {
border: 1px solid #ccc;
padding: 1rem;
background-color: #f9f9f9;
}
/* Dark mode support for RTD theme */
.rst-content .sphinx-tabs-tab {
color: #333;
}
.rst-content .sphinx-tabs-tab[aria-selected="true"] {
color: white;
}
......@@ -58,6 +58,7 @@ extensions = [
"nbsphinx",
"breathe",
"autoapi.extension",
"sphinx_tabs.tabs",
]
templates_path = ["_templates"]
......@@ -83,6 +84,8 @@ html_show_sphinx = False
html_css_files = [
"css/nvidia_font.css",
"css/nvidia_footer.css",
"css/rtabs.css",
"css/output-style.css",
]
html_theme_options = {
......
......@@ -13,7 +13,7 @@
"id": "6dcbf25a",
"metadata": {},
"source": [
"This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
"This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). "
]
},
{
......
This diff is collapsed.
This diff is collapsed.
......@@ -5,13 +5,9 @@
import jax
import jax.numpy as jnp
import time
import math
from typing import Callable, Any, Dict, Optional, Tuple
from flax import linen as nn
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
def speedometer(
......
......@@ -264,7 +264,7 @@
"id": "5e9310c9",
"metadata": {},
"source": [
"# Transformer Engine"
"## Transformer Engine"
]
},
{
......
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Choose your framework to get started with Transformer Engine:
.. toctree::
:maxdepth: 1
PyTorch <examples/quickstart.ipynb>
JAX <examples/quickstart_jax.ipynb>
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
# BENCHMARK_BASELINE_OUTPUT_START
Baseline Flax:
Mean time: 86.580 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 42.252 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.054 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 22.638 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 23.703 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 22.812 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_jax_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - JAX Example
======================================================
This example shows how to build a Transformer decoder layer using JAX/Flax
and how to optimize it with Transformer Engine.
"""
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Optional
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_jax import speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.bfloat16
# Create synthetic data
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)
mesh_resource = MeshResource()
# =============================================================================
# Baseline: Pure Flax Implementation
# =============================================================================
# BASELINE_MLP_START
class FlaxMLP(nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain Flax modules.
"""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class FlaxTransformerLayer(nn.Module):
"""Basic Transformer layer using plain Flax modules."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# Fused QKV projection
qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = baseline.init(key, x, deterministic=False)
print("Baseline Flax:")
time_baseline = speedometer(
baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline"
)
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x)
x = x.reshape(*x.shape[:-1], 1, x.shape[-1])
x = te.activation.activation(x, activation_type=("gelu",))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(nn.Module):
"""Transformer layer using basic TE modules (without TE attention)."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = te_unfused.init(key, x, deterministic=False)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x)
qkv = te_flax.DenseGeneral(
features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16
)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=False, mesh_resource=mesh_resource):
params = te_unfused_attn.init(key, x, deterministic=False)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource},
label="te_unfused_attn",
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_unfused_fp8.init(key, x, deterministic=False)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(nn.Module):
"""Transformer layer using fused TE modules for better performance."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
# Fused LayerNorm + QKV projection
qkv, _ = te_flax.LayerNormDenseGeneral(
features=3 * self.hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
return_layernorm_output=False,
)(x)
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels)
q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
qkv_layout="bshd_bshd_bshd",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
# Fused LayerNorm + MLP
x, _ = te_flax.LayerNormMLP(
intermediate_dim=self.ffn_hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
activations=("gelu",),
intermediate_dropout_rate=0.0,
return_layernorm_output=False,
)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_fused_fp8.init(key, x, deterministic=False)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = te_flax.TransformerLayer(
hidden_size=hidden_size,
mlp_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
mlp_activations=("gelu",),
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
use_bias=True,
attention_dropout=0.0,
intermediate_dropout=0.0,
hidden_dropout=0.0,
enable_relative_embedding=False,
self_attn_bias_type="no_bias",
dtype=jnp.bfloat16,
transpose_batch_sequence=False,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_transformer_layer.init(key, x, deterministic=False)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_jax_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_jax_summary.csv")
Implementation,Time (ms),Speedup
Baseline Flax,86.58,1.00x
TE Unfused,42.25,2.05x
TE Unfused + TE Attention,35.05,2.47x
TE Unfused + TE Attention + FP8,22.64,3.82x
TE Fused + TE Attention + FP8,23.70,3.65x
TE TransformerLayer + FP8,22.81,3.80x
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926
dispatch key: ADInplaceOrView
previous kernel: no debug info
new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
self.m.impl(
# BENCHMARK_BASELINE_OUTPUT_START
Baseline PyTorch:
Mean time: 48.280 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 49.342 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.709 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 23.406 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 22.964 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 21.670 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_pytorch_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - PyTorch Example
==========================================================
This example shows how to build a Transformer layer using PyTorch
and how to optimize it with Transformer Engine.
"""
from typing import Optional
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_pytorch import DotProductAttention, speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16
# Create synthetic data
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
# =============================================================================
# Baseline: Pure PyTorch Implementation
# =============================================================================
# BASELINE_MLP_START
class PyTorchMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain PyTorch modules.
"""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class PyTorchTransformerLayer(torch.nn.Module):
"""Basic Transformer layer using plain PyTorch modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = torch.nn.Dropout(hidden_dropout)
self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = (
PyTorchTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("Baseline PyTorch:")
time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline")
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(torch.nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(torch.nn.Module):
"""Transformer layer using basic TE modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = (
TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(torch.nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn"
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(torch.nn.Module):
"""Transformer layer using fused TE modules for better performance."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
# Fused LayerNorm + QKV projection
self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
# Fused LayerNorm + MLP
self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
# Fused LayerNorm + QKV projection
qkv = self.ln_qkv(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Fused LayerNorm + MLP
res = x
x = self.ln_mlp(x)
x = self.dropout2(x)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = (
TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = (
te.TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
bias=True,
hidden_dropout=0.0,
attention_dropout=0.0,
)
.to(dtype=dtype)
.cuda()
)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_pytorch_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_pytorch_summary.csv")
Implementation,Time (ms),Speedup
Baseline PyTorch,48.28,1.00x
TE Unfused,49.34,0.98x
TE Unfused + TE Attention,35.71,1.35x
TE Unfused + TE Attention + FP8,23.41,2.06x
TE Fused + TE Attention + FP8,22.96,2.10x
TE TransformerLayer + FP8,21.67,2.23x
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for Getting Started with Transformer Engine - JAX
====================================================================
Helper classes and functions for the getting started examples.
"""
import time
from typing import Callable, Any, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
import transformer_engine.jax as te
from transformer_engine.jax.sharding import MeshResource
def speedometer(
apply_fn: Callable,
params: Any,
x: jnp.ndarray,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 100,
warmup_iters: int = 10,
label: str = "benchmark",
) -> float:
"""Measure average forward + backward pass time for a JAX module.
Args:
apply_fn: JIT-compiled apply function
params: Model parameters
x: Input tensor
forward_kwargs: Additional kwargs for forward pass
autocast_kwargs: Kwargs for te.autocast context
timing_iters: Number of timing iterations
warmup_iters: Number of warmup iterations
label: Optional label for logging
Returns:
Average time per iteration in milliseconds
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
else:
autocast_kwargs = dict(autocast_kwargs)
autocast_kwargs.setdefault("mesh_resource", MeshResource())
def loss_fn(params, x):
y = apply_fn(params, x, **forward_kwargs)
return jnp.sum(y)
# JIT compile within autocast context
with te.autocast(**autocast_kwargs):
grad_fn = jax.jit(jax.value_and_grad(loss_fn))
# Warmup runs
for _ in range(warmup_iters):
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
# Timing runs
times = []
for _ in range(timing_iters):
start = time.perf_counter()
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
times.append(time.perf_counter() - start)
avg_time = sum(times) / len(times) * 1000
print(f"Mean time: {avg_time:.3f} ms")
return avg_time
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for Getting Started with Transformer Engine - PyTorch
========================================================================
Helper classes and functions for the getting started examples.
"""
import math
from typing import Optional
import torch
import transformer_engine.pytorch as te
def speedometer(
module: torch.nn.Module,
x: torch.Tensor,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 100,
warmup_iters: int = 10,
label: str = "benchmark",
) -> float:
"""Measure average forward + backward pass time for a PyTorch module.
Args:
module: PyTorch module to benchmark
x: Input tensor
forward_kwargs: Additional kwargs for forward pass
autocast_kwargs: Kwargs for te.autocast context
timing_iters: Number of timing iterations
warmup_iters: Number of warmup iterations
label: Optional label for logging
Returns:
Average time per iteration in milliseconds
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
with te.autocast(**autocast_kwargs):
y = module(x, **forward_kwargs)
loss = y.sum()
loss.backward()
torch.cuda.synchronize()
# Timing runs
start.record()
for _ in range(timing_iters):
with te.autocast(**autocast_kwargs):
y = module(x, **forward_kwargs)
loss = y.sum()
loss.backward()
end.record()
torch.cuda.synchronize()
avg_time = start.elapsed_time(end) / timing_iters
print(f"Mean time: {avg_time:.3f} ms")
return avg_time
class DotProductAttention(torch.nn.Module):
"""Attention operation in Transformer layer.
Built with plain PyTorch modules.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
) -> None:
super().__init__()
self.projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout)
def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
if mask is not None:
inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b = query.size(1)
np = query.size(2)
sq = query.size(0)
sk = key.size(0)
hn = value.size(3)
query = query.view(sq, b * np, -1)
key = key.view(sk, b * np, -1)
bmm1 = (
torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
)
attention_scores = bmm1.view(b, np, sq, sk)
attention_probs = self.masked_softmax(attention_scores, attention_mask)
attention_probs = self.dropout(attention_probs)
value = value.view(sk, b * np, -1)
attention_probs = attention_probs.view(b * np, sq, -1)
context = torch.bmm(attention_probs, value.transpose(0, 1))
context = context.view(b, np, sq, hn)
context = context.permute(2, 0, 1, 3).contiguous()
context = context.view(sq, b, self.projection_size)
return context
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Getting Started
===============
Overview
--------
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs,
providing better performance with lower memory utilization in both training and inference.
It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as
8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs.
TE implements a collection of highly optimized building blocks for popular Transformer
architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly
with your deep learning code.
Currently two frameworks are supported: PyTorch and JAX.
.. tabs::
.. tab:: PyTorch
Basic knowledge of PyTorch is recommended:
- `PyTorch Tutorials <https://pytorch.org/tutorials/>`_
- `PyTorch Documentation <https://pytorch.org/docs/stable/index.html>`_
.. tab:: JAX
We recommend understanding the basics of JAX first:
- `Thinking in JAX <https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html>`_
- `JAX 101 <https://docs.jax.dev/en/latest/jax-101.html>`_
- `Key concepts in JAX <https://docs.jax.dev/en/latest/key-concepts.html>`_
- `Flax 101 <https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html>`_
Baseline: Pure Framework Implementation
---------------------------------------
Let's build a Transformer decoder layer!
We'll create a basic GPT-style layer with causal masking,
which prevents each position from attending to future positions. This will be our baseline
for later comparisons with Transformer Engine.
.. raw:: html
:file: transformer_layer.svg
.. raw:: html
<p style="text-align: center; font-style: italic; color: #666;">Structure of a GPT decoder layer</p>
We construct the components as follows:
.. tabs::
.. tab:: PyTorch
* **LayerNorm**: ``torch.nn.LayerNorm``
* **QKV Projection**: ``torch.nn.Linear`` (fused Q, K, V into single layer 3x larger)
* **DotProductAttention**: Custom implementation using ``torch.bmm``
* **Projection**: ``torch.nn.Linear``
* **Dropout**: ``torch.nn.Dropout``
* **MLP**: Two ``torch.nn.Linear`` layers with ``torch.nn.functional.gelu`` activation
.. tab:: JAX
* **LayerNorm**: ``nn.LayerNorm``
* **QKV Projection**: ``nn.Dense`` (fused Q, K, V into single layer 3x larger)
* **DotProductAttention**: ``nn.dot_product_attention``
* **Projection**: ``nn.Dense``
* **Dropout**: ``nn.Dropout``
* **MLP**: Two ``nn.Dense`` layers with ``nn.gelu`` activation
Putting it all together:
.. tabs::
.. tab:: PyTorch
First, define the MLP block:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BASELINE_MLP_START
:end-before: # BASELINE_MLP_END
Now, putting it all together into a GPT decoder layer:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BASELINE_LAYER_START
:end-before: # BASELINE_LAYER_END
Benchmark the baseline implementation:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_BASELINE_START
:end-before: # BENCHMARK_BASELINE_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_BASELINE_OUTPUT_START
:end-before: # BENCHMARK_BASELINE_OUTPUT_END
.. tab:: JAX
First, define the MLP block:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BASELINE_MLP_START
:end-before: # BASELINE_MLP_END
Now, putting it all together into a GPT decoder layer:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BASELINE_LAYER_START
:end-before: # BASELINE_LAYER_END
Benchmark the baseline implementation:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_BASELINE_START
:end-before: # BENCHMARK_BASELINE_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_BASELINE_OUTPUT_START
:end-before: # BENCHMARK_BASELINE_OUTPUT_END
TE Unfused: Basic TE Modules
----------------------------
Now let's replace the standard framework modules with TE equivalents.
This is the simplest way to start using Transformer Engine.
.. tabs::
.. tab:: PyTorch
Replace PyTorch modules with TE equivalents:
.. code-block:: python
import transformer_engine.pytorch as te
Mapping:
* ``torch.nn.Linear`` → ``te.Linear``
* ``torch.nn.LayerNorm`` → ``te.LayerNorm``
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_MLP_START
:end-before: # TE_UNFUSED_MLP_END
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_LAYER_START
:end-before: # TE_UNFUSED_LAYER_END
Benchmark the TE unfused implementation:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_START
:end-before: # BENCHMARK_TE_UNFUSED_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END
.. tab:: JAX
Replace Flax modules with TE equivalents:
.. code-block:: python
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
Mapping:
* ``nn.Dense`` → ``te_flax.DenseGeneral``
* ``nn.LayerNorm`` → ``te_flax.LayerNorm``
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_MLP_START
:end-before: # TE_UNFUSED_MLP_END
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_LAYER_START
:end-before: # TE_UNFUSED_LAYER_END
Benchmark the TE unfused implementation:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_START
:end-before: # BENCHMARK_TE_UNFUSED_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END
TE Unfused + TE Attention
-------------------------
Now let's also replace the attention mechanism with TE's optimized ``DotProductAttention``.
TE's attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration,
delivering optimal performance without manual tuning.
.. tabs::
.. tab:: PyTorch
Replace the custom attention with TE's optimized implementation:
* Custom ``DotProductAttention`` → ``te.DotProductAttention``
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_UNFUSED_ATTN_LAYER_START
:end-before: # TE_UNFUSED_ATTN_LAYER_END
Benchmark TE Unfused with TE Attention:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
.. tab:: JAX
Replace Flax's attention with TE's optimized implementation:
* ``nn.dot_product_attention`` → ``te_flax.DotProductAttention``
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_UNFUSED_ATTN_LAYER_START
:end-before: # TE_UNFUSED_ATTN_LAYER_END
Benchmark TE Unfused with TE Attention:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
TE Unfused + TE Attention + FP8
-------------------------------
Now let's combine TE modules with TE Attention and enable FP8 precision.
Wrap your code within an ``autocast`` context manager to enable FP8.
This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs).
.. tabs::
.. tab:: PyTorch
.. code-block:: python
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
y = te_unfused(x, attention_mask=None)
.. note::
The ``autocast`` should only wrap the forward pass and must exit before
starting a backward pass.
Benchmark TE Unfused with FP8:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_FP8_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
.. tab:: JAX
.. code-block:: python
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
params = te_unfused.init(key, x, deterministic=False)
y = te_unfused.apply(params, x, deterministic=True)
.. important::
When using FP8 in JAX, the model **must be initialized within the autocast context**
to create the ``fp8_metas`` collection.
Benchmark TE Unfused with FP8:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_UNFUSED_FP8_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
TE Fused + TE Attention + FP8: Optimized Modules
------------------------------------------------
Fused modules use kernel fusion to combine multiple operations.
While speedups are modest on a single GPU, they scale better in multi-GPU setups.
Combined with TE Attention and FP8, this delivers peak performance.
.. tabs::
.. tab:: PyTorch
Fused modules available:
* ``te.LayerNormLinear`` - fuses LayerNorm + Linear
* ``te.LayerNormMLP`` - fuses LayerNorm + MLP
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # TE_FUSED_LAYER_START
:end-before: # TE_FUSED_LAYER_END
Benchmark TE Fused with FP8:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_FUSED_FP8_START
:end-before: # BENCHMARK_TE_FUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END
.. tab:: JAX
Fused modules available:
* ``te_flax.LayerNormDenseGeneral`` - fuses LayerNorm + Dense
* ``te_flax.LayerNormMLP`` - fuses LayerNorm + MLP
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # TE_FUSED_LAYER_START
:end-before: # TE_FUSED_LAYER_END
Benchmark TE Fused with FP8:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_FUSED_FP8_START
:end-before: # BENCHMARK_TE_FUSED_FP8_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START
:end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END
TE TransformerLayer + FP8: Ready-to-use Module
----------------------------------------------
For the simplest integration, Transformer Engine provides a ready-to-use ``TransformerLayer``
module that includes all optimizations out of the box.
.. tabs::
.. tab:: PyTorch
Just use ``te.TransformerLayer`` - it handles everything for you:
.. literalinclude:: getting_started_pytorch.py
:language: python
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_pytorch.out
:language: text
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
.. tab:: JAX
Just use ``te_flax.TransformerLayer`` - it handles everything for you:
.. literalinclude:: getting_started_jax.py
:language: python
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END
.. raw:: html
<div style="background: #f5f5f5; border-left: 3px solid #9ca3af; padding: 4px 12px; font-size: 12px; color: #6b7280; margin-top: -16px;">
Output:
</div>
.. container:: program-output
.. literalinclude:: getting_started_jax.out
:language: text
:start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
:end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Benchmark Summary
-----------------
The table below summarizes the performance improvements achieved with Transformer Engine
on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this
tutorial focuses on a simple single-GPU scenario, features like fused layers can provide
additional benefits in more complex setups such as multi-GPU training.
.. tabs::
.. tab:: PyTorch
.. csv-table::
:header-rows: 1
:widths: 40, 20, 20
:file: getting_started_pytorch_summary.csv
.. tab:: JAX
.. csv-table::
:header-rows: 1
:widths: 40, 20, 20
:file: getting_started_jax_summary.csv
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 700" style="display: block; margin: 0 auto; max-width: 280px;">
<defs>
<style>
.box { fill: #a8c686; stroke: #7a9a5a; stroke-width: 2; }
.circle { fill: #b8d4a0; stroke: #7a9a5a; stroke-width: 2; }
.text { font-family: Arial, sans-serif; font-size: 16px; font-weight: 500; fill: #333; text-anchor: middle; dominant-baseline: middle; }
.arrow { stroke: #6b8fb3; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.skip { stroke: #6b8fb3; stroke-width: 2; fill: none; }
</style>
<marker id="arrowhead" markerWidth="8" markerHeight="6" refX="8" refY="3" orient="auto">
<polygon points="0 0, 8 3, 0 6" fill="#6b8fb3"/>
</marker>
</defs>
<!-- Input arrow -->
<line x1="160" y1="5" x2="160" y2="40" class="arrow"/>
<!-- Skip connection 1 (input to first +) -->
<path d="M 160 20 L 280 20 L 280 420" class="skip"/>
<line x1="280" y1="420" x2="185" y2="420" class="arrow"/>
<!-- LayerNorm 1 -->
<rect x="60" y="40" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="62" class="text">LayerNorm</text>
<line x1="160" y1="85" x2="160" y2="110" class="arrow"/>
<!-- QKV Projection -->
<rect x="60" y="110" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="132" class="text">QKV Projection</text>
<line x1="160" y1="155" x2="160" y2="180" class="arrow"/>
<!-- Dot Product Attention -->
<rect x="60" y="180" width="200" height="55" rx="10" ry="10" class="box"/>
<text x="160" y="200" class="text">Dot Product</text>
<text x="160" y="220" class="text">Attention</text>
<line x1="160" y1="235" x2="160" y2="260" class="arrow"/>
<!-- Projection -->
<rect x="60" y="260" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="282" class="text">Projection</text>
<line x1="160" y1="305" x2="160" y2="330" class="arrow"/>
<!-- Dropout -->
<rect x="60" y="330" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="352" class="text">Dropout</text>
<line x1="160" y1="375" x2="160" y2="395" class="arrow"/>
<!-- First + circle -->
<circle cx="160" cy="420" r="25" class="circle"/>
<text x="160" y="420" class="text" font-size="24">+</text>
<line x1="160" y1="445" x2="160" y2="480" class="arrow"/>
<!-- Skip connection 2 (first + to second +) -->
<path d="M 160 455 L 280 455 L 280 640" class="skip"/>
<line x1="280" y1="640" x2="185" y2="640" class="arrow"/>
<!-- LayerNorm 2 -->
<rect x="60" y="480" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="502" class="text">LayerNorm</text>
<line x1="160" y1="525" x2="160" y2="555" class="arrow"/>
<!-- MLP -->
<rect x="60" y="555" width="200" height="45" rx="10" ry="10" class="box"/>
<text x="160" y="577" class="text">MLP</text>
<line x1="160" y1="600" x2="160" y2="615" class="arrow"/>
<!-- Second + circle -->
<circle cx="160" cy="640" r="25" class="circle"/>
<text x="160" y="640" class="text" font-size="24">+</text>
<!-- Output arrow -->
<line x1="160" y1="665" x2="160" y2="695" class="arrow"/>
</svg>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment