Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import re
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for distributed/sharded execution of MoE permutation primitives.
Testing Strategy:
=================
MoE permutation is data-dependent - the destination index for each token depends
on how many tokens before it are routed to the same expert. This means:
1. We CANNOT compare sharded output against global reference directly
2. Instead, we verify that each GPU's LOCAL output is correct according to its
LOCAL routing (which produces LOCAL row_id_map with LOCAL indices)
For data-parallel MoE without expert parallelism:
- Each GPU has ALL experts replicated
- Each GPU processes a subset of tokens (sharded on token/batch dimension)
- Each GPU computes its own local row_id_map from its local routing_map slice
- Each GPU's output is local and doesn't need to match global output
These tests verify:
1. Local token_dispatch: sharded input -> local row_id_map -> local permute (forward + backward)
2. Local roundtrip: dispatch + combine recovers original input (forward + backward)
"""
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
)
# Reference implementations from test_permutation.py
from test_permutation import (
reference_make_row_id_map,
_reference_permute_impl,
_reference_unpermute_impl,
reference_token_combine,
)
# Dispatch/combine test cases: (num_tokens, num_experts, hidden_size, topk)
# topk = number of experts each token is routed to
# Includes small, medium-large, and largest stress test cases.
ALL_DISPATCH_COMBINE_CASES = [
(128, 4, 64, 2),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:1],
"L2": ALL_DISPATCH_COMBINE_CASES,
}
# Dispatch/combine with padding test cases: (num_tokens, num_experts, hidden_size, topk, align_size)
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 4, 64, 2, 8),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:1],
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}
# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": [jnp.float32],
"L2": ALL_DTYPES,
}
class TestDistributedPermutation:
"""Test distributed/sharded execution of MoE permutation primitives.
These tests validate that custom partitioning produces correct LOCAL results
when inputs are sharded across multiple devices.
Key insight: With data-parallel MoE, each GPU independently processes its
local tokens. The row_id_map is generated locally and contains LOCAL indices.
We verify correctness by comparing each shard's output against the reference
implementation run on that shard's local data.
"""
@staticmethod
def compute_padded_output_size(
num_tokens: int,
num_experts: int,
topk: int,
align_size: int,
num_dp_devices: int,
) -> int:
"""Compute global_num_out_tokens for distributed padding tests.
Each device processes local_num_tokens tokens. We compute the worst-case
padded output size per device, then multiply by num_dp_devices to get
a global size that ensures global / num_dp >= local_worst.
"""
local_num_tokens = num_tokens // num_dp_devices
local_raw_out = local_num_tokens * topk
local_worst = ((local_raw_out + num_experts * (align_size - 1)) // align_size) * align_size
return local_worst * num_dp_devices
@staticmethod
def generate_routing_map(
num_tokens: int,
num_experts: int,
topk: int = 2, # Number of experts each token is routed to (max 1s per row).
key: jax.Array = None,
):
if key is None:
key = jax.random.PRNGKey(0)
routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
for token_idx in range(num_tokens):
key, subkey = jax.random.split(key)
expert_indices = jax.random.choice(subkey, num_experts, shape=(topk,), replace=False)
routing_map = routing_map.at[token_idx, expert_indices].set(1)
return routing_map
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_token_dispatch(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_shardy,
):
"""
Test token_dispatch with sharded inputs.
Verifies that sharded execution produces the same result as chunk-wise
reference execution. The sharded primitive:
1. Receives global num_out_tokens (partition function divides it)
2. Each GPU operates on its local shard independently
3. Results are gathered (concatenated) across GPUs
Output ordering: [GPU0_expert0, GPU0_expert1, ... | GPU1_expert0, ...]
The reference processes each chunk independently and concatenates,
matching the sharded execution's output ordering.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Shard on token (batch) dimension
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
# Compute num_out_tokens as concrete values
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
# Local num_out_tokens is used for reference implementation
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
global_num_out_tokens = num_tokens * topk
local_num_tokens = num_tokens // num_dp_devices
local_num_out_tokens = local_num_tokens * topk
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
probs_sharding = NamedSharding(mesh, sharded_pspec)
# Shard the inputs
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
probs_sharded = jax.device_put(probs, probs_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def target_dispatch(x, rm, p):
# Pass global num_out_tokens - partition function divides it
out, perm_probs, rid_map, _, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p
)
return out, perm_probs, rid_map
# Reference: process each GPU's shard independently, then concatenate
# This matches how the sharded primitive operates:
# - Each GPU processes its local shard
# - Results are gathered (concatenated) across GPUs
# Output ordering: [GPU0_exp0, GPU0_exp1, ... | GPU1_exp0, GPU1_exp1, ...]
inp_shards = jnp.reshape(inp, (num_dp_devices, local_num_tokens, hidden_size))
routing_shards = jnp.reshape(
routing_map, (num_dp_devices, local_num_tokens, num_experts)
)
probs_shards = jnp.reshape(probs, (num_dp_devices, local_num_tokens, num_experts))
ref_outputs = []
ref_perm_probs_list = []
ref_rid_maps = []
for i in range(num_dp_devices):
shard_rid_map = reference_make_row_id_map(routing_shards[i])
shard_out, shard_perm_probs = _reference_permute_impl(
inp_shards[i], shard_rid_map, probs_shards[i], local_num_out_tokens
)
ref_outputs.append(shard_out)
ref_perm_probs_list.append(shard_perm_probs)
ref_rid_maps.append(shard_rid_map)
# Concatenate like all_gather would
ref_out = jnp.concatenate(ref_outputs, axis=0)
ref_perm_probs = jnp.concatenate(ref_perm_probs_list, axis=0)
ref_rid_map = jnp.concatenate(ref_rid_maps, axis=0)
# Run target on sharded inputs
target_out, target_perm_probs, target_rid_map = target_dispatch(
inp_sharded, routing_sharded, probs_sharded
)
# Compare forward outputs
assert_allclose(jax.device_get(target_out), ref_out, dtype=dtype)
assert_allclose(jax.device_get(target_perm_probs), ref_perm_probs, dtype=dtype)
# Verify row_id_map n_routed column matches routing_map sum
target_rid_map_np = jax.device_get(target_rid_map)
assert jnp.array_equal(
target_rid_map_np[:, -1], ref_rid_map[:, -1]
), "n_routed column mismatch"
# Sanity checks
target_out_np = jax.device_get(target_out)
target_perm_probs_np = jax.device_get(target_perm_probs)
assert not np.any(np.isnan(target_out_np)), "Output contains NaN"
assert not np.any(np.isnan(target_perm_probs_np)), "Permuted probs contain NaN"
assert np.all(target_perm_probs_np >= 0), "Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def target_loss(x, rm, p):
out, perm_probs, _, _, _ = token_dispatch(x, rm, global_num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
# Reference loss: process chunks independently and sum
def ref_chunk_loss(inp_chunk, routing_chunk, probs_chunk):
rid_map = reference_make_row_id_map(routing_chunk)
out, perm_probs = _reference_permute_impl(
inp_chunk, rid_map, probs_chunk, local_num_out_tokens
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
target_grad_fn = jax.jit(jax.grad(target_loss, argnums=(0, 2)))
ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss, argnums=(0, 2)))
target_inp_grad, target_probs_grad = target_grad_fn(
inp_sharded, routing_sharded, probs_sharded
)
# Compute reference gradients per chunk, then concatenate
ref_inp_grads = []
ref_probs_grads = []
for i in range(num_dp_devices):
chunk_inp_grad, chunk_probs_grad = ref_chunk_grad_fn(
inp_shards[i], routing_shards[i], probs_shards[i]
)
ref_inp_grads.append(chunk_inp_grad)
ref_probs_grads.append(chunk_probs_grad)
ref_inp_grad = jnp.concatenate(ref_inp_grads, axis=0)
ref_probs_grad = jnp.concatenate(ref_probs_grads, axis=0)
assert_allclose(jax.device_get(target_inp_grad), ref_inp_grad, dtype=dtype)
assert_allclose(jax.device_get(target_probs_grad), ref_probs_grad, dtype=dtype)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_roundtrip(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_shardy,
):
"""
Test roundtrip: token_dispatch followed by token_combine with sharded inputs.
Each GPU:
1. Gets a shard of the input and routing_map
2. Performs local dispatch (permute)
3. Performs local combine (unpermute)
4. With uniform merging probs, should recover original input
Tests both forward pass and backward pass (gradient should be 2*x).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key = jax.random.split(key, 2)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
# Uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
# Compute num_out_tokens as concrete value
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
global_num_out_tokens = num_tokens * topk
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
merging_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def roundtrip(x, rm, mprobs):
dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens)
return token_combine(dispatched, rid_map, mprobs)
roundtrip_out = roundtrip(inp_sharded, routing_sharded, merging_sharded)
# Should recover original input
assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def roundtrip_loss(x, rm, mprobs):
dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens)
combined = token_combine(dispatched, rid_map, mprobs)
return jnp.sum(combined**2)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn = jax.jit(jax.grad(roundtrip_loss, argnums=0))
computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded)
expected_grad = 2.0 * inp_sharded
assert_allclose(
jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_token_dispatch_with_padding(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
align_size,
dtype,
use_shardy,
):
"""
Test token_dispatch with padding using sharded inputs.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens = self.compute_padded_output_size(
num_tokens, num_experts, topk, align_size, num_dp_devices
)
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
probs_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
probs_sharded = jax.device_put(probs, probs_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def dispatch_with_padding(x, rm, p):
out, perm_probs, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p, align_size=align_size
)
return out, perm_probs, rid_map, pad_offsets
out, perm_probs, rid_map, pad_offsets = dispatch_with_padding(
inp_sharded, routing_sharded, probs_sharded
)
# Sanity checks
out_np = jax.device_get(out)
perm_probs_np = jax.device_get(perm_probs)
assert not np.any(np.isnan(out_np)), "Output contains NaN"
assert not np.any(np.isnan(perm_probs_np)), "Permuted probs contain NaN"
assert np.all(perm_probs_np >= 0), "Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def loss_with_padding(x, rm, p):
out, perm_probs, _, _, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p, align_size=align_size
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
grad_fn = jax.jit(jax.grad(loss_with_padding, argnums=(0, 2)))
inp_grad, probs_grad = grad_fn(inp_sharded, routing_sharded, probs_sharded)
# Gradients should not contain NaN
assert not np.any(np.isnan(jax.device_get(inp_grad))), "Input gradient contains NaN"
assert not np.any(np.isnan(jax.device_get(probs_grad))), "Probs gradient contains NaN"
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_roundtrip_with_padding(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
align_size,
dtype,
use_shardy,
):
"""
Test roundtrip with padding/alignment using sharded inputs.
With uniform merging probs, should recover original input.
Tests both forward pass and backward pass.
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate inputs
key, inp_key = jax.random.split(key, 2)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
# Uniform merging probs
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens = self.compute_padded_output_size(
num_tokens, num_experts, topk, align_size, num_dp_devices
)
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
merging_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def roundtrip_with_padding(x, rm, mprobs):
dispatched, _, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, align_size=align_size
)
return token_combine(dispatched, rid_map, mprobs, pad_offsets)
roundtrip_out = roundtrip_with_padding(inp_sharded, routing_sharded, merging_sharded)
# Should recover original input
assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def roundtrip_loss_with_padding(x, rm, mprobs):
dispatched, _, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, align_size=align_size
)
combined = token_combine(dispatched, rid_map, mprobs, pad_offsets)
return jnp.sum(combined**2)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn = jax.jit(jax.grad(roundtrip_loss_with_padding, argnums=0))
computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded)
expected_grad = 2.0 * inp_sharded
assert_allclose(
jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import autocast
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
DTYPES = [jnp.float16, jnp.bfloat16]
......@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(
self, shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
self, shape, mesh_resource, softmax_fusion_type, dtype, bad_sharding, broadcast_batch_mask
):
batch, _, sqelen, _ = shape
x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
if softmax_type == SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
if softmax_fusion_type == SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
mask = make_causal_mask(batch, sqelen)
else:
mask = make_self_mask(1 if broadcast_batch_mask else batch, sqelen)
......@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
def target_func(x, mask, scale_factor=1.0, softmax_fusion_type=SoftmaxFusionType.SCALED):
return jnp.mean(
softmax(x, mask, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type)
)
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
......@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
broadcast_batch_mask,
use_shardy,
):
if broadcast_batch_mask and softmax_type != SoftmaxType.SCALED_MASKED:
if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED:
pytest.skip("Softmax type has no mask.")
jax.config.update("jax_use_shardy_partitioner", use_shardy)
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding, broadcast_batch_mask
data_shape,
mesh_resource,
softmax_fusion_type,
dtype,
bad_sharding,
broadcast_batch_mask,
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
......@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]])
@pytest.mark.parametrize(
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
"softmax_fusion_type",
[
SoftmaxFusionType.SCALED,
SoftmaxFusionType.SCALED_MASKED,
SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
softmax_fusion_type,
scale_factor,
dtype,
bad_sharding,
......@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED])
@pytest.mark.parametrize(
"softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED]
)
@pytest.mark.parametrize("bad_sharding", [False, True])
@pytest.mark.parametrize("broadcast_batch_mask", [False, True])
def test_softmax_gspmd(
......@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape,
mesh_axes,
mesh_resource,
softmax_type,
softmax_fusion_type,
bad_sharding,
broadcast_batch_mask,
):
......@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes,
mesh_resource,
data_shape=[32, 12, 128, 128],
softmax_type=softmax_type,
softmax_fusion_type=softmax_fusion_type,
scale_factor=1.0,
dtype=DTYPES[0],
bad_sharding=bad_sharding,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for fused attention"""
......@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
AttnSoftmaxType,
QKVLayout,
QKVFormat,
reorder_causal_load_balancing,
......@@ -59,14 +60,16 @@ def init():
yield
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
softmax_offset: Optional[ArrayLike],
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
softmax_type: AttnSoftmaxType,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
......@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask = jnp.expand_dims(mask, axis=-3)
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
softmax_out = jax.nn.softmax(logits).astype(dtype)
match softmax_type:
case AttnSoftmaxType.VANILLA_SOFTMAX:
softmax_out = jax.nn.softmax(logits).astype(dtype)
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
softmax_out = softmax_with_extra[..., :-1].astype(dtype)
case _:
raise NotImplementedError(f"Unknown {softmax_type=}")
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
......@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
......@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query,
key,
value,
softmax_offset,
bias,
mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
softmax_type=kwargs["softmax_type"],
dropout_rng=dropout_rng,
dtype=jnp.float32,
)
......@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key,
value,
bias,
softmax_offset,
sequence_descriptor,
dropout_rng,
**kwargs,
......@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args = (query, key, value)
case _:
raise ValueError(f"Unsupported {qkv_layout=}")
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
query.dtype
)
return fused_attn(
qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
).astype(query.dtype)
class BiasShape(Enum):
......@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v: int
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
softmax_type: AttnSoftmaxType
dropout_prob: float
dtype: DTypeLike
is_training: bool
......@@ -327,6 +352,8 @@ class FusedAttnRunner:
bias_shape: BiasShape
window_size: Tuple[int, int]
seq_desc_format: SeqDescFormat
stripe_size: int | None = None
num_segments_per_seq: int | None = None
# Specifies sharding resources for distributed tests
number_of_devices: int = 1
......@@ -341,6 +368,14 @@ class FusedAttnRunner:
# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None
def __post_init__(self):
# Reset defaults for num_segments_per_seq if not explicitly passed
if self.num_segments_per_seq is None:
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
else:
self.num_segments_per_seq = 1
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
......@@ -402,6 +437,7 @@ class FusedAttnRunner:
self.qkv_layout,
self.attn_bias_type,
self.attn_mask_type,
self.softmax_type,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
......@@ -439,7 +475,7 @@ class FusedAttnRunner:
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
......@@ -490,6 +526,13 @@ class FusedAttnRunner:
else:
pad_ratio = 0.0
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
self.softmax_offset = jax.random.uniform(
softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
)
else:
self.softmax_offset = None
def gen_valid(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
......@@ -544,7 +587,6 @@ class FusedAttnRunner:
return segment_ids, segment_pos, segment_pad
if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
......@@ -570,7 +612,6 @@ class FusedAttnRunner:
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
self.batch_size, self.max_seqlen_q, pad_ratio
)
......@@ -602,12 +643,14 @@ class FusedAttnRunner:
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
self.cp_inverse_reorder_fn = partial(
inverse_reorder_causal_load_balancing,
strategy=reorder_strategy,
cp_size=self.cp_size,
seq_dim=seq_dim,
stripe_size=self.stripe_size,
)
else:
# no-ops for non cp or non load balanced
......@@ -625,14 +668,24 @@ class FusedAttnRunner:
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
# if no CP and load balancing, else explicitly pass the segment_pos
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
)
if self.cp_size > 1 and self.cp_load_balanced
else None
),
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=(
True if self.cp_size > 1 and self.cp_load_balanced else False
),
)
case _:
......@@ -661,6 +714,8 @@ class FusedAttnRunner:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=False,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
......@@ -713,6 +768,16 @@ class FusedAttnRunner:
self.bias_pspec = PartitionSpec()
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource = (
self.mesh_resource.tpsp_resource
if self.mesh_resource.tpsp_resource is not None
else self.mesh_resource.tp_resource
)
self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
self.dropout_rng_pspec = PartitionSpec(
None,
)
......@@ -728,11 +793,11 @@ class FusedAttnRunner:
def test_forward(self):
"""
Test forward without JIT
Test forward with JITted primitive and unJITted reference
"""
self._setup_inputs()
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# Put test data onto each GPU for distributed.
......@@ -742,12 +807,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -756,6 +823,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
customcall_fused_dpa_jit = jit(
......@@ -766,6 +834,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
],
......@@ -826,7 +895,7 @@ class FusedAttnRunner:
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
customcall_args = [
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
......@@ -834,12 +903,14 @@ class FusedAttnRunner:
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
jax.device_put(self.bias, self.bias_sharding),
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
]
kwargs = {
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"softmax_type": self.softmax_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
......@@ -848,6 +919,7 @@ class FusedAttnRunner:
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
"stripe_size": self.stripe_size,
}
# We can compute dBias only for the [1, h, s, s] layout
......@@ -866,8 +938,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
lambda q, k, v, bias, softmax_offset, *args: grad_func(
customcall_fused_dpa,
q,
k,
v,
bias,
softmax_offset,
*args,
cp_reverse_out=True,
**kwargs,
),
arg_nums,
),
......@@ -876,6 +956,7 @@ class FusedAttnRunner:
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.softmax_offset_sharding,
self.seq_desc_sharding,
self.dropout_rng_sharding,
),
......@@ -883,7 +964,9 @@ class FusedAttnRunner:
)
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
lambda q, k, v, bias, softmax_offset, *args: grad_func(
jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
),
arg_nums,
)
)
......@@ -977,41 +1060,78 @@ class FusedAttnRunner:
],
)
@pytest.mark.parametrize(
"qkv_layout",
"softmax_type",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + bf16 + qkv packed
pytest.param(
2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
2,
2048,
2048,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.BS3HD,
id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED",
),
pytest.param(
2,
512,
1024,
2048,
2048,
12,
12,
64,
64,
jnp.bfloat16,
id="2-512-1024-12-12-64-64-BF16-CROSS",
QKVLayout.T3HD,
id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED",
),
# mid data size + bf16 + cross attn + kv packed
pytest.param(
2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
2,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.BSHD_BS2HD,
id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED",
),
pytest.param(
4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
2,
512,
1024,
12,
12,
64,
64,
jnp.bfloat16,
QKVLayout.THD_T2HD,
id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED",
),
# large data size + bf16 + cross attn + diff hidden v dim + qkv separate
pytest.param(
4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
2,
2048,
1024,
12,
12,
64,
32,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE",
),
pytest.param(
2,
......@@ -1022,10 +1142,108 @@ class FusedAttnRunner:
64,
32,
jnp.bfloat16,
id="2-2048-1024-12-12-64-32-BF16-CROSS",
QKVLayout.THD_THD_THD,
id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE",
),
# large data size + bf16 + gqa + kv packed
pytest.param(
2,
2048,
2048,
12,
6,
64,
64,
jnp.bfloat16,
QKVLayout.BSHD_BS2HD,
id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED",
),
pytest.param(
2,
2048,
2048,
12,
6,
64,
64,
jnp.bfloat16,
QKVLayout.THD_T2HD,
id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED",
),
# small data size + fp16 + diff hidden v dim + qkv packed
pytest.param(
4,
128,
128,
16,
16,
64,
32,
jnp.float16,
QKVLayout.BS3HD,
id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED",
),
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
4,
128,
128,
16,
16,
64,
32,
jnp.float16,
QKVLayout.T3HD,
id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED",
),
# small data size + fp16 + kv packed
pytest.param(
4,
128,
128,
16,
16,
64,
64,
jnp.float16,
QKVLayout.BSHD_BS2HD,
id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED",
),
pytest.param(
4,
128,
128,
16,
16,
64,
64,
jnp.float16,
QKVLayout.THD_T2HD,
id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED",
),
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.float16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.float16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
......@@ -1084,6 +1302,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1110,6 +1329,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
......@@ -1138,6 +1358,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
......@@ -1161,6 +1382,7 @@ class TestFusedAttn:
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
True,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test transformer_engine.jax.flax.TransformerLayer"""
......@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"
_KEY_OF_SOFTMAX_TYPE = "softmax_type"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING: True,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "off_by_one",
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE: "learnable",
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -418,6 +427,12 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......@@ -463,10 +478,22 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"encoder_decoder_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset": (
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
......@@ -534,7 +561,7 @@ class BaseTester:
"""Test forward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
self.runner(attrs).test_forward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
......@@ -542,7 +569,7 @@ class BaseTester:
"""Test backward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
self.runner(attrs).test_backward(data_shape, dtype)
class TestEncoderLayer(BaseTester):
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for permutation Triton kernels and high-level APIs"""
import functools
import jax
import jax.numpy as jnp
import pytest
# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
sort_chunks_by_index,
)
from utils import assert_allclose, pytest_parametrize_wrapper
ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2],
"L2": ALL_DISPATCH_COMBINE_CASES,
}
ALL_SORT_CHUNKS_CASES = [
(8, 4096, 1280),
(64, 4096, 4096),
(256, 4096, 9216),
]
SORT_CHUNKS_CASES = {
"L0": ALL_SORT_CHUNKS_CASES[0:2],
"L2": ALL_SORT_CHUNKS_CASES,
}
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": ALL_DTYPES,
"L2": ALL_DTYPES,
}
ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
"L0": [True],
"L2": ALL_WITH_PROBS,
}
def reference_make_row_id_map(
routing_map: jnp.ndarray,
) -> jnp.ndarray:
"""
Vectorized reference implementation of make_row_id_map using JAX primitives.
Parameters
----------
routing_map : jnp.ndarray
Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
are routed to which tokens (1 = routed, 0 = not routed).
Returns
-------
row_id_map : jnp.ndarray
The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
"""
num_tokens, num_experts = routing_map.shape
# For each expert, compute cumulative sum to get destination indices
cumsum_per_expert = jnp.cumsum(routing_map, axis=0)
# Compute total tokens per expert and expert offsets
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
)
# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * (
1 - routing_map
)
# Count routed experts per token
n_routed_per_token = jnp.sum(routing_map, axis=1)
# For each token, we need to sort by descending dest_row and pack into row_id_map
# Use a large negative value for non-routed experts so they sort to the end
sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max)
sorted_expert_indices = jnp.argsort(sort_keys, axis=1)
# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx = jnp.broadcast_to(
jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
)
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]
# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
row_id_map = jnp.concatenate(
[
sorted_dest_rows.astype(jnp.int32),
sorted_expert_indices.astype(jnp.int32),
n_routed_per_token.astype(jnp.int32)[:, None],
],
axis=1,
)
return row_id_map
def _reference_permute_impl(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
num_out_tokens: int,
) -> tuple:
"""
Vectorized internal helper for reference permutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
probs : jnp.ndarray
The probabilities of the input tensor.
num_out_tokens : int
Number of tokens in the permuted tensor.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray
Permuted probabilities if probs was provided, None otherwise.
"""
num_tokens, hidden_size = inp.shape
num_experts = (row_id_map.shape[1] - 1) // 2
# Extract destination rows, expert indices, and n_routed from row_id_map
dest_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts]
expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts]
n_routed = row_id_map[:, 2 * num_experts] # [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
# (slots beyond n_routed may contain garbage, not -1)
slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts]
valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts]
# Flatten for scatter operations
flat_dest_rows = dest_rows.flatten() # [num_tokens * num_experts]
flat_valid_mask = valid_mask.flatten()
flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts)
flat_expert_indices = expert_indices.flatten()
# Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped)
# This avoids overwriting valid entries at index 0 with zeros
flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens)
# Gather input tokens and scatter to output
output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
gathered_inp = inp[flat_token_indices] # [num_tokens * num_experts, hidden_size]
# Use segment_sum-like operation via scatter
# For each valid (token, expert) pair, write inp[token] to output[dest_row]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
output = output.at[flat_dest_rows_clamped].set(
gathered_inp,
mode="drop",
)
permuted_probs = None
if probs is not None:
permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype)
# Vectorized approach: gather probs and scatter to permuted_probs
if probs.ndim == 1:
flat_probs = probs[flat_token_indices]
else:
# Clamp invalid expert indices to 0 to avoid wraparound indexing with -1
# The result for invalid entries will be ignored anyway since they target num_out_tokens
# Cast to int32 explicitly for consistent indexing behavior
flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype(
jnp.int32
)
flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set(
flat_probs,
mode="drop",
)
return output, permuted_probs
def _reference_unpermute_impl(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: jnp.ndarray,
permuted_probs: jnp.ndarray,
) -> tuple:
"""
Vectorized internal helper for reference unpermutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
permuted_probs : jnp.ndarray
The permuted probabilities.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
unpermuted_probs : jnp.ndarray
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
# Extract source rows, expert indices, and n_routed from row_id_map
src_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts]
expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts]
n_routed = row_id_map[:, 2 * num_experts] # [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts]
valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts]
# Clamp invalid src_rows to 0 (they won't be used due to masking)
src_rows_clamped = jnp.where(valid_mask, src_rows, 0)
# Gather input from permuted positions
gathered_inp = inp[src_rows_clamped] # [num_tokens, num_experts, hidden_size]
# Apply merging probs if provided
if merging_probs is not None:
# Gather the merging weights for each (token, expert) pair using advanced indexing
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
weights = merging_probs[token_idx, expert_indices] # [num_tokens, num_experts]
gathered_inp = gathered_inp * weights[:, :, None]
# Mask out invalid entries and sum across experts
gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0)
output = jnp.sum(gathered_inp, axis=1) # [num_tokens, hidden_size]
unpermuted_probs = None
if permuted_probs is not None:
gathered_probs = permuted_probs[src_rows_clamped] # [num_tokens, num_experts]
unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set(
jnp.where(valid_mask, gathered_probs, 0.0)
)
return output, unpermuted_probs
def reference_token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: jnp.ndarray = None,
) -> tuple:
"""
Reference implementation of token_dispatch using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [num_tokens, num_experts].
num_out_tokens : int
Number of tokens in the permuted tensor.
probs : jnp.ndarray, optional
The probabilities of shape [num_tokens, num_experts].
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray or None
Permuted probabilities of shape [num_out_tokens], or None if probs not provided.
row_id_map : jnp.ndarray
The row_id_map for the permutation.
"""
row_id_map = reference_make_row_id_map(routing_map)
output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens)
return output, permuted_probs, row_id_map
def reference_token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: jnp.ndarray,
) -> jnp.ndarray:
"""
Reference implementation of token_combine using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
"""
output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None)
return output
def reference_make_chunk_sort_map(
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
num_tokens: int,
) -> jnp.ndarray:
"""
Vectorized reference implementation of make_chunk_sort_map using JAX primitives.
Parameters
----------
split_sizes : jnp.ndarray
The sizes of the chunks of shape [num_splits,].
sorted_indices : jnp.ndarray
The indices of the sorted chunks of shape [num_splits,].
num_tokens : int
Number of tokens.
Returns
-------
row_id_map : jnp.ndarray
Row ID map for chunk sorting of shape [num_tokens,].
"""
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
)
# Compute destination chunk boundaries based on sorted order
sorted_sizes = split_sizes[sorted_indices]
dest_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
)
# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32)
dest_offsets = dest_cumsum[inverse_indices]
# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32)
# within_chunk_offset[i] = position i's offset within its chunk
within_chunk_offset = position_indices - src_cumsum[chunk_ids]
# destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i]
row_id_map = dest_offsets[chunk_ids] + within_chunk_offset
return row_id_map.astype(jnp.int32)
def reference_sort_chunks_by_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
is_forward: bool,
) -> tuple:
"""
Vectorized reference implementation of sort_chunks_by_map using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to destination mapping of shape [num_tokens,].
probs : jnp.ndarray
The probabilities.
is_forward : bool
Whether this is forward or backward.
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
permuted_probs : jnp.ndarray
Sorted probabilities if probs was provided, None otherwise.
"""
num_tokens = inp.shape[0]
hidden_size = inp.shape[1]
if is_forward:
# Forward: scatter inp[src] to output[dest] where dest = row_id_map[src]
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
output = output.at[row_id_map].set(inp)
if probs is not None:
permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype)
permuted_probs = permuted_probs.at[row_id_map].set(probs)
else:
permuted_probs = None
else:
# Backward: gather output[dest] = inp[src] where src = row_id_map[dest]
output = inp[row_id_map]
if probs is not None:
permuted_probs = probs[row_id_map]
else:
permuted_probs = None
return output, permuted_probs
class TestHighLevelPermutationAPI:
"""Test high-level permutation APIs (token_dispatch, token_combine, etc.)
These tests compare the high-level APIs against reference implementations
to verify correctness of both forward and backward passes.
"""
@staticmethod
def generate_routing_map(
num_tokens: int,
num_experts: int,
tokens_per_expert: int = 2,
key: jax.Array = None,
):
"""Generate random routing map for testing"""
if key is None:
key = jax.random.PRNGKey(0)
routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
for token_idx in range(num_tokens):
key, subkey = jax.random.split(key)
expert_indices = jax.random.choice(
subkey, num_experts, shape=(tokens_per_expert,), replace=False
)
routing_map = routing_map.at[token_idx, expert_indices].set(1)
return routing_map
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_token_dispatch(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
):
"""
Individual test for token_dispatch forward and backward passes.
This test validates dispatch in isolation to catch errors that might be
masked when combined with token_combine in the roundtrip test.
Uses value_and_grad to validate both forward (via loss comparison) and
backward (via gradient comparison) passes against reference implementation.
"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
# Generate reference row_id_map for comparison
ref_row_id_map = reference_make_row_id_map(routing_map)
# =====================================================================
# Test forward and backward pass using value_and_grad
# (value validates forward, grad validates backward)
# =====================================================================
if with_probs:
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
@jax.jit
def ref_dispatch_loss(x, p):
out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
inp, probs
)
ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_dispatch_loss, argnums=(0, 1)
)(inp, probs)
# Validate forward loss matches
assert_allclose(loss_val, ref_loss_val, dtype=dtype)
# Validate gradients
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
else:
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(out**2)
@jax.jit
def ref_dispatch_loss_no_probs(x):
out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens)
return jnp.sum(out**2)
loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
# Validate forward loss matches
assert_allclose(loss_val, ref_loss_val, dtype=dtype)
# Validate gradients
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
# =========================================================================
# Consolidated dispatch + combine tests
# =========================================================================
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_dispatch_and_combine(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
):
"""
Comprehensive test for token_dispatch and token_combine.
Tests:
1. Dispatch forward pass against reference (element-by-element)
2. Dispatch backward pass against reference
3. Combine forward pass against reference (element-by-element)
4. Combine backward pass against reference
5. Roundtrip: dispatch + combine recovers original input
6. row_id_map n_routed column validation
7. Probs permutation (when with_probs=True)
"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data
key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
# Generate merging probs (normalized per token)
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed
merging_probs = merging_probs / jnp.maximum(
jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
)
# =====================================================================
# Test 1: Dispatch forward pass
# =====================================================================
output, permuted_probs, row_id_map, _, _ = token_dispatch(
inp, routing_map, num_out_tokens, probs=probs
)
ref_output, ref_permuted_probs = _reference_permute_impl(
inp, row_id_map, probs, num_out_tokens
)
# Validate row_id_map structure: n_routed column should match routing_map sum
n_routed_actual = row_id_map[:, -1]
n_routed_expected = jnp.sum(routing_map, axis=1)
assert jnp.array_equal(
n_routed_actual, n_routed_expected
), "make_row_id_map n_routed column mismatch"
# Compare dispatch output
assert_allclose(output, ref_output, dtype=dtype)
if with_probs:
assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype)
# =====================================================================
# Test 2: Dispatch backward pass
# =====================================================================
if with_probs:
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
@jax.jit
def ref_dispatch_loss(x, p):
out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
_, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
inp, probs
)
_, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_dispatch_loss, argnums=(0, 1)
)(inp, probs)
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
else:
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(out**2)
@jax.jit
def ref_dispatch_loss_no_probs(x):
out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens)
return jnp.sum(out**2)
_, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
_, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
# =====================================================================
# Test 3: Combine forward pass
# =====================================================================
combined = token_combine(output, row_id_map, merging_probs)
ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0]
assert_allclose(combined, ref_combined, dtype=dtype)
# =====================================================================
# Test 4: Combine backward pass
# =====================================================================
@jax.jit
def combine_loss(x):
return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2)
@jax.jit
def ref_combine_loss(x):
return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2)
_, combine_grad = jax.value_and_grad(combine_loss)(output)
_, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output)
assert_allclose(combine_grad, ref_combine_grad, dtype=dtype)
# =====================================================================
# Test 5: Roundtrip (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
@jax.jit
def roundtrip(x):
dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return token_combine(dispatched, rid_map, uniform_merging_probs)
roundtrip_output = roundtrip(inp)
assert_allclose(roundtrip_output, inp, dtype=dtype)
# =========================================================================
# sort_chunks_by_index tests
# =========================================================================
@pytest_parametrize_wrapper(
"num_splits,total_tokens,hidden_size",
SORT_CHUNKS_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype):
"""Test sort_chunks_by_index forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate random split sizes
key, size_key = jax.random.split(key)
split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits)
split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1]))
# Generate sorted indices
key, sort_key = jax.random.split(key)
sorted_indices = jax.random.permutation(sort_key, num_splits)
# Generate input data
key, inp_key = jax.random.split(key)
inp = jax.random.uniform(
inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Get reference row_id_map
row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens)
# Define loss functions (JIT compiled for performance)
@jax.jit
def loss_fn(x):
output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices)
return jnp.sum(output**2)
@jax.jit
def ref_loss_fn(x):
output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True)
return jnp.sum(output**2)
# Test forward pass
output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True)
# Test backward pass with JIT
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# Compare forward and backward
assert_allclose(output, ref_output)
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
# =========================================================================
# Consolidated dispatch + combine with padding tests
# =========================================================================
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_dispatch_and_combine_with_padding(
self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs
):
"""
Comprehensive test for token_dispatch and token_combine with padding/unpadding.
Tests:
1. Dispatch with padding: output shape and alignment
2. Dispatch backward pass with padding
3. Combine with unpad: output shape
4. Combine backward pass with unpad
5. Roundtrip with padding: dispatch + combine recovers original
6. Probs permutation with padding (when with_probs=True)
"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
num_out_tokens = int(jnp.sum(routing_map))
# Compute worst-case padded size
worst_case_size = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
# Generate input data
key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
# Generate merging probs (normalized per token)
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed
merging_probs = merging_probs / jnp.maximum(
jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
)
# =====================================================================
# Test 1: Dispatch with padding - forward pass
# =====================================================================
output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch(
inp, routing_map, num_out_tokens, probs=probs, align_size=align_size
)
# Check output shape
assert output.shape == (worst_case_size, hidden_size)
if with_probs:
assert permuted_probs is not None
assert permuted_probs.shape == (worst_case_size,)
else:
assert permuted_probs is None
# Check alignment: each expert's tokens should be aligned
for expert_idx in range(num_experts):
expert_tokens = int(target_tokens_per_expert[expert_idx])
assert expert_tokens % align_size == 0 or expert_tokens == 0
# =====================================================================
# Test 2: Dispatch with padding - backward pass
# =====================================================================
if with_probs:
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(
x, routing_map, num_out_tokens, probs=p, align_size=align_size
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs)
assert inp_grad.shape == inp.shape
assert probs_grad.shape == probs.shape
assert not jnp.any(jnp.isnan(inp_grad))
assert not jnp.any(jnp.isnan(probs_grad))
else:
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(
x, routing_map, num_out_tokens, align_size=align_size
)
return jnp.sum(out**2)
inp_grad = jax.grad(dispatch_loss_no_probs)(inp)
assert inp_grad.shape == inp.shape
assert not jnp.any(jnp.isnan(inp_grad))
# =====================================================================
# Test 3: Combine with unpad - forward pass
# =====================================================================
combined = token_combine(output, row_id_map, merging_probs, pad_offsets)
assert combined.shape == (num_tokens, hidden_size)
# =====================================================================
# Test 4: Combine with unpad - backward pass
# =====================================================================
@jax.jit
def combine_loss(x):
return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2)
combine_grad = jax.grad(combine_loss)(output)
assert combine_grad.shape == output.shape
assert not jnp.any(jnp.isnan(combine_grad))
# =====================================================================
# Test 5: Roundtrip with padding (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
@jax.jit
def roundtrip(x):
dispatched, _, rid_map, p_offsets, _ = token_dispatch(
x, routing_map, num_out_tokens, align_size=align_size
)
return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets)
roundtrip_output = roundtrip(inp)
assert_allclose(roundtrip_output, inp, dtype=dtype)
# Test roundtrip gradient
@jax.jit
def roundtrip_loss(x):
return jnp.sum(roundtrip(x) ** 2)
roundtrip_grad = jax.grad(roundtrip_loss)(inp)
assert roundtrip_grad.shape == inp.shape
assert not jnp.any(jnp.isnan(roundtrip_grad))
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
......@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.cpp_extensions.attention import AttnSoftmaxType
from transformer_engine.jax.softmax import SoftmaxFusionType, softmax
from transformer_engine.jax.flax.module import Softmax
......@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
softmax_fusion_type: SoftmaxFusionType
dtype: DTypeLike
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
......@@ -68,6 +70,7 @@ class SoftmaxRunner:
def _is_support(self):
return is_softmax_kernel_available(
self.softmax_fusion_type,
self.softmax_type,
self.batch_size,
self.num_heads,
......@@ -85,22 +88,22 @@ class SoftmaxRunner:
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
match self.softmax_fusion_type:
case SoftmaxFusionType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
case SoftmaxFusionType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
case SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
raise ValueError(f"Unknown {self.softmax_fusion_type=}")
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_fusion_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
......@@ -117,7 +120,7 @@ class SoftmaxRunner:
args = [self.logits, self.mask]
kwargs = {
"scale_factor": self.scale_factor,
"softmax_type": self.softmax_type,
"softmax_fusion_type": self.softmax_fusion_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
......@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng = jax.random.PRNGKey(0)
softmax_module = Softmax(
scale_factor=runner.scale_factor,
softmax_type=runner.softmax_type,
softmax_fusion_type=runner.softmax_fusion_type,
)
softmax_vars = softmax_module.init(rng, runner.logits, runner.mask)
module_out = softmax_module.apply(softmax_vars, runner.logits, runner.mask)
......@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner = SoftmaxPrimitivesRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
runner.test_backward()
......@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
"softmax_fusion_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
pytest.param(SoftmaxFusionType.SCALED, id="SCALED"),
pytest.param(SoftmaxFusionType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
......@@ -263,11 +266,11 @@ class TestSoftmaxModule:
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype):
"""
Test forward with parameterized configs
"""
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
module_runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_fusion_type, dtype)
bias = None
runner = SoftmaxModuleRunner(module_runner, bias)
runner.test_forward()
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for Triton-based custom calls in TE JAX."""
import jax
import jax.numpy as jnp
import pytest
from utils import assert_allclose, pytest_parametrize_wrapper
import triton
import triton.language as tl
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
from transformer_engine.jax.triton_extensions import triton_call_lowering
@pytest.fixture(autouse=True, scope="module")
def init():
"""WAR for CUDA uninitialize error"""
_ = jnp.zeros(0)
yield
class TestTritonBinding:
"""Test Triton binding primitive."""
# Define autotuned Triton kernel
@staticmethod
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 256}), # Uses defaults: num_warps=4, num_stages=3
triton.Config({"BLOCK_SIZE": 512}, num_warps=8), # Custom num_warps
],
key=["n_elements"], # Autotune based on input size
)
@triton.jit
def amax_kernel(
x_ptr,
amax_ptr,
n_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Compute amax using Triton with autotuning."""
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
abs_x = tl.abs(x)
block_max = tl.max(abs_x)
tl.atomic_max(amax_ptr, block_max)
# Define test primitive
class AmaxTritonPrimitive(BasePrimitive):
"""Test primitive using Triton kernel."""
name = "te_amax_triton_test"
multiple_results = False
impl_static_args = ()
@staticmethod
def abstract(x_aval):
return jax.core.ShapedArray((1,), jnp.float32)
@staticmethod
def impl(x):
assert TestTritonBinding.AmaxTritonPrimitive.inner_primitive is not None
return TestTritonBinding.AmaxTritonPrimitive.inner_primitive.bind(x)
@staticmethod
def lowering(ctx, x):
"""MLIR lowering using Triton kernel."""
n_elements = 1
for dim in ctx.avals_in[0].shape:
n_elements *= dim
# For autotuned kernels, use the minimum BLOCK_SIZE from configs
# to ensure all elements are processed by all configs
block_size = min(
config.kwargs.get("BLOCK_SIZE") for config in TestTritonBinding.amax_kernel.configs
)
grid = (triton.cdiv(n_elements, block_size),)
return triton_call_lowering(
ctx,
TestTritonBinding.amax_kernel, # Autotuned kernel
x,
grid=grid,
constexprs={"n_elements": n_elements},
# BLOCK_SIZE comes from autotuner config, not passed here
)
register_primitive(AmaxTritonPrimitive)
@staticmethod
def _triton_amax(x: jnp.ndarray) -> jnp.ndarray:
"""Compute amax using Triton kernel."""
return TestTritonBinding.AmaxTritonPrimitive.outer_primitive.bind(x)
@pytest_parametrize_wrapper("shape", [(1024, 1024)])
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16])
def test_triton_amax(self, shape, dtype):
"""Test Triton amax with JIT."""
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, dtype)
expected = jnp.max(jnp.abs(x), keepdims=False).astype(jnp.float32)
jitted_amax = jax.jit(self._triton_amax)
result = jitted_amax(x)
assert_allclose(result, expected, dtype=jnp.float32)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility for the TE layer tests"""
......@@ -21,6 +21,7 @@ from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnSoftmaxType,
canonicalize_attn_mask_type,
make_swa_mask,
)
......@@ -46,6 +47,13 @@ def is_devices_enough(required):
return len(jax.devices()) >= required
def is_devices_equal(required):
"""
Check if the available GPUs is exactly equal
"""
return len(jax.devices()) == required
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
drop_path_shape = list(range(0, len(shape)))
......@@ -162,6 +170,7 @@ class DotProductAttention(nn.Module):
dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
......@@ -211,6 +220,24 @@ class DotProductAttention(nn.Module):
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads = query.shape[-2]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset = None
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset = jnp.zeros((1, num_attention_heads, 1, 1), dtype=input_dtype)
elif self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset = self.param(
"softmax_offset",
nn.initializers.zeros,
(1, num_attention_heads, 1, 1),
jnp.float32,
)
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
......@@ -241,9 +268,23 @@ class DotProductAttention(nn.Module):
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Add attention sink to the last column if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col = jnp.broadcast_to(
softmax_offset,
(attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2], 1),
)
attn_weights = jnp.concatenate([attn_weights, extra_col], axis=-1)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Remove the extra column after softmax if not vanilla softmax
if self.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
attn_weights = attn_weights[..., :-1]
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
......@@ -535,6 +576,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
def __post_init__(self):
if self.kernel_init is None:
......@@ -801,6 +843,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits,
softmax_type=self.softmax_type,
)(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -1058,6 +1101,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1111,6 +1155,9 @@ class EncoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1126,6 +1173,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1222,6 +1270,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
softmax_type: str = "vanilla"
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1290,6 +1339,9 @@ class DecoderLayer(nn.Module):
else:
x = inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type = AttnSoftmaxType.from_str(self.softmax_type)
# Self-attention block
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
......@@ -1305,6 +1357,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......@@ -1343,6 +1396,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
softmax_type=attn_softmax_type,
name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -89,40 +89,47 @@ def generate_input_shapes(
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
total_tokens = cu_seqlens_q_padded[-1]
q_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads,
config.head_dim_qk,
)
k_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_qk,
)
v_input_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_gqa_groups,
config.head_dim_v,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
total_tokens,
config.num_heads * config.head_dim_v,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format=} is not supported!"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import logging
......@@ -119,7 +119,14 @@ model_configs_base = {
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
dtype,
model_configs,
model,
ckpt_attn,
workspace_opt,
qkv_layout,
swa,
pad_between_seqs,
):
"""Test DotProductAttention module"""
......@@ -310,6 +317,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
model_configs_num_splits = {
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
def test_dpa_num_splits(dtype, model_configs, model):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention(
dtype,
model_configs,
model,
False,
True,
None,
False,
False,
)
model_configs_softmax = {
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
......@@ -1155,6 +1187,8 @@ def _run_dot_product_attention(
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
# Only pass num_splits when exercising the FlashAttention path
num_splits=config.num_splits if backend == "FlashAttention" else 1,
)
max_logit = None
if config.return_max_logit:
......@@ -1789,9 +1823,10 @@ def test_mha_fp8_vs_f16(
fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
pytest.skip("No FP8 attention backend available.")
fused_attn_supported_f16 = False
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
......@@ -1799,8 +1834,8 @@ def test_mha_fp8_vs_f16(
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
pytest.skip("No attention backend available.")
if flash_attn_supported:
......@@ -1812,23 +1847,28 @@ def test_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
if flash_attn_supported:
if flash_attn_supported and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
......@@ -1841,32 +1881,33 @@ def test_mha_fp8_vs_f16(
rmse_tol,
True,
)
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if fused_attn_supported_fp8 and fused_attn_supported_f16:
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
logging.debug("========== {:^25s} ==========".format("forward output"))
compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
True,
)
def _run_mha_fp8_vs_f16(
......@@ -2492,7 +2533,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
mask_type: str,
quantizers: list[Quantizer],
......@@ -2521,7 +2561,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv, *_ = ext.general_gemm(
qkv_weight_fp8,
inp_fp8,
workspace,
bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer,
......@@ -2563,9 +2602,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer=s_quantizer,
)
tensors_to_save, tensor_objects = prepare_for_saving(
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
tensors_to_save, tensor_objects = prepare_for_saving(q, k, v, inp_fp8, qkv_weight_fp8, out)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
......@@ -2595,7 +2632,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
(q, k, v, inp_fp8, qkv_weight_fp8, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
......@@ -2651,7 +2688,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8,
dqkv_c,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD,
layout="NN",
......@@ -2661,7 +2697,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad, *_ = ext.general_gemm(
inp_fp8,
dqkv,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD,
layout="NT",
......@@ -2712,9 +2747,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self,
......@@ -2733,7 +2765,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training,
self.mask_type,
self.quantizers,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -7,7 +7,7 @@ import subprocess
import sys
import pathlib
import logging
import copy
import pytest
import torch
from transformer_engine.pytorch import (
......@@ -74,7 +74,7 @@ dtypes = ["bf16", "fp16"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
configs = ["cp_1_0", "cp_1_2", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
......@@ -97,12 +97,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
......@@ -184,7 +188,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......@@ -225,10 +229,14 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
......@@ -282,6 +290,14 @@ def test_cp_with_fused_attention(
)
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
if qkv_format == "thd":
config = copy.deepcopy(config)
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment