Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
.. ..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information. See LICENSE for license information.
......
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel
# BENCHMARK_BASELINE_OUTPUT_START
Baseline Flax:
Mean time: 86.580 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 42.252 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.054 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 22.638 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 23.703 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 22.812 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_jax_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - JAX Example
======================================================
This example shows how to build a Transformer decoder layer using JAX/Flax
and how to optimize it with Transformer Engine.
"""
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Optional
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_jax import speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.bfloat16
# Create synthetic data
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)
mesh_resource = MeshResource()
# =============================================================================
# Baseline: Pure Flax Implementation
# =============================================================================
# BASELINE_MLP_START
class FlaxMLP(nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain Flax modules.
"""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class FlaxTransformerLayer(nn.Module):
"""Basic Transformer layer using plain Flax modules."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# Fused QKV projection
qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = baseline.init(key, x, deterministic=False)
print("Baseline Flax:")
time_baseline = speedometer(
baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline"
)
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x)
x = x.reshape(*x.shape[:-1], 1, x.shape[-1])
x = te.activation.activation(x, activation_type=("gelu",))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(nn.Module):
"""Transformer layer using basic TE modules (without TE attention)."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = te_unfused.init(key, x, deterministic=False)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x)
qkv = te_flax.DenseGeneral(
features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16
)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=False, mesh_resource=mesh_resource):
params = te_unfused_attn.init(key, x, deterministic=False)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource},
label="te_unfused_attn",
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_unfused_fp8.init(key, x, deterministic=False)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(nn.Module):
"""Transformer layer using fused TE modules for better performance."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
# Fused LayerNorm + QKV projection
qkv, _ = te_flax.LayerNormDenseGeneral(
features=3 * self.hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
return_layernorm_output=False,
)(x)
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels)
q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
qkv_layout="bshd_bshd_bshd",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
# Fused LayerNorm + MLP
x, _ = te_flax.LayerNormMLP(
intermediate_dim=self.ffn_hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
activations=("gelu",),
intermediate_dropout_rate=0.0,
return_layernorm_output=False,
)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_fused_fp8.init(key, x, deterministic=False)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = te_flax.TransformerLayer(
hidden_size=hidden_size,
mlp_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
mlp_activations=("gelu",),
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
use_bias=True,
attention_dropout=0.0,
intermediate_dropout=0.0,
hidden_dropout=0.0,
enable_relative_embedding=False,
self_attn_bias_type="no_bias",
dtype=jnp.bfloat16,
transpose_batch_sequence=False,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_transformer_layer.init(key, x, deterministic=False)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_jax_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_jax_summary.csv")
Implementation,Time (ms),Speedup
Baseline Flax,86.58,1.00x
TE Unfused,42.25,2.05x
TE Unfused + TE Attention,35.05,2.47x
TE Unfused + TE Attention + FP8,22.64,3.82x
TE Fused + TE Attention + FP8,23.70,3.65x
TE TransformerLayer + FP8,22.81,3.80x
pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64
/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor
registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926
dispatch key: ADInplaceOrView
previous kernel: no debug info
new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.)
self.m.impl(
# BENCHMARK_BASELINE_OUTPUT_START
Baseline PyTorch:
Mean time: 48.280 ms
# BENCHMARK_BASELINE_OUTPUT_END
# BENCHMARK_TE_UNFUSED_OUTPUT_START
TE Unfused:
Mean time: 49.342 ms
# BENCHMARK_TE_UNFUSED_OUTPUT_END
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START
TE Unfused + TE Attention:
Mean time: 35.709 ms
# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START
TE Unfused + TE Attention + FP8:
Mean time: 23.406 ms
# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END
# BENCHMARK_TE_FUSED_FP8_OUTPUT_START
TE Fused + TE Attention + FP8:
Mean time: 22.964 ms
# BENCHMARK_TE_FUSED_FP8_OUTPUT_END
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START
TE TransformerLayer + FP8:
Mean time: 21.670 ms
# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END
Summary written to getting_started_pytorch_summary.csv
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Getting Started with Transformer Engine - PyTorch Example
==========================================================
This example shows how to build a Transformer layer using PyTorch
and how to optimize it with Transformer Engine.
"""
from typing import Optional
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from getting_started_utils_pytorch import DotProductAttention, speedometer
# Configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 8
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.bfloat16
# Create synthetic data
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
# =============================================================================
# Baseline: Pure PyTorch Implementation
# =============================================================================
# BASELINE_MLP_START
class PyTorchMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain PyTorch modules.
"""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# BASELINE_MLP_END
# BASELINE_LAYER_START
class PyTorchTransformerLayer(torch.nn.Module):
"""Basic Transformer layer using plain PyTorch modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = torch.nn.Dropout(hidden_dropout)
self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
# BASELINE_LAYER_END
print("# BENCHMARK_BASELINE_OUTPUT_START")
# BENCHMARK_BASELINE_START
baseline = (
PyTorchTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("Baseline PyTorch:")
time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline")
# BENCHMARK_BASELINE_END
print("# BENCHMARK_BASELINE_OUTPUT_END\n")
# =============================================================================
# TE Unfused: Basic TE Modules
# =============================================================================
# TE_UNFUSED_MLP_START
class TEUnfusedMLP(torch.nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
# TE_UNFUSED_MLP_END
# TE_UNFUSED_LAYER_START
class TEUnfusedTransformerLayer(torch.nn.Module):
"""Transformer layer using basic TE modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_LAYER_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_START
te_unfused = (
TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused"
)
# BENCHMARK_TE_UNFUSED_END
print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n")
# =============================================================================
# TE Unfused + TE Attention
# =============================================================================
# TE_UNFUSED_ATTN_LAYER_START
class TEUnfusedAttnTransformerLayer(torch.nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
# TE_UNFUSED_ATTN_LAYER_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_ATTN_START
te_unfused_attn = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn"
)
# BENCHMARK_TE_UNFUSED_ATTN_END
print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n")
# =============================================================================
# TE Unfused + FP8
# =============================================================================
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_UNFUSED_FP8_START
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_unfused_fp8",
)
# BENCHMARK_TE_UNFUSED_FP8_END
print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE Fused + FP8: Optimized Modules with FP8
# =============================================================================
# TE_FUSED_LAYER_START
class TEFusedTransformerLayer(torch.nn.Module):
"""Transformer layer using fused TE modules for better performance."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
# Fused LayerNorm + QKV projection
self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
# Fused LayerNorm + MLP
self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
# Fused LayerNorm + QKV projection
qkv = self.ln_qkv(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Fused LayerNorm + MLP
res = x
x = self.ln_mlp(x)
x = self.dropout2(x)
return x + res
# TE_FUSED_LAYER_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START")
# BENCHMARK_TE_FUSED_FP8_START
te_fused_fp8 = (
TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_fused_fp8",
)
# BENCHMARK_TE_FUSED_FP8_END
print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n")
# =============================================================================
# TE TransformerLayer + FP8: Ready-to-use Module
# =============================================================================
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START")
# BENCHMARK_TE_TRANSFORMER_LAYER_START
te_transformer_layer = (
te.TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
bias=True,
hidden_dropout=0.0,
attention_dropout=0.0,
)
.to(dtype=dtype)
.cuda()
)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_transformer_layer",
)
# BENCHMARK_TE_TRANSFORMER_LAYER_END
print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n")
# Write summary CSV for RST documentation
with open("getting_started_pytorch_summary.csv", "w") as f:
f.write("Implementation,Time (ms),Speedup\n")
f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n")
f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n")
f.write(
"TE Unfused + TE"
f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n"
)
f.write(
"TE Unfused + TE Attention +"
f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n"
)
f.write(
"TE Fused + TE Attention +"
f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n"
)
f.write(
"TE TransformerLayer +"
f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n"
)
print("\nSummary written to getting_started_pytorch_summary.csv")
Implementation,Time (ms),Speedup
Baseline PyTorch,48.28,1.00x
TE Unfused,49.34,0.98x
TE Unfused + TE Attention,35.71,1.35x
TE Unfused + TE Attention + FP8,23.41,2.06x
TE Fused + TE Attention + FP8,22.96,2.10x
TE TransformerLayer + FP8,21.67,2.23x
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for Getting Started with Transformer Engine - JAX
====================================================================
Helper classes and functions for the getting started examples.
"""
import time
from typing import Callable, Any, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
import transformer_engine.jax as te
from transformer_engine.jax.sharding import MeshResource
def speedometer(
apply_fn: Callable,
params: Any,
x: jnp.ndarray,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 100,
warmup_iters: int = 10,
label: str = "benchmark",
) -> float:
"""Measure average forward + backward pass time for a JAX module.
Args:
apply_fn: JIT-compiled apply function
params: Model parameters
x: Input tensor
forward_kwargs: Additional kwargs for forward pass
autocast_kwargs: Kwargs for te.autocast context
timing_iters: Number of timing iterations
warmup_iters: Number of warmup iterations
label: Optional label for logging
Returns:
Average time per iteration in milliseconds
"""
if autocast_kwargs is None:
autocast_kwargs = {"enabled": False}
else:
autocast_kwargs = dict(autocast_kwargs)
autocast_kwargs.setdefault("mesh_resource", MeshResource())
def loss_fn(params, x):
y = apply_fn(params, x, **forward_kwargs)
return jnp.sum(y)
# JIT compile within autocast context
with te.autocast(**autocast_kwargs):
grad_fn = jax.jit(jax.value_and_grad(loss_fn))
# Warmup runs
for _ in range(warmup_iters):
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
# Timing runs
times = []
for _ in range(timing_iters):
start = time.perf_counter()
loss, grads = grad_fn(params, x)
jax.block_until_ready((loss, grads))
times.append(time.perf_counter() - start)
avg_time = sum(times) / len(times) * 1000
print(f"Mean time: {avg_time:.3f} ms")
return avg_time
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Shared functions for the comm_overlap tests""" """Shared functions for the comm_overlap tests"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment