Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2188 additions
and
177 deletions
+2188
-177
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+1
-1
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+1
-1
tests/jax/test_distributed_permutation.py
tests/jax/test_distributed_permutation.py
+597
-0
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+30
-17
tests/jax/test_functions.py
tests/jax/test_functions.py
+1
-1
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+257
-35
tests/jax/test_layer.py
tests/jax/test_layer.py
+30
-3
tests/jax/test_misc.py
tests/jax/test_misc.py
+1
-1
tests/jax/test_multi_process_distributed_grouped_gemm.py
tests/jax/test_multi_process_distributed_grouped_gemm.py
+1
-1
tests/jax/test_permutation.py
tests/jax/test_permutation.py
+926
-0
tests/jax/test_recipe_characteristics.py
tests/jax/test_recipe_characteristics.py
+1
-1
tests/jax/test_sanity_import.py
tests/jax/test_sanity_import.py
+1
-1
tests/jax/test_softmax.py
tests/jax/test_softmax.py
+28
-25
tests/jax/test_triton_custom_calls.py
tests/jax/test_triton_custom_calls.py
+115
-0
tests/jax/utils.py
tests/jax/utils.py
+55
-1
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+27
-20
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+86
-55
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+28
-12
tests/pytorch/attention/test_cp_utils.py
tests/pytorch/attention/test_cp_utils.py
+1
-1
tests/pytorch/attention/test_kv_cache.py
tests/pytorch/attention/test_kv_cache.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/jax/test_distributed_layernorm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
re
...
...
tests/jax/test_distributed_permutation.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for distributed/sharded execution of MoE permutation primitives.
Testing Strategy:
=================
MoE permutation is data-dependent - the destination index for each token depends
on how many tokens before it are routed to the same expert. This means:
1. We CANNOT compare sharded output against global reference directly
2. Instead, we verify that each GPU's LOCAL output is correct according to its
LOCAL routing (which produces LOCAL row_id_map with LOCAL indices)
For data-parallel MoE without expert parallelism:
- Each GPU has ALL experts replicated
- Each GPU processes a subset of tokens (sharded on token/batch dimension)
- Each GPU computes its own local row_id_map from its local routing_map slice
- Each GPU's output is local and doesn't need to match global output
These tests verify:
1. Local token_dispatch: sharded input -> local row_id_map -> local permute (forward + backward)
2. Local roundtrip: dispatch + combine recovers original input (forward + backward)
"""
import
pytest
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax.sharding
import
Mesh
,
NamedSharding
,
PartitionSpec
from
distributed_test_base
import
generate_configs
from
utils
import
assert_allclose
,
pytest_parametrize_wrapper
# High-level API with VJP support
from
transformer_engine.jax.permutation
import
(
token_dispatch
,
token_combine
,
)
# Reference implementations from test_permutation.py
from
test_permutation
import
(
reference_make_row_id_map
,
_reference_permute_impl
,
_reference_unpermute_impl
,
reference_token_combine
,
)
# Dispatch/combine test cases: (num_tokens, num_experts, hidden_size, topk)
# topk = number of experts each token is routed to
# Includes small, medium-large, and largest stress test cases.
ALL_DISPATCH_COMBINE_CASES
=
[
(
128
,
4
,
64
,
2
),
(
4096
,
32
,
1280
,
2
),
(
4096
,
256
,
4096
,
6
),
]
DISPATCH_COMBINE_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_CASES
[
0
:
1
],
"L2"
:
ALL_DISPATCH_COMBINE_CASES
,
}
# Dispatch/combine with padding test cases: (num_tokens, num_experts, hidden_size, topk, align_size)
ALL_DISPATCH_COMBINE_PADDING_CASES
=
[
(
128
,
4
,
64
,
2
,
8
),
(
4096
,
32
,
1280
,
2
,
128
),
(
4096
,
256
,
4096
,
6
,
16
),
]
DISPATCH_COMBINE_PADDING_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
[
0
:
1
],
"L2"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
,
}
# Dtypes for testing
ALL_DTYPES
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
DTYPES
=
{
"L0"
:
[
jnp
.
float32
],
"L2"
:
ALL_DTYPES
,
}
class
TestDistributedPermutation
:
"""Test distributed/sharded execution of MoE permutation primitives.
These tests validate that custom partitioning produces correct LOCAL results
when inputs are sharded across multiple devices.
Key insight: With data-parallel MoE, each GPU independently processes its
local tokens. The row_id_map is generated locally and contains LOCAL indices.
We verify correctness by comparing each shard's output against the reference
implementation run on that shard's local data.
"""
@
staticmethod
def
compute_padded_output_size
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
align_size
:
int
,
num_dp_devices
:
int
,
)
->
int
:
"""Compute global_num_out_tokens for distributed padding tests.
Each device processes local_num_tokens tokens. We compute the worst-case
padded output size per device, then multiply by num_dp_devices to get
a global size that ensures global / num_dp >= local_worst.
"""
local_num_tokens
=
num_tokens
//
num_dp_devices
local_raw_out
=
local_num_tokens
*
topk
local_worst
=
((
local_raw_out
+
num_experts
*
(
align_size
-
1
))
//
align_size
)
*
align_size
return
local_worst
*
num_dp_devices
@
staticmethod
def
generate_routing_map
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
=
2
,
# Number of experts each token is routed to (max 1s per row).
key
:
jax
.
Array
=
None
,
):
if
key
is
None
:
key
=
jax
.
random
.
PRNGKey
(
0
)
routing_map
=
jnp
.
zeros
((
num_tokens
,
num_experts
),
dtype
=
jnp
.
int32
)
for
token_idx
in
range
(
num_tokens
):
key
,
subkey
=
jax
.
random
.
split
(
key
)
expert_indices
=
jax
.
random
.
choice
(
subkey
,
num_experts
,
shape
=
(
topk
,),
replace
=
False
)
routing_map
=
routing_map
.
at
[
token_idx
,
expert_indices
].
set
(
1
)
return
routing_map
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,topk"
,
DISPATCH_COMBINE_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_local_token_dispatch
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_shardy
,
):
"""
Test token_dispatch with sharded inputs.
Verifies that sharded execution produces the same result as chunk-wise
reference execution. The sharded primitive:
1. Receives global num_out_tokens (partition function divides it)
2. Each GPU operates on its local shard independently
3. Results are gathered (concatenated) across GPUs
Output ordering: [GPU0_expert0, GPU0_expert1, ... | GPU1_expert0, ...]
The reference processes each chunk independently and concatenates,
matching the sharded execution's output ordering.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate global inputs
key
,
inp_key
,
prob_key
=
jax
.
random
.
split
(
key
,
3
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
topk
,
key
)
probs
=
jax
.
random
.
uniform
(
prob_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
# Shard on token (batch) dimension
dp_axis
=
mesh_resource
.
dp_resource
sharded_pspec
=
PartitionSpec
(
dp_axis
,
None
)
# Compute num_out_tokens as concrete values
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
# Local num_out_tokens is used for reference implementation
num_dp_devices
=
mesh
.
shape
[
dp_axis
]
if
dp_axis
else
1
global_num_out_tokens
=
num_tokens
*
topk
local_num_tokens
=
num_tokens
//
num_dp_devices
local_num_out_tokens
=
local_num_tokens
*
topk
with
mesh
:
inp_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
routing_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
probs_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
# Shard the inputs
inp_sharded
=
jax
.
device_put
(
inp
,
inp_sharding
)
routing_sharded
=
jax
.
device_put
(
routing_map
,
routing_sharding
)
probs_sharded
=
jax
.
device_put
(
probs
,
probs_sharding
)
# ================================================================
# Forward pass test
# ================================================================
@
jax
.
jit
def
target_dispatch
(
x
,
rm
,
p
):
# Pass global num_out_tokens - partition function divides it
out
,
perm_probs
,
rid_map
,
_
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
probs
=
p
)
return
out
,
perm_probs
,
rid_map
# Reference: process each GPU's shard independently, then concatenate
# This matches how the sharded primitive operates:
# - Each GPU processes its local shard
# - Results are gathered (concatenated) across GPUs
# Output ordering: [GPU0_exp0, GPU0_exp1, ... | GPU1_exp0, GPU1_exp1, ...]
inp_shards
=
jnp
.
reshape
(
inp
,
(
num_dp_devices
,
local_num_tokens
,
hidden_size
))
routing_shards
=
jnp
.
reshape
(
routing_map
,
(
num_dp_devices
,
local_num_tokens
,
num_experts
)
)
probs_shards
=
jnp
.
reshape
(
probs
,
(
num_dp_devices
,
local_num_tokens
,
num_experts
))
ref_outputs
=
[]
ref_perm_probs_list
=
[]
ref_rid_maps
=
[]
for
i
in
range
(
num_dp_devices
):
shard_rid_map
=
reference_make_row_id_map
(
routing_shards
[
i
])
shard_out
,
shard_perm_probs
=
_reference_permute_impl
(
inp_shards
[
i
],
shard_rid_map
,
probs_shards
[
i
],
local_num_out_tokens
)
ref_outputs
.
append
(
shard_out
)
ref_perm_probs_list
.
append
(
shard_perm_probs
)
ref_rid_maps
.
append
(
shard_rid_map
)
# Concatenate like all_gather would
ref_out
=
jnp
.
concatenate
(
ref_outputs
,
axis
=
0
)
ref_perm_probs
=
jnp
.
concatenate
(
ref_perm_probs_list
,
axis
=
0
)
ref_rid_map
=
jnp
.
concatenate
(
ref_rid_maps
,
axis
=
0
)
# Run target on sharded inputs
target_out
,
target_perm_probs
,
target_rid_map
=
target_dispatch
(
inp_sharded
,
routing_sharded
,
probs_sharded
)
# Compare forward outputs
assert_allclose
(
jax
.
device_get
(
target_out
),
ref_out
,
dtype
=
dtype
)
assert_allclose
(
jax
.
device_get
(
target_perm_probs
),
ref_perm_probs
,
dtype
=
dtype
)
# Verify row_id_map n_routed column matches routing_map sum
target_rid_map_np
=
jax
.
device_get
(
target_rid_map
)
assert
jnp
.
array_equal
(
target_rid_map_np
[:,
-
1
],
ref_rid_map
[:,
-
1
]
),
"n_routed column mismatch"
# Sanity checks
target_out_np
=
jax
.
device_get
(
target_out
)
target_perm_probs_np
=
jax
.
device_get
(
target_perm_probs
)
assert
not
np
.
any
(
np
.
isnan
(
target_out_np
)),
"Output contains NaN"
assert
not
np
.
any
(
np
.
isnan
(
target_perm_probs_np
)),
"Permuted probs contain NaN"
assert
np
.
all
(
target_perm_probs_np
>=
0
),
"Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def
target_loss
(
x
,
rm
,
p
):
out
,
perm_probs
,
_
,
_
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
probs
=
p
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
# Reference loss: process chunks independently and sum
def
ref_chunk_loss
(
inp_chunk
,
routing_chunk
,
probs_chunk
):
rid_map
=
reference_make_row_id_map
(
routing_chunk
)
out
,
perm_probs
=
_reference_permute_impl
(
inp_chunk
,
rid_map
,
probs_chunk
,
local_num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
target_grad_fn
=
jax
.
jit
(
jax
.
grad
(
target_loss
,
argnums
=
(
0
,
2
)))
ref_chunk_grad_fn
=
jax
.
jit
(
jax
.
grad
(
ref_chunk_loss
,
argnums
=
(
0
,
2
)))
target_inp_grad
,
target_probs_grad
=
target_grad_fn
(
inp_sharded
,
routing_sharded
,
probs_sharded
)
# Compute reference gradients per chunk, then concatenate
ref_inp_grads
=
[]
ref_probs_grads
=
[]
for
i
in
range
(
num_dp_devices
):
chunk_inp_grad
,
chunk_probs_grad
=
ref_chunk_grad_fn
(
inp_shards
[
i
],
routing_shards
[
i
],
probs_shards
[
i
]
)
ref_inp_grads
.
append
(
chunk_inp_grad
)
ref_probs_grads
.
append
(
chunk_probs_grad
)
ref_inp_grad
=
jnp
.
concatenate
(
ref_inp_grads
,
axis
=
0
)
ref_probs_grad
=
jnp
.
concatenate
(
ref_probs_grads
,
axis
=
0
)
assert_allclose
(
jax
.
device_get
(
target_inp_grad
),
ref_inp_grad
,
dtype
=
dtype
)
assert_allclose
(
jax
.
device_get
(
target_probs_grad
),
ref_probs_grad
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,topk"
,
DISPATCH_COMBINE_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_local_roundtrip
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
num_tokens
,
num_experts
,
hidden_size
,
topk
,
dtype
,
use_shardy
,
):
"""
Test roundtrip: token_dispatch followed by token_combine with sharded inputs.
Each GPU:
1. Gets a shard of the input and routing_map
2. Performs local dispatch (permute)
3. Performs local combine (unpermute)
4. With uniform merging probs, should recover original input
Tests both forward pass and backward pass (gradient should be 2*x).
"""
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate global inputs
key
,
inp_key
=
jax
.
random
.
split
(
key
,
2
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
topk
,
key
)
# Uniform merging probs for perfect roundtrip
uniform_merging_probs
=
routing_map
.
astype
(
dtype
)
/
jnp
.
maximum
(
jnp
.
sum
(
routing_map
,
axis
=
1
,
keepdims
=
True
),
1.0
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
dp_axis
=
mesh_resource
.
dp_resource
sharded_pspec
=
PartitionSpec
(
dp_axis
,
None
)
# Compute num_out_tokens as concrete value
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
global_num_out_tokens
=
num_tokens
*
topk
with
mesh
:
inp_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
routing_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
merging_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
inp_sharded
=
jax
.
device_put
(
inp
,
inp_sharding
)
routing_sharded
=
jax
.
device_put
(
routing_map
,
routing_sharding
)
merging_sharded
=
jax
.
device_put
(
uniform_merging_probs
,
merging_sharding
)
# ================================================================
# Forward pass test
# ================================================================
@
jax
.
jit
def
roundtrip
(
x
,
rm
,
mprobs
):
dispatched
,
_
,
rid_map
,
_
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
)
return
token_combine
(
dispatched
,
rid_map
,
mprobs
)
roundtrip_out
=
roundtrip
(
inp_sharded
,
routing_sharded
,
merging_sharded
)
# Should recover original input
assert_allclose
(
jax
.
device_get
(
roundtrip_out
),
jax
.
device_get
(
inp_sharded
),
dtype
=
dtype
)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def
roundtrip_loss
(
x
,
rm
,
mprobs
):
dispatched
,
_
,
rid_map
,
_
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
)
combined
=
token_combine
(
dispatched
,
rid_map
,
mprobs
)
return
jnp
.
sum
(
combined
**
2
)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn
=
jax
.
jit
(
jax
.
grad
(
roundtrip_loss
,
argnums
=
0
))
computed_grad
=
grad_fn
(
inp_sharded
,
routing_sharded
,
merging_sharded
)
expected_grad
=
2.0
*
inp_sharded
assert_allclose
(
jax
.
device_get
(
computed_grad
),
jax
.
device_get
(
expected_grad
),
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,topk,align_size"
,
DISPATCH_COMBINE_PADDING_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_local_token_dispatch_with_padding
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
num_tokens
,
num_experts
,
hidden_size
,
topk
,
align_size
,
dtype
,
use_shardy
,
):
"""
Test token_dispatch with padding using sharded inputs.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate global inputs
key
,
inp_key
,
prob_key
=
jax
.
random
.
split
(
key
,
3
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
topk
,
key
)
probs
=
jax
.
random
.
uniform
(
prob_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
dp_axis
=
mesh_resource
.
dp_resource
sharded_pspec
=
PartitionSpec
(
dp_axis
,
None
)
num_dp_devices
=
mesh
.
shape
[
dp_axis
]
if
dp_axis
else
1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens
=
self
.
compute_padded_output_size
(
num_tokens
,
num_experts
,
topk
,
align_size
,
num_dp_devices
)
with
mesh
:
inp_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
routing_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
probs_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
inp_sharded
=
jax
.
device_put
(
inp
,
inp_sharding
)
routing_sharded
=
jax
.
device_put
(
routing_map
,
routing_sharding
)
probs_sharded
=
jax
.
device_put
(
probs
,
probs_sharding
)
# ================================================================
# Forward pass test
# ================================================================
@
jax
.
jit
def
dispatch_with_padding
(
x
,
rm
,
p
):
out
,
perm_probs
,
rid_map
,
pad_offsets
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
probs
=
p
,
align_size
=
align_size
)
return
out
,
perm_probs
,
rid_map
,
pad_offsets
out
,
perm_probs
,
rid_map
,
pad_offsets
=
dispatch_with_padding
(
inp_sharded
,
routing_sharded
,
probs_sharded
)
# Sanity checks
out_np
=
jax
.
device_get
(
out
)
perm_probs_np
=
jax
.
device_get
(
perm_probs
)
assert
not
np
.
any
(
np
.
isnan
(
out_np
)),
"Output contains NaN"
assert
not
np
.
any
(
np
.
isnan
(
perm_probs_np
)),
"Permuted probs contain NaN"
assert
np
.
all
(
perm_probs_np
>=
0
),
"Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def
loss_with_padding
(
x
,
rm
,
p
):
out
,
perm_probs
,
_
,
_
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
probs
=
p
,
align_size
=
align_size
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
grad_fn
=
jax
.
jit
(
jax
.
grad
(
loss_with_padding
,
argnums
=
(
0
,
2
)))
inp_grad
,
probs_grad
=
grad_fn
(
inp_sharded
,
routing_sharded
,
probs_sharded
)
# Gradients should not contain NaN
assert
not
np
.
any
(
np
.
isnan
(
jax
.
device_get
(
inp_grad
))),
"Input gradient contains NaN"
assert
not
np
.
any
(
np
.
isnan
(
jax
.
device_get
(
probs_grad
))),
"Probs gradient contains NaN"
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,topk,align_size"
,
DISPATCH_COMBINE_PADDING_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_local_roundtrip_with_padding
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
num_tokens
,
num_experts
,
hidden_size
,
topk
,
align_size
,
dtype
,
use_shardy
,
):
"""
Test roundtrip with padding/alignment using sharded inputs.
With uniform merging probs, should recover original input.
Tests both forward pass and backward pass.
"""
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate inputs
key
,
inp_key
=
jax
.
random
.
split
(
key
,
2
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
topk
,
key
)
# Uniform merging probs
uniform_merging_probs
=
routing_map
.
astype
(
dtype
)
/
jnp
.
maximum
(
jnp
.
sum
(
routing_map
,
axis
=
1
,
keepdims
=
True
),
1.0
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
dp_axis
=
mesh_resource
.
dp_resource
sharded_pspec
=
PartitionSpec
(
dp_axis
,
None
)
num_dp_devices
=
mesh
.
shape
[
dp_axis
]
if
dp_axis
else
1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens
=
self
.
compute_padded_output_size
(
num_tokens
,
num_experts
,
topk
,
align_size
,
num_dp_devices
)
with
mesh
:
inp_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
routing_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
merging_sharding
=
NamedSharding
(
mesh
,
sharded_pspec
)
inp_sharded
=
jax
.
device_put
(
inp
,
inp_sharding
)
routing_sharded
=
jax
.
device_put
(
routing_map
,
routing_sharding
)
merging_sharded
=
jax
.
device_put
(
uniform_merging_probs
,
merging_sharding
)
# ================================================================
# Forward pass test
# ================================================================
@
jax
.
jit
def
roundtrip_with_padding
(
x
,
rm
,
mprobs
):
dispatched
,
_
,
rid_map
,
pad_offsets
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
align_size
=
align_size
)
return
token_combine
(
dispatched
,
rid_map
,
mprobs
,
pad_offsets
)
roundtrip_out
=
roundtrip_with_padding
(
inp_sharded
,
routing_sharded
,
merging_sharded
)
# Should recover original input
assert_allclose
(
jax
.
device_get
(
roundtrip_out
),
jax
.
device_get
(
inp_sharded
),
dtype
=
dtype
)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def
roundtrip_loss_with_padding
(
x
,
rm
,
mprobs
):
dispatched
,
_
,
rid_map
,
pad_offsets
,
_
=
token_dispatch
(
x
,
rm
,
global_num_out_tokens
,
align_size
=
align_size
)
combined
=
token_combine
(
dispatched
,
rid_map
,
mprobs
,
pad_offsets
)
return
jnp
.
sum
(
combined
**
2
)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn
=
jax
.
jit
(
jax
.
grad
(
roundtrip_loss_with_padding
,
argnums
=
0
))
computed_grad
=
grad_fn
(
inp_sharded
,
routing_sharded
,
merging_sharded
)
expected_grad
=
2.0
*
inp_sharded
assert_allclose
(
jax
.
device_get
(
computed_grad
),
jax
.
device_get
(
expected_grad
),
dtype
=
dtype
)
tests/jax/test_distributed_softmax.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -16,7 +16,7 @@ from distributed_test_base import generate_configs, generate_collectives_count
from
distributed_test_base
import
compare_ops
from
utils
import
make_causal_mask
,
make_self_mask
from
transformer_engine.jax
import
autocast
from
transformer_engine.jax.softmax
import
SoftmaxType
,
softmax
from
transformer_engine.jax.softmax
import
Softmax
Fusion
Type
,
softmax
DTYPES
=
[
jnp
.
float16
,
jnp
.
bfloat16
]
...
...
@@ -29,12 +29,12 @@ class TestDistributedSoftmax:
return
generate_collectives_count
(
allreduce
=
all_reduce_loss_bytes
,
allgather
=
0
,
other
=
0
)
def
generate_inputs
(
self
,
shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
self
,
shape
,
mesh_resource
,
softmax_
fusion_
type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
):
batch
,
_
,
sqelen
,
_
=
shape
x
=
random
.
normal
(
random
.
PRNGKey
(
1124
),
shape
,
dtype
=
dtype
)
if
softmax_type
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_
fusion_
type
==
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
mask
=
make_causal_mask
(
batch
,
sqelen
)
else
:
mask
=
make_self_mask
(
1
if
broadcast_batch_mask
else
batch
,
sqelen
)
...
...
@@ -56,8 +56,10 @@ class TestDistributedSoftmax:
return
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
@
staticmethod
def
target_func
(
x
,
mask
,
scale_factor
=
1.0
,
softmax_type
=
SoftmaxType
.
SCALED
):
return
jnp
.
mean
(
softmax
(
x
,
mask
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
))
def
target_func
(
x
,
mask
,
scale_factor
=
1.0
,
softmax_fusion_type
=
SoftmaxFusionType
.
SCALED
):
return
jnp
.
mean
(
softmax
(
x
,
mask
,
scale_factor
=
scale_factor
,
softmax_fusion_type
=
softmax_fusion_type
)
)
@
staticmethod
def
ref_func
(
x
,
mask
,
scale_factor
=
1.0
,
dtype
=
jnp
.
float16
):
...
...
@@ -80,24 +82,29 @@ class TestDistributedSoftmax:
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
softmax_
fusion_
type
,
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
use_shardy
,
):
if
broadcast_batch_mask
and
softmax_type
!=
SoftmaxType
.
SCALED_MASKED
:
if
broadcast_batch_mask
and
softmax_
fusion_
type
!=
Softmax
Fusion
Type
.
SCALED_MASKED
:
pytest
.
skip
(
"Softmax type has no mask."
)
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
target_func
=
partial
(
self
.
target_func
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
self
.
target_func
,
scale_factor
=
scale_factor
,
softmax_
fusion_
type
=
softmax_
fusion_
type
)
ref_func
=
partial
(
self
.
ref_func
,
scale_factor
=
scale_factor
,
dtype
=
dtype
)
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
=
self
.
generate_inputs
(
data_shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
data_shape
,
mesh_resource
,
softmax_fusion_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
)
collective_count_ref
=
self
.
generate_collectives_count_ref
()
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
...
...
@@ -139,8 +146,12 @@ class TestDistributedSoftmax:
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
12
,
128
,
128
],
[
8
,
8
,
1024
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
,
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
],
"softmax_fusion_type"
,
[
SoftmaxFusionType
.
SCALED
,
SoftmaxFusionType
.
SCALED_MASKED
,
SoftmaxFusionType
.
SCALED_UPPER_TRIANG_MASKED
,
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
1.0
,
3.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -153,7 +164,7 @@ class TestDistributedSoftmax:
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
softmax_
fusion_
type
,
scale_factor
,
dtype
,
bad_sharding
,
...
...
@@ -165,7 +176,7 @@ class TestDistributedSoftmax:
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
softmax_
fusion_
type
,
scale_factor
,
dtype
,
bad_sharding
,
...
...
@@ -174,7 +185,9 @@ class TestDistributedSoftmax:
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
])
@
pytest
.
mark
.
parametrize
(
"softmax_fusion_type"
,
[
SoftmaxFusionType
.
SCALED
,
SoftmaxFusionType
.
SCALED_MASKED
]
)
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"broadcast_batch_mask"
,
[
False
,
True
])
def
test_softmax_gspmd
(
...
...
@@ -183,7 +196,7 @@ class TestDistributedSoftmax:
mesh_shape
,
mesh_axes
,
mesh_resource
,
softmax_type
,
softmax_
fusion_
type
,
bad_sharding
,
broadcast_batch_mask
,
):
...
...
@@ -193,7 +206,7 @@ class TestDistributedSoftmax:
mesh_axes
,
mesh_resource
,
data_shape
=
[
32
,
12
,
128
,
128
],
softmax_type
=
softmax_type
,
softmax_
fusion_
type
=
softmax_
fusion_
type
,
scale_factor
=
1.0
,
dtype
=
DTYPES
[
0
],
bad_sharding
=
bad_sharding
,
...
...
tests/jax/test_functions.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_fused_attn.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for fused attention"""
...
...
@@ -27,6 +27,7 @@ from transformer_engine.jax.sharding import MeshResource
from
transformer_engine.jax.attention
import
(
AttnBiasType
,
AttnMaskType
,
AttnSoftmaxType
,
QKVLayout
,
QKVFormat
,
reorder_causal_load_balancing
,
...
...
@@ -59,14 +60,16 @@ def init():
yield
@
partial
(
jax
.
jit
,
static_argnums
=
(
5
,
6
,
7
,
9
))
@
partial
(
jax
.
jit
,
static_argnums
=
(
6
,
7
,
8
,
9
,
11
))
def
general_dot_product_attention
(
query
:
ArrayLike
,
key
:
ArrayLike
,
value
:
ArrayLike
,
softmax_offset
:
Optional
[
ArrayLike
],
bias
:
ArrayLike
,
mask
:
ArrayLike
,
deterministic
:
bool
,
softmax_type
:
AttnSoftmaxType
,
scale_factor
:
float
,
dropout_rate
:
float
,
dropout_rng
:
ArrayLike
,
...
...
@@ -99,7 +102,25 @@ def general_dot_product_attention(
mask
=
jnp
.
expand_dims
(
mask
,
axis
=-
3
)
logits
=
jnp
.
where
(
mask
,
jnp
.
finfo
(
dtype
).
min
,
logits
)
softmax_out
=
jax
.
nn
.
softmax
(
logits
).
astype
(
dtype
)
match
softmax_type
:
case
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_out
=
jax
.
nn
.
softmax
(
logits
).
astype
(
dtype
)
case
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
# Append a zero logit, apply standard softmax, then remove last column
zero_logit
=
jnp
.
zeros
(
logits
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
logits
.
dtype
)
logits_with_extra
=
jnp
.
concatenate
([
logits
,
zero_logit
],
axis
=-
1
)
softmax_with_extra
=
jax
.
nn
.
softmax
(
logits_with_extra
,
axis
=-
1
)
softmax_out
=
softmax_with_extra
[...,
:
-
1
].
astype
(
dtype
)
case
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
# Append learnable offset logit, apply standard softmax, then remove last column
learnable_logit
=
softmax_offset
.
reshape
(
1
,
h_kv
,
num_groups
,
1
,
1
)
learnable_logit
=
jnp
.
broadcast_to
(
learnable_logit
,
logits
.
shape
[:
-
1
]
+
(
1
,))
logits_with_extra
=
jnp
.
concatenate
([
logits
,
learnable_logit
],
axis
=-
1
)
softmax_with_extra
=
jax
.
nn
.
softmax
(
logits_with_extra
,
axis
=-
1
)
softmax_out
=
softmax_with_extra
[...,
:
-
1
].
astype
(
dtype
)
case
_
:
raise
NotImplementedError
(
f
"Unknown
{
softmax_type
=
}
"
)
if
not
deterministic
and
dropout_rate
>
0.0
:
keep_prob
=
1.0
-
dropout_rate
...
...
@@ -238,7 +259,7 @@ def _split_valid_and_invalid(primitive, reference, pad):
return
primitive_valid
,
primitive_invalid
,
reference_valid
,
reference_invalid
def
jax_dpa
(
query
,
key
,
value
,
bias
,
mask
,
dropout_rng
,
**
kwargs
):
def
jax_dpa
(
query
,
key
,
value
,
bias
,
softmax_offset
,
mask
,
dropout_rng
,
**
kwargs
):
"""
JAX native dot product attention implementation
"""
...
...
@@ -246,11 +267,13 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
query
,
key
,
value
,
softmax_offset
,
bias
,
mask
,
deterministic
=
not
kwargs
[
"is_training"
],
scale_factor
=
kwargs
[
"scaling_factor"
],
dropout_rate
=
kwargs
[
"dropout_probability"
],
softmax_type
=
kwargs
[
"softmax_type"
],
dropout_rng
=
dropout_rng
,
dtype
=
jnp
.
float32
,
)
...
...
@@ -262,6 +285,7 @@ def customcall_fused_dpa(
key
,
value
,
bias
,
softmax_offset
,
sequence_descriptor
,
dropout_rng
,
**
kwargs
,
...
...
@@ -283,9 +307,9 @@ def customcall_fused_dpa(
qkv_args
=
(
query
,
key
,
value
)
case
_
:
raise
ValueError
(
f
"Unsupported
{
qkv_layout
=
}
"
)
return
fused_attn
(
qkv_args
,
bias
,
sequence_descriptor
,
dropout_rng
,
**
kwargs
).
astype
(
q
uery
.
dtype
)
return
fused_attn
(
q
kv_args
,
bias
,
sequence_descriptor
,
dropout_rng
,
softmax_offset
=
softmax_offset
,
**
kwargs
)
.
astype
(
query
.
dtype
)
class
BiasShape
(
Enum
):
...
...
@@ -320,6 +344,7 @@ class FusedAttnRunner:
head_dim_v
:
int
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
dropout_prob
:
float
dtype
:
DTypeLike
is_training
:
bool
...
...
@@ -327,6 +352,8 @@ class FusedAttnRunner:
bias_shape
:
BiasShape
window_size
:
Tuple
[
int
,
int
]
seq_desc_format
:
SeqDescFormat
stripe_size
:
int
|
None
=
None
num_segments_per_seq
:
int
|
None
=
None
# Specifies sharding resources for distributed tests
number_of_devices
:
int
=
1
...
...
@@ -341,6 +368,14 @@ class FusedAttnRunner:
# dictionary of expected collective comm bytes
coll_count_ref
:
Optional
[
Dict
[
str
,
int
]]
=
None
def
__post_init__
(
self
):
# Reset defaults for num_segments_per_seq if not explicitly passed
if
self
.
num_segments_per_seq
is
None
:
if
self
.
qkv_layout
.
is_thd
():
self
.
num_segments_per_seq
=
2
else
:
self
.
num_segments_per_seq
=
1
# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def
_get_max_segments_per_sequence
(
self
):
...
...
@@ -402,6 +437,7 @@ class FusedAttnRunner:
self
.
qkv_layout
,
self
.
attn_bias_type
,
self
.
attn_mask_type
,
self
.
softmax_type
,
self
.
dropout_prob
,
self
.
num_heads_q
,
self
.
num_heads_kv
,
...
...
@@ -439,7 +475,7 @@ class FusedAttnRunner:
self
.
tp_size
=
self
.
mesh
.
shape
.
get
(
self
.
mesh_resource
.
tpsp_resource
,
1
)
key
=
jax
.
random
.
PRNGKey
(
0
)
q_key
,
k_key
,
v_key
,
bias_key
,
dropout_key
=
jax
.
random
.
split
(
key
,
5
)
q_key
,
k_key
,
v_key
,
bias_key
,
dropout_key
,
softmax_key
=
jax
.
random
.
split
(
key
,
6
)
q_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_q
,
self
.
num_heads_q
,
self
.
head_dim_qk
)
k_shape
=
(
self
.
batch_size
,
self
.
max_seqlen_kv
,
self
.
num_heads_kv
,
self
.
head_dim_qk
)
...
...
@@ -490,6 +526,13 @@ class FusedAttnRunner:
else
:
pad_ratio
=
0.0
if
self
.
softmax_type
==
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
self
.
softmax_offset
=
jax
.
random
.
uniform
(
softmax_key
,
(
1
,
self
.
num_heads_q
,
1
,
1
),
jnp
.
float32
,
-
1.0
)
else
:
self
.
softmax_offset
=
None
def
gen_valid
(
bs
,
max_seqlen
,
pad_ratio
):
pad_len
=
int
(
max_seqlen
*
pad_ratio
)
valid_len
=
max_seqlen
-
pad_len
...
...
@@ -544,7 +587,6 @@ class FusedAttnRunner:
return
segment_ids
,
segment_pos
,
segment_pad
if
self
.
qkv_layout
.
is_thd
():
self
.
num_segments_per_seq
=
2
self
.
segment_ids_q
,
self
.
segment_pos_q
,
self
.
pad_q
=
generate_random_segment_ids
(
self
.
batch_size
,
self
.
max_seqlen_q
,
self
.
num_segments_per_seq
,
seed
=
42
)
...
...
@@ -570,7 +612,6 @@ class FusedAttnRunner:
)
self
.
seqlens_kv
,
self
.
offsets_kv
=
get_seqlens_and_offsets
(
self
.
segment_ids_kv
)
else
:
self
.
num_segments_per_seq
=
1
self
.
segment_ids_q
,
self
.
pad_q
=
gen_valid
(
self
.
batch_size
,
self
.
max_seqlen_q
,
pad_ratio
)
...
...
@@ -602,12 +643,14 @@ class FusedAttnRunner:
strategy
=
reorder_strategy
,
cp_size
=
self
.
cp_size
,
seq_dim
=
seq_dim
,
stripe_size
=
self
.
stripe_size
,
)
self
.
cp_inverse_reorder_fn
=
partial
(
inverse_reorder_causal_load_balancing
,
strategy
=
reorder_strategy
,
cp_size
=
self
.
cp_size
,
seq_dim
=
seq_dim
,
stripe_size
=
self
.
stripe_size
,
)
else
:
# no-ops for non cp or non load balanced
...
...
@@ -625,14 +668,24 @@ class FusedAttnRunner:
(
self
.
offsets_q
,
self
.
offsets_kv
),
)
case
SeqDescFormat
.
SegmentIDs
:
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
# if no CP and load balancing, else explicitly pass the segment_pos
self
.
sequence_desciptor
=
SequenceDescriptor
.
from_segment_ids_and_pos
(
(
self
.
cp_reorder_fn
(
self
.
segment_ids_q
),
self
.
cp_reorder_fn
(
self
.
segment_ids_kv
),
),
(
self
.
cp_reorder_fn
(
self
.
segment_pos_q
),
self
.
cp_reorder_fn
(
self
.
segment_pos_kv
),
(
self
.
cp_reorder_fn
(
self
.
segment_pos_q
),
self
.
cp_reorder_fn
(
self
.
segment_pos_kv
),
)
if
self
.
cp_size
>
1
and
self
.
cp_load_balanced
else
None
),
is_thd
=
self
.
qkv_layout
.
is_thd
(),
is_segment_ids_reordered
=
(
True
if
self
.
cp_size
>
1
and
self
.
cp_load_balanced
else
False
),
)
case
_
:
...
...
@@ -661,6 +714,8 @@ class FusedAttnRunner:
self
.
sequence_desciptor
=
SequenceDescriptor
.
from_segment_ids_and_pos
(
(
self
.
segment_ids_q
,
self
.
segment_ids_kv
),
None
,
is_thd
=
self
.
qkv_layout
.
is_thd
(),
is_segment_ids_reordered
=
False
,
)
case
_
:
raise
ValueError
(
f
"Unknown
{
self
.
seq_desc_format
=
}
"
)
...
...
@@ -713,6 +768,16 @@ class FusedAttnRunner:
self
.
bias_pspec
=
PartitionSpec
()
self
.
bias_sharding
=
NamedSharding
(
self
.
mesh
,
self
.
bias_pspec
)
# Softmax offset sharding (1, num_heads, 1, 1)
# Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
head_resource
=
(
self
.
mesh_resource
.
tpsp_resource
if
self
.
mesh_resource
.
tpsp_resource
is
not
None
else
self
.
mesh_resource
.
tp_resource
)
self
.
softmax_offset_pspec
=
PartitionSpec
(
None
,
head_resource
,
None
,
None
)
self
.
softmax_offset_sharding
=
NamedSharding
(
self
.
mesh
,
self
.
softmax_offset_pspec
)
self
.
dropout_rng_pspec
=
PartitionSpec
(
None
,
)
...
...
@@ -728,11 +793,11 @@ class FusedAttnRunner:
def
test_forward
(
self
):
"""
Test forward with
out
JIT
Test forward with JIT
ted primitive and unJITted reference
"""
self
.
_setup_inputs
()
args
=
[
self
.
q
,
self
.
k
,
self
.
v
,
self
.
bias
,
self
.
mask
,
self
.
dropout_rng
]
args
=
[
self
.
q
,
self
.
k
,
self
.
v
,
self
.
bias
,
self
.
softmax_offset
,
self
.
mask
,
self
.
dropout_rng
]
customcall_args
=
[
# Put test data onto each GPU for distributed.
...
...
@@ -742,12 +807,14 @@ class FusedAttnRunner:
jax
.
device_put
(
self
.
cp_reorder_fn
(
self
.
k
),
self
.
qkvo_sharding
),
jax
.
device_put
(
self
.
cp_reorder_fn
(
self
.
v
),
self
.
qkvo_sharding
),
jax
.
device_put
(
self
.
bias
,
self
.
bias_sharding
),
jax
.
device_put
(
self
.
softmax_offset
,
self
.
softmax_offset_sharding
),
jax
.
device_put
(
self
.
sequence_desciptor
,
self
.
seq_desc_sharding
),
jax
.
device_put
(
self
.
dropout_rng
,
self
.
dropout_rng_sharding
),
]
kwargs
=
{
"attn_bias_type"
:
self
.
attn_bias_type
,
"attn_mask_type"
:
self
.
attn_mask_type
,
"softmax_type"
:
self
.
softmax_type
,
"scaling_factor"
:
self
.
scaling_factor
,
"dropout_probability"
:
self
.
dropout_prob
,
"is_training"
:
self
.
is_training
,
...
...
@@ -756,6 +823,7 @@ class FusedAttnRunner:
"window_size"
:
self
.
window_size
,
"context_parallel_strategy"
:
self
.
cp_strategy
,
"context_parallel_causal_load_balanced"
:
self
.
cp_load_balanced
,
"stripe_size"
:
self
.
stripe_size
,
}
customcall_fused_dpa_jit
=
jit
(
...
...
@@ -766,6 +834,7 @@ class FusedAttnRunner:
self
.
qkvo_sharding
,
self
.
qkvo_sharding
,
self
.
bias_sharding
,
self
.
softmax_offset_sharding
,
self
.
seq_desc_sharding
,
self
.
dropout_rng_sharding
,
],
...
...
@@ -826,7 +895,7 @@ class FusedAttnRunner:
jnp
.
mean
(
ret_valid
.
astype
(
jnp
.
float32
),
dtype
=
jnp
.
float32
)
*
gradient_multiplier
).
astype
(
self
.
dtype
)
args
=
[
self
.
q
,
self
.
k
,
self
.
v
,
self
.
bias
,
self
.
mask
,
self
.
dropout_rng
]
args
=
[
self
.
q
,
self
.
k
,
self
.
v
,
self
.
bias
,
self
.
softmax_offset
,
self
.
mask
,
self
.
dropout_rng
]
customcall_args
=
[
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
# THD params once we support those features on CP.
...
...
@@ -834,12 +903,14 @@ class FusedAttnRunner:
jax
.
device_put
(
self
.
cp_reorder_fn
(
self
.
k
),
self
.
qkvo_sharding
),
jax
.
device_put
(
self
.
cp_reorder_fn
(
self
.
v
),
self
.
qkvo_sharding
),
jax
.
device_put
(
self
.
bias
,
self
.
bias_sharding
),
jax
.
device_put
(
self
.
softmax_offset
,
self
.
softmax_offset_sharding
),
jax
.
device_put
(
self
.
sequence_desciptor
,
self
.
seq_desc_sharding
),
jax
.
device_put
(
self
.
dropout_rng
,
self
.
dropout_rng_sharding
),
]
kwargs
=
{
"attn_bias_type"
:
self
.
attn_bias_type
,
"attn_mask_type"
:
self
.
attn_mask_type
,
"softmax_type"
:
self
.
softmax_type
,
"scaling_factor"
:
self
.
scaling_factor
,
"dropout_probability"
:
self
.
dropout_prob
,
"is_training"
:
self
.
is_training
,
...
...
@@ -848,6 +919,7 @@ class FusedAttnRunner:
"window_size"
:
self
.
window_size
,
"context_parallel_strategy"
:
self
.
cp_strategy
,
"context_parallel_causal_load_balanced"
:
self
.
cp_load_balanced
,
"stripe_size"
:
self
.
stripe_size
,
}
# We can compute dBias only for the [1, h, s, s] layout
...
...
@@ -866,8 +938,16 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive
=
jit
(
value_and_grad
(
lambda
q
,
k
,
v
,
bias
,
*
args
:
grad_func
(
customcall_fused_dpa
,
q
,
k
,
v
,
bias
,
*
args
,
cp_reverse_out
=
True
,
**
kwargs
lambda
q
,
k
,
v
,
bias
,
softmax_offset
,
*
args
:
grad_func
(
customcall_fused_dpa
,
q
,
k
,
v
,
bias
,
softmax_offset
,
*
args
,
cp_reverse_out
=
True
,
**
kwargs
,
),
arg_nums
,
),
...
...
@@ -876,6 +956,7 @@ class FusedAttnRunner:
self
.
qkvo_sharding
,
self
.
qkvo_sharding
,
self
.
bias_sharding
,
self
.
softmax_offset_sharding
,
self
.
seq_desc_sharding
,
self
.
dropout_rng_sharding
,
),
...
...
@@ -883,7 +964,9 @@ class FusedAttnRunner:
)
jitted_reference
=
jit
(
value_and_grad
(
lambda
q
,
k
,
v
,
bias
,
*
args
:
grad_func
(
jax_dpa
,
q
,
k
,
v
,
bias
,
*
args
,
**
kwargs
),
lambda
q
,
k
,
v
,
bias
,
softmax_offset
,
*
args
:
grad_func
(
jax_dpa
,
q
,
k
,
v
,
bias
,
softmax_offset
,
*
args
,
**
kwargs
),
arg_nums
,
)
)
...
...
@@ -977,41 +1060,78 @@ class FusedAttnRunner:
],
)
@
pytest
.
mark
.
parametrize
(
"
qkv_layout
"
,
"
softmax_type
"
,
[
pytest
.
param
(
QKVLayout
.
BS3HD
,
id
=
"QKV_PACKED"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
id
=
"KV_PACKED"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
id
=
"SEPARATE"
),
pytest
.
param
(
QKVLayout
.
T3HD
,
id
=
"RAGGED_QKV_PACKED"
),
pytest
.
param
(
QKVLayout
.
THD_T2HD
,
id
=
"RAGGED_KV_PACKED"
),
pytest
.
param
(
QKVLayout
.
THD_THD_THD
,
id
=
"RAGGED_SEPARATE"
),
pytest
.
param
(
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
id
=
"VANILLA_SOFTMAX"
),
pytest
.
param
(
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
,
id
=
"OFF_BY_ONE_SOFTMAX"
),
pytest
.
param
(
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
,
id
=
"LEARNABLE_SOFTMAX"
),
],
)
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype"
,
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype
, qkv_layout
"
,
[
# large data size + bf16 + qkv packed
pytest
.
param
(
2
,
2048
,
2048
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-12-64-64-BF16-SELF"
2
,
2048
,
2048
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
QKVLayout
.
BS3HD
,
id
=
"2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED"
,
),
pytest
.
param
(
2
,
512
,
1024
,
2048
,
2048
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-512-1024-12-12-64-64-BF16-CROSS"
,
QKVLayout
.
T3HD
,
id
=
"2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED"
,
),
# mid data size + bf16 + cross attn + kv packed
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
64
,
64
,
jnp
.
bfloat16
,
id
=
"2-2048-2048-12-6-64-64-BF16-GQA"
2
,
512
,
1024
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
QKVLayout
.
BSHD_BS2HD
,
id
=
"2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED"
,
),
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
64
,
jnp
.
float16
,
id
=
"4-128-128-16-16-64-64-FP16-SELF"
2
,
512
,
1024
,
12
,
12
,
64
,
64
,
jnp
.
bfloat16
,
QKVLayout
.
THD_T2HD
,
id
=
"2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED"
,
),
# large data size + bf16 + cross attn + diff hidden v dim + qkv separate
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
32
,
jnp
.
float16
,
id
=
"4-128-128-16-16-64-32-FP16-SELF"
2
,
2048
,
1024
,
12
,
12
,
64
,
32
,
jnp
.
bfloat16
,
QKVLayout
.
BSHD_BSHD_BSHD
,
id
=
"2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE"
,
),
pytest
.
param
(
2
,
...
...
@@ -1022,10 +1142,108 @@ class FusedAttnRunner:
64
,
32
,
jnp
.
bfloat16
,
id
=
"2-2048-1024-12-12-64-32-BF16-CROSS"
,
QKVLayout
.
THD_THD_THD
,
id
=
"2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE"
,
),
# large data size + bf16 + gqa + kv packed
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
64
,
64
,
jnp
.
bfloat16
,
QKVLayout
.
BSHD_BS2HD
,
id
=
"2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED"
,
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
64
,
64
,
jnp
.
bfloat16
,
QKVLayout
.
THD_T2HD
,
id
=
"2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED"
,
),
# small data size + fp16 + diff hidden v dim + qkv packed
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
32
,
jnp
.
float16
,
QKVLayout
.
BS3HD
,
id
=
"4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED"
,
),
pytest
.
param
(
2
,
2048
,
2048
,
12
,
6
,
128
,
64
,
jnp
.
float16
,
id
=
"2-2048-2048-12-6-128-64-FP16-GQA"
4
,
128
,
128
,
16
,
16
,
64
,
32
,
jnp
.
float16
,
QKVLayout
.
T3HD
,
id
=
"4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED"
,
),
# small data size + fp16 + kv packed
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
64
,
jnp
.
float16
,
QKVLayout
.
BSHD_BS2HD
,
id
=
"4-128-128-16-16-64-64-FP16-SELF-KV_PACKED"
,
),
pytest
.
param
(
4
,
128
,
128
,
16
,
16
,
64
,
64
,
jnp
.
float16
,
QKVLayout
.
THD_T2HD
,
id
=
"4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED"
,
),
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest
.
param
(
2
,
1024
,
2048
,
12
,
6
,
128
,
64
,
jnp
.
float16
,
QKVLayout
.
BSHD_BSHD_BSHD
,
id
=
"2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE"
,
),
pytest
.
param
(
2
,
1024
,
2048
,
12
,
6
,
128
,
64
,
jnp
.
float16
,
QKVLayout
.
THD_THD_THD
,
id
=
"2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE"
,
),
],
)
...
...
@@ -1084,6 +1302,7 @@ class TestFusedAttn:
d_v
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_prob
,
dtype
,
is_training
,
...
...
@@ -1110,6 +1329,7 @@ class TestFusedAttn:
d_v
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_prob
,
dtype
,
is_training
,
...
...
@@ -1138,6 +1358,7 @@ class TestFusedAttn:
d_v
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_prob
,
dtype
,
qkv_layout
,
...
...
@@ -1161,6 +1382,7 @@ class TestFusedAttn:
d_v
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_prob
,
dtype
,
True
,
...
...
tests/jax/test_layer.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test transformer_engine.jax.flax.TransformerLayer"""
...
...
@@ -83,6 +83,7 @@ _KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS
=
"use_bias"
_KEY_OF_RELATIVE_EMBEDDING
=
"enable_relative_embedding"
_KEY_OF_WINDOW_SIZE
=
"window_size"
_KEY_OF_SOFTMAX_TYPE
=
"softmax_type"
BASE_ATTRS
=
{
_KEY_OF_TRANSPOSE_BS
:
True
,
...
...
@@ -276,6 +277,14 @@ ATTRS = [
_KEY_OF_RELATIVE_EMBEDDING
:
True
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"post_scale_bias"
,
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE
:
"off_by_one"
,
},
# attrs31
{
_KEY_OF_SOFTMAX_TYPE
:
"learnable"
,
},
]
ATTRS
=
[{
**
BASE_ATTRS
,
**
attr
}
for
attr
in
ATTRS
]
...
...
@@ -418,6 +427,12 @@ class EncoderRunner(BaseRunner):
"attention/qkv/ln_bias"
:
"pre_attention_layer_norm/ln_bias"
,
"attention/query/scale"
:
"pre_attention_layer_norm/scale"
,
"attention/query/ln_bias"
:
"pre_attention_layer_norm/ln_bias"
,
"attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset"
:
(
"attention/DotProductAttention_0/softmax_offset"
),
"attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset"
:
(
"attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel"
:
"mlp/wi/kernel"
,
"mlp/wi_bias"
:
"mlp/wi/bias"
,
"mlp/wo_kernel"
:
"mlp/wo/kernel"
,
...
...
@@ -463,10 +478,22 @@ class DecoderRunner(BaseRunner):
"encoder_decoder_attention/qkv/ln_bias"
:
"pre_cross_attention_layer_norm/ln_bias"
,
"encoder_decoder_attention/query/scale"
:
"pre_cross_attention_layer_norm/scale"
,
"encoder_decoder_attention/query/ln_bias"
:
"pre_cross_attention_layer_norm/ln_bias"
,
"encoder_decoder_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset"
:
(
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"encoder_decoder_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset"
:
(
"encoder_decoder_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/qkv/scale"
:
"pre_self_attention_layer_norm/scale"
,
"self_attention/qkv/ln_bias"
:
"pre_self_attention_layer_norm/ln_bias"
,
"self_attention/query/scale"
:
"pre_self_attention_layer_norm/scale"
,
"self_attention/query/ln_bias"
:
"pre_self_attention_layer_norm/ln_bias"
,
"self_attention/DotProductAttention_0/_UnfusedDotProductAttention_0/softmax_offset"
:
(
"self_attention/DotProductAttention_0/softmax_offset"
),
"self_attention/DotProductAttention_0/_FusedDotProductAttention_0/softmax_offset"
:
(
"self_attention/DotProductAttention_0/softmax_offset"
),
"mlp/wi_kernel"
:
"mlp/wi/kernel"
,
"mlp/wi_bias"
:
"mlp/wi/bias"
,
"mlp/wo_kernel"
:
"mlp/wo/kernel"
,
...
...
@@ -534,7 +561,7 @@ class BaseTester:
"""Test forward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
QUANTIZE_RECIPES
)
...
...
@@ -542,7 +569,7 @@ class BaseTester:
"""Test backward with fp8 enabled"""
# Empty MeshResource is used as we are running on a single device
with
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
mesh_resource
=
MeshResource
()):
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
class
TestEncoderLayer
(
BaseTester
):
...
...
tests/jax/test_misc.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_multi_process_distributed_grouped_gemm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_permutation.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for permutation Triton kernels and high-level APIs"""
import
functools
import
jax
import
jax.numpy
as
jnp
import
pytest
# High-level API with VJP support
from
transformer_engine.jax.permutation
import
(
token_dispatch
,
token_combine
,
sort_chunks_by_index
,
)
from
utils
import
assert_allclose
,
pytest_parametrize_wrapper
ALL_DISPATCH_COMBINE_CASES
=
[
(
128
,
5
,
128
,
3
),
(
1024
,
8
,
128
,
8
),
(
4096
,
32
,
1280
,
2
),
(
4096
,
256
,
4096
,
6
),
]
DISPATCH_COMBINE_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_CASES
[
0
:
2
],
"L2"
:
ALL_DISPATCH_COMBINE_CASES
,
}
ALL_SORT_CHUNKS_CASES
=
[
(
8
,
4096
,
1280
),
(
64
,
4096
,
4096
),
(
256
,
4096
,
9216
),
]
SORT_CHUNKS_CASES
=
{
"L0"
:
ALL_SORT_CHUNKS_CASES
[
0
:
2
],
"L2"
:
ALL_SORT_CHUNKS_CASES
,
}
ALL_DISPATCH_COMBINE_PADDING_CASES
=
[
(
128
,
5
,
128
,
3
,
8
),
(
1024
,
8
,
128
,
8
,
16
),
(
4096
,
32
,
1280
,
2
,
128
),
(
4096
,
256
,
4096
,
6
,
16
),
]
DISPATCH_COMBINE_PADDING_CASES
=
{
"L0"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
[
0
:
2
],
"L2"
:
ALL_DISPATCH_COMBINE_PADDING_CASES
,
}
ALL_DTYPES
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
DTYPES
=
{
"L0"
:
ALL_DTYPES
,
"L2"
:
ALL_DTYPES
,
}
ALL_WITH_PROBS
=
[
True
,
False
]
WITH_PROBS
=
{
"L0"
:
[
True
],
"L2"
:
ALL_WITH_PROBS
,
}
def
reference_make_row_id_map
(
routing_map
:
jnp
.
ndarray
,
)
->
jnp
.
ndarray
:
"""
Vectorized reference implementation of make_row_id_map using JAX primitives.
Parameters
----------
routing_map : jnp.ndarray
Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
are routed to which tokens (1 = routed, 0 = not routed).
Returns
-------
row_id_map : jnp.ndarray
The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
"""
num_tokens
,
num_experts
=
routing_map
.
shape
# For each expert, compute cumulative sum to get destination indices
cumsum_per_expert
=
jnp
.
cumsum
(
routing_map
,
axis
=
0
)
# Compute total tokens per expert and expert offsets
tokens_per_expert
=
jnp
.
sum
(
routing_map
,
axis
=
0
)
expert_offsets
=
jnp
.
concatenate
(
[
jnp
.
array
([
0
],
dtype
=
jnp
.
int32
),
jnp
.
cumsum
(
tokens_per_expert
)[:
-
1
].
astype
(
jnp
.
int32
)]
)
# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
dest_rows_all
=
(
expert_offsets
[
None
,
:]
+
cumsum_per_expert
-
1
)
*
routing_map
+
(
-
1
)
*
(
1
-
routing_map
)
# Count routed experts per token
n_routed_per_token
=
jnp
.
sum
(
routing_map
,
axis
=
1
)
# For each token, we need to sort by descending dest_row and pack into row_id_map
# Use a large negative value for non-routed experts so they sort to the end
sort_keys
=
jnp
.
where
(
routing_map
==
1
,
-
dest_rows_all
,
jnp
.
iinfo
(
jnp
.
int32
).
max
)
sorted_expert_indices
=
jnp
.
argsort
(
sort_keys
,
axis
=
1
)
# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx
=
jnp
.
broadcast_to
(
jnp
.
arange
(
num_tokens
,
dtype
=
jnp
.
int32
)[:,
None
],
(
num_tokens
,
num_experts
)
)
sorted_dest_rows
=
dest_rows_all
[
token_idx
,
sorted_expert_indices
]
# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
row_id_map
=
jnp
.
concatenate
(
[
sorted_dest_rows
.
astype
(
jnp
.
int32
),
sorted_expert_indices
.
astype
(
jnp
.
int32
),
n_routed_per_token
.
astype
(
jnp
.
int32
)[:,
None
],
],
axis
=
1
,
)
return
row_id_map
def
_reference_permute_impl
(
inp
:
jnp
.
ndarray
,
row_id_map
:
jnp
.
ndarray
,
probs
:
jnp
.
ndarray
,
num_out_tokens
:
int
,
)
->
tuple
:
"""
Vectorized internal helper for reference permutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
probs : jnp.ndarray
The probabilities of the input tensor.
num_out_tokens : int
Number of tokens in the permuted tensor.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray
Permuted probabilities if probs was provided, None otherwise.
"""
num_tokens
,
hidden_size
=
inp
.
shape
num_experts
=
(
row_id_map
.
shape
[
1
]
-
1
)
//
2
# Extract destination rows, expert indices, and n_routed from row_id_map
dest_rows
=
row_id_map
[:,
:
num_experts
]
# [num_tokens, num_experts]
expert_indices
=
row_id_map
[:,
num_experts
:
2
*
num_experts
]
# [num_tokens, num_experts]
n_routed
=
row_id_map
[:,
2
*
num_experts
]
# [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
# (slots beyond n_routed may contain garbage, not -1)
slot_indices
=
jnp
.
arange
(
num_experts
)[
None
,
:]
# [1, num_experts]
valid_mask
=
slot_indices
<
n_routed
[:,
None
]
# [num_tokens, num_experts]
# Flatten for scatter operations
flat_dest_rows
=
dest_rows
.
flatten
()
# [num_tokens * num_experts]
flat_valid_mask
=
valid_mask
.
flatten
()
flat_token_indices
=
jnp
.
repeat
(
jnp
.
arange
(
num_tokens
),
num_experts
)
flat_expert_indices
=
expert_indices
.
flatten
()
# Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped)
# This avoids overwriting valid entries at index 0 with zeros
flat_dest_rows_clamped
=
jnp
.
where
(
flat_valid_mask
,
flat_dest_rows
,
num_out_tokens
)
# Gather input tokens and scatter to output
output
=
jnp
.
zeros
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
)
gathered_inp
=
inp
[
flat_token_indices
]
# [num_tokens * num_experts, hidden_size]
# Use segment_sum-like operation via scatter
# For each valid (token, expert) pair, write inp[token] to output[dest_row]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
output
=
output
.
at
[
flat_dest_rows_clamped
].
set
(
gathered_inp
,
mode
=
"drop"
,
)
permuted_probs
=
None
if
probs
is
not
None
:
permuted_probs
=
jnp
.
zeros
((
num_out_tokens
,),
dtype
=
probs
.
dtype
)
# Vectorized approach: gather probs and scatter to permuted_probs
if
probs
.
ndim
==
1
:
flat_probs
=
probs
[
flat_token_indices
]
else
:
# Clamp invalid expert indices to 0 to avoid wraparound indexing with -1
# The result for invalid entries will be ignored anyway since they target num_out_tokens
# Cast to int32 explicitly for consistent indexing behavior
flat_expert_indices_clamped
=
jnp
.
where
(
flat_valid_mask
,
flat_expert_indices
,
0
).
astype
(
jnp
.
int32
)
flat_probs
=
probs
[
flat_token_indices
.
astype
(
jnp
.
int32
),
flat_expert_indices_clamped
]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
permuted_probs
=
permuted_probs
.
at
[
flat_dest_rows_clamped
.
astype
(
jnp
.
int32
)].
set
(
flat_probs
,
mode
=
"drop"
,
)
return
output
,
permuted_probs
def
_reference_unpermute_impl
(
inp
:
jnp
.
ndarray
,
row_id_map
:
jnp
.
ndarray
,
merging_probs
:
jnp
.
ndarray
,
permuted_probs
:
jnp
.
ndarray
,
)
->
tuple
:
"""
Vectorized internal helper for reference unpermutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
permuted_probs : jnp.ndarray
The permuted probabilities.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
unpermuted_probs : jnp.ndarray
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
num_tokens
=
row_id_map
.
shape
[
0
]
num_experts
=
(
row_id_map
.
shape
[
1
]
-
1
)
//
2
# Extract source rows, expert indices, and n_routed from row_id_map
src_rows
=
row_id_map
[:,
:
num_experts
]
# [num_tokens, num_experts]
expert_indices
=
row_id_map
[:,
num_experts
:
2
*
num_experts
]
# [num_tokens, num_experts]
n_routed
=
row_id_map
[:,
2
*
num_experts
]
# [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
slot_indices
=
jnp
.
arange
(
num_experts
)[
None
,
:]
# [1, num_experts]
valid_mask
=
slot_indices
<
n_routed
[:,
None
]
# [num_tokens, num_experts]
# Clamp invalid src_rows to 0 (they won't be used due to masking)
src_rows_clamped
=
jnp
.
where
(
valid_mask
,
src_rows
,
0
)
# Gather input from permuted positions
gathered_inp
=
inp
[
src_rows_clamped
]
# [num_tokens, num_experts, hidden_size]
# Apply merging probs if provided
if
merging_probs
is
not
None
:
# Gather the merging weights for each (token, expert) pair using advanced indexing
token_idx
=
jnp
.
broadcast_to
(
jnp
.
arange
(
num_tokens
)[:,
None
],
(
num_tokens
,
num_experts
))
weights
=
merging_probs
[
token_idx
,
expert_indices
]
# [num_tokens, num_experts]
gathered_inp
=
gathered_inp
*
weights
[:,
:,
None
]
# Mask out invalid entries and sum across experts
gathered_inp
=
jnp
.
where
(
valid_mask
[:,
:,
None
],
gathered_inp
,
0.0
)
output
=
jnp
.
sum
(
gathered_inp
,
axis
=
1
)
# [num_tokens, hidden_size]
unpermuted_probs
=
None
if
permuted_probs
is
not
None
:
gathered_probs
=
permuted_probs
[
src_rows_clamped
]
# [num_tokens, num_experts]
unpermuted_probs
=
jnp
.
zeros
((
num_tokens
,
num_experts
),
dtype
=
permuted_probs
.
dtype
)
token_idx
=
jnp
.
broadcast_to
(
jnp
.
arange
(
num_tokens
)[:,
None
],
(
num_tokens
,
num_experts
))
unpermuted_probs
=
unpermuted_probs
.
at
[
token_idx
,
expert_indices
].
set
(
jnp
.
where
(
valid_mask
,
gathered_probs
,
0.0
)
)
return
output
,
unpermuted_probs
def
reference_token_dispatch
(
inp
:
jnp
.
ndarray
,
routing_map
:
jnp
.
ndarray
,
num_out_tokens
:
int
,
probs
:
jnp
.
ndarray
=
None
,
)
->
tuple
:
"""
Reference implementation of token_dispatch using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [num_tokens, num_experts].
num_out_tokens : int
Number of tokens in the permuted tensor.
probs : jnp.ndarray, optional
The probabilities of shape [num_tokens, num_experts].
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray or None
Permuted probabilities of shape [num_out_tokens], or None if probs not provided.
row_id_map : jnp.ndarray
The row_id_map for the permutation.
"""
row_id_map
=
reference_make_row_id_map
(
routing_map
)
output
,
permuted_probs
=
_reference_permute_impl
(
inp
,
row_id_map
,
probs
,
num_out_tokens
)
return
output
,
permuted_probs
,
row_id_map
def
reference_token_combine
(
inp
:
jnp
.
ndarray
,
row_id_map
:
jnp
.
ndarray
,
merging_probs
:
jnp
.
ndarray
,
)
->
jnp
.
ndarray
:
"""
Reference implementation of token_combine using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
"""
output
,
_
=
_reference_unpermute_impl
(
inp
,
row_id_map
,
merging_probs
,
None
)
return
output
def
reference_make_chunk_sort_map
(
split_sizes
:
jnp
.
ndarray
,
sorted_indices
:
jnp
.
ndarray
,
num_tokens
:
int
,
)
->
jnp
.
ndarray
:
"""
Vectorized reference implementation of make_chunk_sort_map using JAX primitives.
Parameters
----------
split_sizes : jnp.ndarray
The sizes of the chunks of shape [num_splits,].
sorted_indices : jnp.ndarray
The indices of the sorted chunks of shape [num_splits,].
num_tokens : int
Number of tokens.
Returns
-------
row_id_map : jnp.ndarray
Row ID map for chunk sorting of shape [num_tokens,].
"""
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum
=
jnp
.
concatenate
(
[
jnp
.
array
([
0
],
dtype
=
jnp
.
int32
),
jnp
.
cumsum
(
split_sizes
).
astype
(
jnp
.
int32
)]
)
# Compute destination chunk boundaries based on sorted order
sorted_sizes
=
split_sizes
[
sorted_indices
]
dest_cumsum
=
jnp
.
concatenate
(
[
jnp
.
array
([
0
],
dtype
=
jnp
.
int32
),
jnp
.
cumsum
(
sorted_sizes
).
astype
(
jnp
.
int32
)]
)
# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
inverse_indices
=
jnp
.
argsort
(
sorted_indices
).
astype
(
jnp
.
int32
)
dest_offsets
=
dest_cumsum
[
inverse_indices
]
# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices
=
jnp
.
arange
(
num_tokens
,
dtype
=
jnp
.
int32
)
# chunk_ids[i] = which chunk position i belongs to
chunk_ids
=
jnp
.
searchsorted
(
src_cumsum
[
1
:],
position_indices
,
side
=
"right"
).
astype
(
jnp
.
int32
)
# within_chunk_offset[i] = position i's offset within its chunk
within_chunk_offset
=
position_indices
-
src_cumsum
[
chunk_ids
]
# destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i]
row_id_map
=
dest_offsets
[
chunk_ids
]
+
within_chunk_offset
return
row_id_map
.
astype
(
jnp
.
int32
)
def
reference_sort_chunks_by_map
(
inp
:
jnp
.
ndarray
,
row_id_map
:
jnp
.
ndarray
,
probs
:
jnp
.
ndarray
,
is_forward
:
bool
,
)
->
tuple
:
"""
Vectorized reference implementation of sort_chunks_by_map using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to destination mapping of shape [num_tokens,].
probs : jnp.ndarray
The probabilities.
is_forward : bool
Whether this is forward or backward.
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
permuted_probs : jnp.ndarray
Sorted probabilities if probs was provided, None otherwise.
"""
num_tokens
=
inp
.
shape
[
0
]
hidden_size
=
inp
.
shape
[
1
]
if
is_forward
:
# Forward: scatter inp[src] to output[dest] where dest = row_id_map[src]
output
=
jnp
.
zeros
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
)
output
=
output
.
at
[
row_id_map
].
set
(
inp
)
if
probs
is
not
None
:
permuted_probs
=
jnp
.
zeros
((
num_tokens
,),
dtype
=
probs
.
dtype
)
permuted_probs
=
permuted_probs
.
at
[
row_id_map
].
set
(
probs
)
else
:
permuted_probs
=
None
else
:
# Backward: gather output[dest] = inp[src] where src = row_id_map[dest]
output
=
inp
[
row_id_map
]
if
probs
is
not
None
:
permuted_probs
=
probs
[
row_id_map
]
else
:
permuted_probs
=
None
return
output
,
permuted_probs
class
TestHighLevelPermutationAPI
:
"""Test high-level permutation APIs (token_dispatch, token_combine, etc.)
These tests compare the high-level APIs against reference implementations
to verify correctness of both forward and backward passes.
"""
@
staticmethod
def
generate_routing_map
(
num_tokens
:
int
,
num_experts
:
int
,
tokens_per_expert
:
int
=
2
,
key
:
jax
.
Array
=
None
,
):
"""Generate random routing map for testing"""
if
key
is
None
:
key
=
jax
.
random
.
PRNGKey
(
0
)
routing_map
=
jnp
.
zeros
((
num_tokens
,
num_experts
),
dtype
=
jnp
.
int32
)
for
token_idx
in
range
(
num_tokens
):
key
,
subkey
=
jax
.
random
.
split
(
key
)
expert_indices
=
jax
.
random
.
choice
(
subkey
,
num_experts
,
shape
=
(
tokens_per_expert
,),
replace
=
False
)
routing_map
=
routing_map
.
at
[
token_idx
,
expert_indices
].
set
(
1
)
return
routing_map
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,tokens_per_expert"
,
DISPATCH_COMBINE_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"with_probs"
,
WITH_PROBS
)
def
test_token_dispatch
(
self
,
num_tokens
,
num_experts
,
hidden_size
,
tokens_per_expert
,
dtype
,
with_probs
):
"""
Individual test for token_dispatch forward and backward passes.
This test validates dispatch in isolation to catch errors that might be
masked when combined with token_combine in the roundtrip test.
Uses value_and_grad to validate both forward (via loss comparison) and
backward (via gradient comparison) passes against reference implementation.
"""
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate routing map
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
tokens_per_expert
,
key
)
num_out_tokens
=
int
(
jnp
.
sum
(
routing_map
))
# Generate input data
key
,
inp_key
,
prob_key
=
jax
.
random
.
split
(
key
,
3
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs
=
None
if
with_probs
:
probs
=
jax
.
random
.
uniform
(
prob_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
# Generate reference row_id_map for comparison
ref_row_id_map
=
reference_make_row_id_map
(
routing_map
)
# =====================================================================
# Test forward and backward pass using value_and_grad
# (value validates forward, grad validates backward)
# =====================================================================
if
with_probs
:
@
jax
.
jit
def
dispatch_loss
(
x
,
p
):
out
,
perm_probs
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
,
probs
=
p
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
@
jax
.
jit
def
ref_dispatch_loss
(
x
,
p
):
out
,
perm_probs
=
_reference_permute_impl
(
x
,
ref_row_id_map
,
p
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
loss_val
,
(
inp_grad
,
probs_grad
)
=
jax
.
value_and_grad
(
dispatch_loss
,
argnums
=
(
0
,
1
))(
inp
,
probs
)
ref_loss_val
,
(
ref_inp_grad
,
ref_probs_grad
)
=
jax
.
value_and_grad
(
ref_dispatch_loss
,
argnums
=
(
0
,
1
)
)(
inp
,
probs
)
# Validate forward loss matches
assert_allclose
(
loss_val
,
ref_loss_val
,
dtype
=
dtype
)
# Validate gradients
assert_allclose
(
inp_grad
,
ref_inp_grad
,
dtype
=
dtype
)
assert_allclose
(
probs_grad
,
ref_probs_grad
,
dtype
=
dtype
)
else
:
@
jax
.
jit
def
dispatch_loss_no_probs
(
x
):
out
,
_
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
@
jax
.
jit
def
ref_dispatch_loss_no_probs
(
x
):
out
,
_
=
_reference_permute_impl
(
x
,
ref_row_id_map
,
None
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
loss_val
,
inp_grad
=
jax
.
value_and_grad
(
dispatch_loss_no_probs
)(
inp
)
ref_loss_val
,
ref_inp_grad
=
jax
.
value_and_grad
(
ref_dispatch_loss_no_probs
)(
inp
)
# Validate forward loss matches
assert_allclose
(
loss_val
,
ref_loss_val
,
dtype
=
dtype
)
# Validate gradients
assert_allclose
(
inp_grad
,
ref_inp_grad
,
dtype
=
dtype
)
# =========================================================================
# Consolidated dispatch + combine tests
# =========================================================================
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,tokens_per_expert"
,
DISPATCH_COMBINE_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"with_probs"
,
WITH_PROBS
)
def
test_dispatch_and_combine
(
self
,
num_tokens
,
num_experts
,
hidden_size
,
tokens_per_expert
,
dtype
,
with_probs
):
"""
Comprehensive test for token_dispatch and token_combine.
Tests:
1. Dispatch forward pass against reference (element-by-element)
2. Dispatch backward pass against reference
3. Combine forward pass against reference (element-by-element)
4. Combine backward pass against reference
5. Roundtrip: dispatch + combine recovers original input
6. row_id_map n_routed column validation
7. Probs permutation (when with_probs=True)
"""
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate routing map
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
tokens_per_expert
,
key
)
num_out_tokens
=
int
(
jnp
.
sum
(
routing_map
))
# Generate input data
key
,
inp_key
,
prob_key
,
merge_key
=
jax
.
random
.
split
(
key
,
4
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs
=
None
if
with_probs
:
probs
=
jax
.
random
.
uniform
(
prob_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
# Generate merging probs (normalized per token)
merging_probs
=
jax
.
random
.
uniform
(
merge_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
merging_probs
=
merging_probs
*
routing_map
.
astype
(
dtype
)
# Zero out non-routed
merging_probs
=
merging_probs
/
jnp
.
maximum
(
jnp
.
sum
(
merging_probs
,
axis
=
1
,
keepdims
=
True
),
1e-8
)
# =====================================================================
# Test 1: Dispatch forward pass
# =====================================================================
output
,
permuted_probs
,
row_id_map
,
_
,
_
=
token_dispatch
(
inp
,
routing_map
,
num_out_tokens
,
probs
=
probs
)
ref_output
,
ref_permuted_probs
=
_reference_permute_impl
(
inp
,
row_id_map
,
probs
,
num_out_tokens
)
# Validate row_id_map structure: n_routed column should match routing_map sum
n_routed_actual
=
row_id_map
[:,
-
1
]
n_routed_expected
=
jnp
.
sum
(
routing_map
,
axis
=
1
)
assert
jnp
.
array_equal
(
n_routed_actual
,
n_routed_expected
),
"make_row_id_map n_routed column mismatch"
# Compare dispatch output
assert_allclose
(
output
,
ref_output
,
dtype
=
dtype
)
if
with_probs
:
assert_allclose
(
permuted_probs
,
ref_permuted_probs
,
dtype
=
dtype
)
# =====================================================================
# Test 2: Dispatch backward pass
# =====================================================================
if
with_probs
:
@
jax
.
jit
def
dispatch_loss
(
x
,
p
):
out
,
perm_probs
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
,
probs
=
p
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
@
jax
.
jit
def
ref_dispatch_loss
(
x
,
p
):
out
,
perm_probs
=
_reference_permute_impl
(
x
,
row_id_map
,
p
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
_
,
(
inp_grad
,
probs_grad
)
=
jax
.
value_and_grad
(
dispatch_loss
,
argnums
=
(
0
,
1
))(
inp
,
probs
)
_
,
(
ref_inp_grad
,
ref_probs_grad
)
=
jax
.
value_and_grad
(
ref_dispatch_loss
,
argnums
=
(
0
,
1
)
)(
inp
,
probs
)
assert_allclose
(
inp_grad
,
ref_inp_grad
,
dtype
=
dtype
)
assert_allclose
(
probs_grad
,
ref_probs_grad
,
dtype
=
dtype
)
else
:
@
jax
.
jit
def
dispatch_loss_no_probs
(
x
):
out
,
_
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
@
jax
.
jit
def
ref_dispatch_loss_no_probs
(
x
):
out
,
_
=
_reference_permute_impl
(
x
,
row_id_map
,
None
,
num_out_tokens
)
return
jnp
.
sum
(
out
**
2
)
_
,
inp_grad
=
jax
.
value_and_grad
(
dispatch_loss_no_probs
)(
inp
)
_
,
ref_inp_grad
=
jax
.
value_and_grad
(
ref_dispatch_loss_no_probs
)(
inp
)
assert_allclose
(
inp_grad
,
ref_inp_grad
,
dtype
=
dtype
)
# =====================================================================
# Test 3: Combine forward pass
# =====================================================================
combined
=
token_combine
(
output
,
row_id_map
,
merging_probs
)
ref_combined
=
_reference_unpermute_impl
(
output
,
row_id_map
,
merging_probs
,
None
)[
0
]
assert_allclose
(
combined
,
ref_combined
,
dtype
=
dtype
)
# =====================================================================
# Test 4: Combine backward pass
# =====================================================================
@
jax
.
jit
def
combine_loss
(
x
):
return
jnp
.
sum
(
token_combine
(
x
,
row_id_map
,
merging_probs
)
**
2
)
@
jax
.
jit
def
ref_combine_loss
(
x
):
return
jnp
.
sum
(
_reference_unpermute_impl
(
x
,
row_id_map
,
merging_probs
,
None
)[
0
]
**
2
)
_
,
combine_grad
=
jax
.
value_and_grad
(
combine_loss
)(
output
)
_
,
ref_combine_grad
=
jax
.
value_and_grad
(
ref_combine_loss
)(
output
)
assert_allclose
(
combine_grad
,
ref_combine_grad
,
dtype
=
dtype
)
# =====================================================================
# Test 5: Roundtrip (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs
=
routing_map
.
astype
(
dtype
)
/
jnp
.
maximum
(
jnp
.
sum
(
routing_map
,
axis
=
1
,
keepdims
=
True
),
1.0
)
@
jax
.
jit
def
roundtrip
(
x
):
dispatched
,
_
,
rid_map
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
)
return
token_combine
(
dispatched
,
rid_map
,
uniform_merging_probs
)
roundtrip_output
=
roundtrip
(
inp
)
assert_allclose
(
roundtrip_output
,
inp
,
dtype
=
dtype
)
# =========================================================================
# sort_chunks_by_index tests
# =========================================================================
@
pytest_parametrize_wrapper
(
"num_splits,total_tokens,hidden_size"
,
SORT_CHUNKS_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
def
test_sort_chunks_by_index
(
self
,
num_splits
,
total_tokens
,
hidden_size
,
dtype
):
"""Test sort_chunks_by_index forward and backward pass against reference"""
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate random split sizes
key
,
size_key
=
jax
.
random
.
split
(
key
)
split_sizes
=
jax
.
random
.
randint
(
size_key
,
(
num_splits
,),
10
,
total_tokens
//
num_splits
)
split_sizes
=
split_sizes
.
at
[
-
1
].
set
(
total_tokens
-
jnp
.
sum
(
split_sizes
[:
-
1
]))
# Generate sorted indices
key
,
sort_key
=
jax
.
random
.
split
(
key
)
sorted_indices
=
jax
.
random
.
permutation
(
sort_key
,
num_splits
)
# Generate input data
key
,
inp_key
=
jax
.
random
.
split
(
key
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
total_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
# Get reference row_id_map
row_id_map
=
reference_make_chunk_sort_map
(
split_sizes
,
sorted_indices
,
total_tokens
)
# Define loss functions (JIT compiled for performance)
@
jax
.
jit
def
loss_fn
(
x
):
output
,
_
=
sort_chunks_by_index
(
x
,
split_sizes
,
sorted_indices
)
return
jnp
.
sum
(
output
**
2
)
@
jax
.
jit
def
ref_loss_fn
(
x
):
output
,
_
=
reference_sort_chunks_by_map
(
x
,
row_id_map
,
None
,
is_forward
=
True
)
return
jnp
.
sum
(
output
**
2
)
# Test forward pass
output
,
_
=
sort_chunks_by_index
(
inp
,
split_sizes
,
sorted_indices
)
ref_output
,
_
=
reference_sort_chunks_by_map
(
inp
,
row_id_map
,
None
,
is_forward
=
True
)
# Test backward pass with JIT
loss_val
,
computed_grad
=
jax
.
value_and_grad
(
loss_fn
)(
inp
)
ref_loss_val
,
ref_grad
=
jax
.
value_and_grad
(
ref_loss_fn
)(
inp
)
# Compare forward and backward
assert_allclose
(
output
,
ref_output
)
assert_allclose
(
loss_val
,
ref_loss_val
)
assert_allclose
(
computed_grad
,
ref_grad
)
# =========================================================================
# Consolidated dispatch + combine with padding tests
# =========================================================================
@
pytest_parametrize_wrapper
(
"num_tokens,num_experts,hidden_size,topk,align_size"
,
DISPATCH_COMBINE_PADDING_CASES
,
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"with_probs"
,
WITH_PROBS
)
def
test_dispatch_and_combine_with_padding
(
self
,
num_tokens
,
num_experts
,
hidden_size
,
topk
,
align_size
,
dtype
,
with_probs
):
"""
Comprehensive test for token_dispatch and token_combine with padding/unpadding.
Tests:
1. Dispatch with padding: output shape and alignment
2. Dispatch backward pass with padding
3. Combine with unpad: output shape
4. Combine backward pass with unpad
5. Roundtrip with padding: dispatch + combine recovers original
6. Probs permutation with padding (when with_probs=True)
"""
key
=
jax
.
random
.
PRNGKey
(
42
)
# Generate routing map
routing_map
=
self
.
generate_routing_map
(
num_tokens
,
num_experts
,
topk
,
key
)
num_out_tokens
=
int
(
jnp
.
sum
(
routing_map
))
# Compute worst-case padded size
worst_case_size
=
(
(
num_out_tokens
+
num_experts
*
(
align_size
-
1
))
//
align_size
)
*
align_size
# Generate input data
key
,
inp_key
,
prob_key
,
merge_key
=
jax
.
random
.
split
(
key
,
4
)
inp
=
jax
.
random
.
uniform
(
inp_key
,
(
num_tokens
,
hidden_size
),
dtype
=
dtype
,
minval
=-
1.0
,
maxval
=
1.0
)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs
=
None
if
with_probs
:
probs
=
jax
.
random
.
uniform
(
prob_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
# Generate merging probs (normalized per token)
merging_probs
=
jax
.
random
.
uniform
(
merge_key
,
(
num_tokens
,
num_experts
),
dtype
=
dtype
,
minval
=
0.1
,
maxval
=
1.0
)
merging_probs
=
merging_probs
*
routing_map
.
astype
(
dtype
)
# Zero out non-routed
merging_probs
=
merging_probs
/
jnp
.
maximum
(
jnp
.
sum
(
merging_probs
,
axis
=
1
,
keepdims
=
True
),
1e-8
)
# =====================================================================
# Test 1: Dispatch with padding - forward pass
# =====================================================================
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
target_tokens_per_expert
=
token_dispatch
(
inp
,
routing_map
,
num_out_tokens
,
probs
=
probs
,
align_size
=
align_size
)
# Check output shape
assert
output
.
shape
==
(
worst_case_size
,
hidden_size
)
if
with_probs
:
assert
permuted_probs
is
not
None
assert
permuted_probs
.
shape
==
(
worst_case_size
,)
else
:
assert
permuted_probs
is
None
# Check alignment: each expert's tokens should be aligned
for
expert_idx
in
range
(
num_experts
):
expert_tokens
=
int
(
target_tokens_per_expert
[
expert_idx
])
assert
expert_tokens
%
align_size
==
0
or
expert_tokens
==
0
# =====================================================================
# Test 2: Dispatch with padding - backward pass
# =====================================================================
if
with_probs
:
@
jax
.
jit
def
dispatch_loss
(
x
,
p
):
out
,
perm_probs
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
,
probs
=
p
,
align_size
=
align_size
)
return
jnp
.
sum
(
out
**
2
)
+
jnp
.
sum
(
perm_probs
**
2
)
inp_grad
,
probs_grad
=
jax
.
grad
(
dispatch_loss
,
argnums
=
(
0
,
1
))(
inp
,
probs
)
assert
inp_grad
.
shape
==
inp
.
shape
assert
probs_grad
.
shape
==
probs
.
shape
assert
not
jnp
.
any
(
jnp
.
isnan
(
inp_grad
))
assert
not
jnp
.
any
(
jnp
.
isnan
(
probs_grad
))
else
:
@
jax
.
jit
def
dispatch_loss_no_probs
(
x
):
out
,
_
,
_
,
_
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
,
align_size
=
align_size
)
return
jnp
.
sum
(
out
**
2
)
inp_grad
=
jax
.
grad
(
dispatch_loss_no_probs
)(
inp
)
assert
inp_grad
.
shape
==
inp
.
shape
assert
not
jnp
.
any
(
jnp
.
isnan
(
inp_grad
))
# =====================================================================
# Test 3: Combine with unpad - forward pass
# =====================================================================
combined
=
token_combine
(
output
,
row_id_map
,
merging_probs
,
pad_offsets
)
assert
combined
.
shape
==
(
num_tokens
,
hidden_size
)
# =====================================================================
# Test 4: Combine with unpad - backward pass
# =====================================================================
@
jax
.
jit
def
combine_loss
(
x
):
return
jnp
.
sum
(
token_combine
(
x
,
row_id_map
,
merging_probs
,
pad_offsets
)
**
2
)
combine_grad
=
jax
.
grad
(
combine_loss
)(
output
)
assert
combine_grad
.
shape
==
output
.
shape
assert
not
jnp
.
any
(
jnp
.
isnan
(
combine_grad
))
# =====================================================================
# Test 5: Roundtrip with padding (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs
=
routing_map
.
astype
(
dtype
)
/
jnp
.
maximum
(
jnp
.
sum
(
routing_map
,
axis
=
1
,
keepdims
=
True
),
1.0
)
@
jax
.
jit
def
roundtrip
(
x
):
dispatched
,
_
,
rid_map
,
p_offsets
,
_
=
token_dispatch
(
x
,
routing_map
,
num_out_tokens
,
align_size
=
align_size
)
return
token_combine
(
dispatched
,
rid_map
,
uniform_merging_probs
,
p_offsets
)
roundtrip_output
=
roundtrip
(
inp
)
assert_allclose
(
roundtrip_output
,
inp
,
dtype
=
dtype
)
# Test roundtrip gradient
@
jax
.
jit
def
roundtrip_loss
(
x
):
return
jnp
.
sum
(
roundtrip
(
x
)
**
2
)
roundtrip_grad
=
jax
.
grad
(
roundtrip_loss
)(
inp
)
assert
roundtrip_grad
.
shape
==
inp
.
shape
assert
not
jnp
.
any
(
jnp
.
isnan
(
roundtrip_grad
))
tests/jax/test_recipe_characteristics.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_sanity_import.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/jax/test_softmax.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
...
...
@@ -17,7 +17,8 @@ from jax.typing import DTypeLike
from
utils
import
assert_allclose
from
transformer_engine.jax.cpp_extensions
import
is_softmax_kernel_available
from
transformer_engine.jax.softmax
import
SoftmaxType
,
softmax
from
transformer_engine.jax.cpp_extensions.attention
import
AttnSoftmaxType
from
transformer_engine.jax.softmax
import
SoftmaxFusionType
,
softmax
from
transformer_engine.jax.flax.module
import
Softmax
...
...
@@ -50,8 +51,9 @@ class SoftmaxRunner:
max_seqlen_kv
:
int
num_heads
:
int
scale_factor
:
float
softmax_type
:
SoftmaxType
softmax_
fusion_
type
:
Softmax
Fusion
Type
dtype
:
DTypeLike
softmax_type
:
AttnSoftmaxType
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
@
staticmethod
def
reference_softmax
(
logits
,
mask
,
scale_factor
,
**
_
):
...
...
@@ -68,6 +70,7 @@ class SoftmaxRunner:
def
_is_support
(
self
):
return
is_softmax_kernel_available
(
self
.
softmax_fusion_type
,
self
.
softmax_type
,
self
.
batch_size
,
self
.
num_heads
,
...
...
@@ -85,22 +88,22 @@ class SoftmaxRunner:
self
.
logits
=
jax
.
random
.
uniform
(
logits_key
,
logits_shape
,
self
.
dtype
,
-
1.0
)
match
self
.
softmax_type
:
case
SoftmaxType
.
SCALED
:
match
self
.
softmax_
fusion_
type
:
case
Softmax
Fusion
Type
.
SCALED
:
self
.
mask
=
None
case
SoftmaxType
.
SCALED_MASKED
:
case
Softmax
Fusion
Type
.
SCALED_MASKED
:
self
.
mask
=
jax
.
random
.
bernoulli
(
mask_key
,
shape
=
mask_shape
).
astype
(
jnp
.
uint8
)
case
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
case
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
self
.
mask
=
(
1.0
-
jnp
.
tril
(
jnp
.
ones_like
(
self
.
logits
))).
astype
(
jnp
.
uint8
)
case
_
:
raise
ValueError
(
f
"Unknown
{
self
.
softmax_type
=
}
"
)
raise
ValueError
(
f
"Unknown
{
self
.
softmax_
fusion_
type
=
}
"
)
def
test_forward
(
self
):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self
.
_setup_inputs
()
primitive_out
=
softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
,
self
.
softmax_type
)
primitive_out
=
softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
,
self
.
softmax_
fusion_
type
)
reference_out
=
__class__
.
reference_softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
self
.
dtype
)
...
...
@@ -117,7 +120,7 @@ class SoftmaxRunner:
args
=
[
self
.
logits
,
self
.
mask
]
kwargs
=
{
"scale_factor"
:
self
.
scale_factor
,
"softmax_type"
:
self
.
softmax_type
,
"softmax_
fusion_
type"
:
self
.
softmax_
fusion_
type
,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
...
...
@@ -175,7 +178,7 @@ class SoftmaxModuleRunner:
rng
=
jax
.
random
.
PRNGKey
(
0
)
softmax_module
=
Softmax
(
scale_factor
=
runner
.
scale_factor
,
softmax_type
=
runner
.
softmax_type
,
softmax_
fusion_
type
=
runner
.
softmax_
fusion_
type
,
)
softmax_vars
=
softmax_module
.
init
(
rng
,
runner
.
logits
,
runner
.
mask
)
module_out
=
softmax_module
.
apply
(
softmax_vars
,
runner
.
logits
,
runner
.
mask
)
...
...
@@ -194,11 +197,11 @@ class SoftmaxModuleRunner:
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
0.125
])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
"softmax_
fusion_
type"
,
[
pytest
.
param
(
SoftmaxType
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
],
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -214,19 +217,19 @@ class TestSoftmaxPrimitives:
"""
@
staticmethod
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
):
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
):
"""
Test forward with parameterized configs
"""
runner
=
SoftmaxPrimitivesRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
SoftmaxPrimitivesRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
)
runner
.
test_forward
()
@
staticmethod
def
test_backward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
):
def
test_backward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
):
"""
Test forward with parameterized configs
"""
runner
=
SoftmaxPrimitivesRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
SoftmaxPrimitivesRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
)
runner
.
test_backward
()
...
...
@@ -243,11 +246,11 @@ class TestSoftmaxPrimitives:
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
0.125
])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
"softmax_
fusion_
type"
,
[
pytest
.
param
(
SoftmaxType
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
],
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -263,11 +266,11 @@ class TestSoftmaxModule:
"""
@
staticmethod
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
):
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
):
"""
Test forward with parameterized configs
"""
module_runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
module_runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_
fusion_
type
,
dtype
)
bias
=
None
runner
=
SoftmaxModuleRunner
(
module_runner
,
bias
)
runner
.
test_forward
()
tests/jax/test_triton_custom_calls.py
0 → 100644
View file @
0d874a4e
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for Triton-based custom calls in TE JAX."""
import
jax
import
jax.numpy
as
jnp
import
pytest
from
utils
import
assert_allclose
,
pytest_parametrize_wrapper
import
triton
import
triton.language
as
tl
from
transformer_engine.jax.cpp_extensions.base
import
BasePrimitive
,
register_primitive
from
transformer_engine.jax.triton_extensions
import
triton_call_lowering
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"module"
)
def
init
():
"""WAR for CUDA uninitialize error"""
_
=
jnp
.
zeros
(
0
)
yield
class
TestTritonBinding
:
"""Test Triton binding primitive."""
# Define autotuned Triton kernel
@
staticmethod
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
# Uses defaults: num_warps=4, num_stages=3
triton
.
Config
({
"BLOCK_SIZE"
:
512
},
num_warps
=
8
),
# Custom num_warps
],
key
=
[
"n_elements"
],
# Autotune based on input size
)
@
triton
.
jit
def
amax_kernel
(
x_ptr
,
amax_ptr
,
n_elements
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""Compute amax using Triton with autotuning."""
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
,
other
=
0.0
)
abs_x
=
tl
.
abs
(
x
)
block_max
=
tl
.
max
(
abs_x
)
tl
.
atomic_max
(
amax_ptr
,
block_max
)
# Define test primitive
class
AmaxTritonPrimitive
(
BasePrimitive
):
"""Test primitive using Triton kernel."""
name
=
"te_amax_triton_test"
multiple_results
=
False
impl_static_args
=
()
@
staticmethod
def
abstract
(
x_aval
):
return
jax
.
core
.
ShapedArray
((
1
,),
jnp
.
float32
)
@
staticmethod
def
impl
(
x
):
assert
TestTritonBinding
.
AmaxTritonPrimitive
.
inner_primitive
is
not
None
return
TestTritonBinding
.
AmaxTritonPrimitive
.
inner_primitive
.
bind
(
x
)
@
staticmethod
def
lowering
(
ctx
,
x
):
"""MLIR lowering using Triton kernel."""
n_elements
=
1
for
dim
in
ctx
.
avals_in
[
0
].
shape
:
n_elements
*=
dim
# For autotuned kernels, use the minimum BLOCK_SIZE from configs
# to ensure all elements are processed by all configs
block_size
=
min
(
config
.
kwargs
.
get
(
"BLOCK_SIZE"
)
for
config
in
TestTritonBinding
.
amax_kernel
.
configs
)
grid
=
(
triton
.
cdiv
(
n_elements
,
block_size
),)
return
triton_call_lowering
(
ctx
,
TestTritonBinding
.
amax_kernel
,
# Autotuned kernel
x
,
grid
=
grid
,
constexprs
=
{
"n_elements"
:
n_elements
},
# BLOCK_SIZE comes from autotuner config, not passed here
)
register_primitive
(
AmaxTritonPrimitive
)
@
staticmethod
def
_triton_amax
(
x
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""Compute amax using Triton kernel."""
return
TestTritonBinding
.
AmaxTritonPrimitive
.
outer_primitive
.
bind
(
x
)
@
pytest_parametrize_wrapper
(
"shape"
,
[(
1024
,
1024
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
])
def
test_triton_amax
(
self
,
shape
,
dtype
):
"""Test Triton amax with JIT."""
key
=
jax
.
random
.
PRNGKey
(
0
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
dtype
)
expected
=
jnp
.
max
(
jnp
.
abs
(
x
),
keepdims
=
False
).
astype
(
jnp
.
float32
)
jitted_amax
=
jax
.
jit
(
self
.
_triton_amax
)
result
=
jitted_amax
(
x
)
assert_allclose
(
result
,
expected
,
dtype
=
jnp
.
float32
)
tests/jax/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility for the TE layer tests"""
...
...
@@ -21,6 +21,7 @@ from jax import random as jax_random
import
pytest
from
transformer_engine.jax.attention
import
(
AttnSoftmaxType
,
canonicalize_attn_mask_type
,
make_swa_mask
,
)
...
...
@@ -46,6 +47,13 @@ def is_devices_enough(required):
return
len
(
jax
.
devices
())
>=
required
def
is_devices_equal
(
required
):
"""
Check if the available GPUs is exactly equal
"""
return
len
(
jax
.
devices
())
==
required
def
_generate_drop_path_shape
(
shape
:
Sequence
[
int
],
batch_dim
:
int
)
->
Sequence
[
int
]:
# Generate broadcast dims for drop_path.
drop_path_shape
=
list
(
range
(
0
,
len
(
shape
)))
...
...
@@ -162,6 +170,7 @@ class DotProductAttention(nn.Module):
dropout_rate
:
float
=
0.0
dtype
:
DType
=
jnp
.
float32
float32_logits
:
bool
=
False
softmax_type
:
AttnSoftmaxType
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
...
...
@@ -211,6 +220,24 @@ class DotProductAttention(nn.Module):
assert
key
.
shape
[
-
2
]
==
value
.
shape
[
-
2
],
"k, v num_heads must match."
assert
query
.
shape
[
-
1
]
==
key
.
shape
[
-
1
],
"q, k head_dim must match."
# Infer number of attention heads from query shape
# query shape: [..., h, d] where h is num_attention_heads
num_attention_heads
=
query
.
shape
[
-
2
]
# Initialize softmax_offset for off-by-one or learnable softmax
softmax_offset
=
None
if
self
.
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
# For off-by-one softmax, use zeros with shape (1, h, 1, 1)
softmax_offset
=
jnp
.
zeros
((
1
,
num_attention_heads
,
1
,
1
),
dtype
=
input_dtype
)
elif
self
.
softmax_type
==
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
# For learnable softmax, create a learnable parameter with shape (1, h, 1, 1)
softmax_offset
=
self
.
param
(
"softmax_offset"
,
nn
.
initializers
.
zeros
,
(
1
,
num_attention_heads
,
1
,
1
),
jnp
.
float32
,
)
if
self
.
scale_attn_logits
:
head_dim
=
query
.
shape
[
-
1
]
depth_scaling
=
jnp
.
sqrt
(
head_dim
).
astype
(
input_dtype
)
...
...
@@ -241,9 +268,23 @@ class DotProductAttention(nn.Module):
if
bias
is
not
None
:
attn_weights
=
attn_weights
+
bias
.
astype
(
attn_weights
.
dtype
)
# Add attention sink to the last column if not vanilla softmax
if
self
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
# Add extra column with softmax_offset
# softmax_offset shape: (1, h, 1, 1), attn_weights shape: [b, h, q, k]
extra_col
=
jnp
.
broadcast_to
(
softmax_offset
,
(
attn_weights
.
shape
[
0
],
attn_weights
.
shape
[
1
],
attn_weights
.
shape
[
2
],
1
),
)
attn_weights
=
jnp
.
concatenate
([
attn_weights
,
extra_col
],
axis
=-
1
)
# Normalize the attention weights across `kv_length` dimension.
attn_weights
=
jax_nn
.
softmax
(
attn_weights
).
astype
(
input_dtype
)
# Remove the extra column after softmax if not vanilla softmax
if
self
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
attn_weights
=
attn_weights
[...,
:
-
1
]
# Apply attention dropout.
if
not
deterministic
and
self
.
dropout_rate
>
0.0
:
keep_prob
=
1.0
-
self
.
dropout_rate
...
...
@@ -535,6 +576,7 @@ class MultiHeadAttention(nn.Module):
rotary_pos_emb_group_method
:
str
=
"consecutive"
fuse_qkv
:
bool
=
True
use_bias
:
bool
=
False
softmax_type
:
AttnSoftmaxType
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
def
__post_init__
(
self
):
if
self
.
kernel_init
is
None
:
...
...
@@ -801,6 +843,7 @@ class MultiHeadAttention(nn.Module):
dropout_rate
=
self
.
dropout_rate
,
dtype
=
self
.
dtype
,
float32_logits
=
self
.
float32_logits
,
softmax_type
=
self
.
softmax_type
,
)(
query
,
key
,
value
,
bias
=
attention_bias
,
deterministic
=
deterministic
)
x
=
x
.
reshape
((
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]))
...
...
@@ -1058,6 +1101,7 @@ class EncoderLayer(nn.Module):
self_attn_bias_type
:
Any
=
None
self_attn_mask_type
:
str
=
"no_mask"
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
)
softmax_type
:
str
=
"vanilla"
def
__post_init__
(
self
):
if
self
.
num_gqa_groups
is
None
:
...
...
@@ -1111,6 +1155,9 @@ class EncoderLayer(nn.Module):
else
:
x
=
inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type
=
AttnSoftmaxType
.
from_str
(
self
.
softmax_type
)
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x
=
MultiHeadAttention
(
num_heads
=
self
.
num_attention_heads
,
...
...
@@ -1126,6 +1173,7 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
use_bias
=
self
.
use_bias
,
softmax_type
=
attn_softmax_type
,
name
=
"attention"
,
)(
x
,
x
,
encoder_mask
,
encoder_bias
,
deterministic
=
deterministic
)
x
=
nn
.
Dropout
(
rate
=
self
.
hidden_dropout
,
broadcast_dims
=
self
.
hidden_dropout_dims
)(
...
...
@@ -1222,6 +1270,7 @@ class DecoderLayer(nn.Module):
self_attn_bias_type
:
Any
=
None
self_attn_mask_type
:
str
=
"no_mask"
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
)
softmax_type
:
str
=
"vanilla"
def
__post_init__
(
self
):
if
self
.
num_gqa_groups
is
None
:
...
...
@@ -1290,6 +1339,9 @@ class DecoderLayer(nn.Module):
else
:
x
=
inputs
# Convert softmax_type string to AttnSoftmaxType enum
attn_softmax_type
=
AttnSoftmaxType
.
from_str
(
self
.
softmax_type
)
# Self-attention block
x
=
MultiHeadAttention
(
num_heads
=
self
.
num_attention_heads
,
...
...
@@ -1305,6 +1357,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
fuse_qkv
=
self
.
fuse_qkv_params
,
use_bias
=
self
.
use_bias
,
softmax_type
=
attn_softmax_type
,
name
=
"self_attention"
,
)(
x
,
x
,
decoder_mask
,
decoder_bias
,
deterministic
=
deterministic
,
decode
=
decode
)
x
=
nn
.
Dropout
(
rate
=
self
.
hidden_dropout
,
broadcast_dims
=
self
.
hidden_dropout_dims
)(
...
...
@@ -1343,6 +1396,7 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
fuse_qkv
=
self
.
fuse_qkv_params
,
use_bias
=
self
.
use_bias
,
softmax_type
=
attn_softmax_type
,
name
=
"encoder_decoder_attention"
,
)(
y
,
encoded
,
encoder_decoder_mask
,
deterministic
=
deterministic
)
y
=
nn
.
Dropout
(
rate
=
self
.
hidden_dropout
,
broadcast_dims
=
self
.
hidden_dropout_dims
)(
...
...
tests/pytorch/attention/run_attention_with_cp.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -89,40 +89,47 @@ def generate_input_shapes(
cu_seqlens_q_padded
=
None
cu_seqlens_kv_padded
=
None
elif
qkv_format
==
"thd"
:
seqlens_q
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
+
1
,
[
config
.
batch_size
]).
to
(
torch
.
int32
)
seqlens_q_padded
=
(
seqlens_q
+
2
*
world_size
-
1
)
//
(
world_size
*
2
)
*
(
world_size
*
2
)
cu_seqlens_q_padded
=
torch
.
cat
(
[
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
),
seqlens_q_padded
.
cumsum
(
0
,
dtype
=
torch
.
int32
),
]
).
cuda
()
cu_seqlens_q
=
torch
.
clone
(
cu_seqlens_q_padded
)
# Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does,
# cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only.
if
kernel_backend
==
"FusedAttention"
:
cu_seqlens_q
[
1
:]
=
seqlens_q
.
cumsum
(
0
,
dtype
=
torch
.
int32
).
cuda
()
# NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded`
# will not be the same as `cu_seqlens_q` and `cu_seqlens_q_padded` respectively.
cu_seqlens_kv
=
cu_seqlens_q
cu_seqlens_kv_padded
=
cu_seqlens_q_padded
total_tokens
=
cu_seqlens_q_padded
[
-
1
]
q_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
total_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
)
k_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
total_tokens
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
)
v_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
total_tokens
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
total_tokens
,
config
.
num_heads
*
config
.
head_dim_v
,
)
seqlens_q
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
+
1
,
[
config
.
batch_size
]).
to
(
torch
.
int32
)
seqlens_q_padded
=
(
seqlens_q
+
2
*
world_size
-
1
)
//
(
world_size
*
2
)
*
(
world_size
*
2
)
cu_seqlens_q_padded
=
torch
.
cat
(
[
torch
.
zeros
([
1
],
dtype
=
torch
.
int32
),
seqlens_q_padded
.
cumsum
(
0
,
dtype
=
torch
.
int32
),
torch
.
tensor
([
q_input_shape
[
0
]],
dtype
=
torch
.
int32
),
]
).
cuda
()
cu_seqlens_q
=
torch
.
clone
(
cu_seqlens_q_padded
)
if
kernel_backend
==
"FusedAttention"
:
cu_seqlens_q
[
1
:
-
1
]
=
seqlens_q
.
cumsum
(
0
,
dtype
=
torch
.
int32
).
cuda
()
cu_seqlens_q
[
-
1
]
=
cu_seqlens_q
[
-
2
]
cu_seqlens_kv
=
cu_seqlens_q
cu_seqlens_kv_padded
=
cu_seqlens_q_padded
else
:
assert
False
,
f
"
{
qkv_format
=
}
is not supported!"
...
...
tests/pytorch/attention/test_attention.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
logging
...
...
@@ -119,7 +119,14 @@ model_configs_base = {
@
pytest
.
mark
.
parametrize
(
"swa"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"pad_between_seqs"
,
[
False
])
def
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
ckpt_attn
,
workspace_opt
,
qkv_layout
,
swa
,
pad_between_seqs
dtype
,
model_configs
,
model
,
ckpt_attn
,
workspace_opt
,
qkv_layout
,
swa
,
pad_between_seqs
,
):
"""Test DotProductAttention module"""
...
...
@@ -310,6 +317,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
False
,
False
)
model_configs_num_splits
=
{
# test: ModelConfig(b, sq, hq, dqk)
"num_splits_1_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
num_splits
=
2
),
"num_splits_1_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
num_splits
=
4
),
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_num_splits
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_num_splits
.
keys
())
def
test_dpa_num_splits
(
dtype
,
model_configs
,
model
):
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
None
,
False
,
False
,
)
model_configs_softmax
=
{
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
...
...
@@ -1155,6 +1187,8 @@ def _run_dot_product_attention(
core_attention_bias
=
bias
,
alibi_slopes
=
alibi_slopes
,
fast_zero_fill
=
True
,
# Only pass num_splits when exercising the FlashAttention path
num_splits
=
config
.
num_splits
if
backend
==
"FlashAttention"
else
1
,
)
max_logit
=
None
if
config
.
return_max_logit
:
...
...
@@ -1789,9 +1823,10 @@ def test_mha_fp8_vs_f16(
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported
<
1
:
flash_attn_supported
,
fused_attn_supported
_fp8
,
unfused_attn_supported
=
available_backends
if
flash_attn_supported
+
fused_attn_supported
_fp8
<
1
:
pytest
.
skip
(
"No FP8 attention backend available."
)
fused_attn_supported_f16
=
False
if
not
fp8_dpa_bwd
:
available_backends
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
...
...
@@ -1799,8 +1834,8 @@ def test_mha_fp8_vs_f16(
qkv_layout
=
qkv_format
.
replace
(
"hd"
,
"h3d"
),
is_training
=
is_training
,
)
_
,
fused_attn_supported
,
_
=
available_backends
if
not
fused_attn_supported
:
_
,
fused_attn_supported
_f16
,
_
=
available_backends
if
not
fused_attn_supported
_f16
:
pytest
.
skip
(
"No attention backend available."
)
if
flash_attn_supported
:
...
...
@@ -1812,23 +1847,28 @@ def test_mha_fp8_vs_f16(
dtype
,
config
,
True
,
qkv_format
,
input_layernorm
,
RoPE
,
is_training
,
fp8_recipe
)
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
fused_attn_fwd_fp8
,
param_names
,
fused_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
dtype
,
config
,
True
,
qkv_format
,
input_layernorm
,
RoPE
,
is_training
,
fp8_recipe
)
if
fused_attn_supported_fp8
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = True"
)
fused_attn_fwd_fp8
,
param_names
,
fused_attn_bwd_fp8
=
_run_mha_fp8_vs_f16
(
dtype
,
config
,
True
,
qkv_format
,
input_layernorm
,
RoPE
,
is_training
,
fp8_recipe
)
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = False"
)
fused_attn_fwd_f16
,
param_names
,
fused_attn_bwd_f16
=
_run_mha_fp8_vs_f16
(
dtype
,
config
,
False
,
qkv_format
,
input_layernorm
,
RoPE
,
is_training
,
fp8_recipe
)
if
fused_attn_supported_f16
:
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN"
]
=
"1"
_attention_backends
[
"backend_selection_requires_update"
]
=
True
logging
.
info
(
"[test_mha_fp8_vs_f16]: run with fp8_mha = False"
)
fused_attn_fwd_f16
,
param_names
,
fused_attn_bwd_f16
=
_run_mha_fp8_vs_f16
(
dtype
,
config
,
False
,
qkv_format
,
input_layernorm
,
RoPE
,
is_training
,
fp8_recipe
)
atol
=
5e-1
rtol
=
5e-1
rmse_tol
=
0.15
if
flash_attn_supported
:
if
flash_attn_supported
and
fused_attn_supported_f16
:
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"flash fp8 vs fused f16:"
))
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
compare_and_assert
(
...
...
@@ -1841,32 +1881,33 @@ def test_mha_fp8_vs_f16(
rmse_tol
,
True
,
)
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"fused fp8 vs fused f16:"
))
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
compare_and_assert
(
fused_attn_fwd_fp8
,
fused_attn_fwd_f16
,
"fused_attn_fwd_fp8"
,
"fused_attn_fwd_f16"
,
atol
,
rtol
,
rmse_tol
,
True
,
)
if
fused_attn_supported_fp8
and
fused_attn_supported_f16
:
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"fused fp8 vs fused f16:"
))
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
"forward output"
))
compare_and_assert
(
fused_attn_fwd_fp8
,
fused_attn_fwd_f16
,
"fused_attn_fwd_fp8"
,
"fused_attn_fwd_f16"
,
atol
,
rtol
,
rmse_tol
,
True
,
)
if
is_training
:
for
i
in
range
(
len
(
param_names
[:
1
])):
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
param_names
[
i
]))
compare_and_assert
(
fused_attn_bwd_fp8
[
i
],
fused_attn_bwd_f16
[
i
],
f
"fused_attn_bwd_fp8[
{
i
}
]"
,
f
"fused_attn_bwd_f16[
{
i
}
]"
,
atol
,
rtol
,
rmse_tol
,
True
,
)
if
is_training
:
for
i
in
range
(
len
(
param_names
[:
1
])):
logging
.
debug
(
"========== {:^25s} =========="
.
format
(
param_names
[
i
]))
compare_and_assert
(
fused_attn_bwd_fp8
[
i
],
fused_attn_bwd_f16
[
i
],
f
"fused_attn_bwd_fp8[
{
i
}
]"
,
f
"fused_attn_bwd_f16[
{
i
}
]"
,
atol
,
rtol
,
rmse_tol
,
True
,
)
def
_run_mha_fp8_vs_f16
(
...
...
@@ -2492,7 +2533,6 @@ class _custom_mha_fp8(torch.autograd.Function):
max_s
:
int
,
fast_zero_fill
:
bool
,
fp8_meta
:
Dict
[
str
,
Any
],
workspace
:
torch
.
Tensor
,
is_training
:
bool
,
mask_type
:
str
,
quantizers
:
list
[
Quantizer
],
...
...
@@ -2521,7 +2561,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv
,
*
_
=
ext
.
general_gemm
(
qkv_weight_fp8
,
inp_fp8
,
workspace
,
bias
=
qkv_bias
,
out_dtype
=
qkv_weight_fp8
.
dtype
,
quantization_params
=
qkv_quantizer
,
...
...
@@ -2563,9 +2602,7 @@ class _custom_mha_fp8(torch.autograd.Function):
s_quantizer
=
s_quantizer
,
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
q
,
k
,
v
,
inp_fp8
,
qkv_weight_fp8
,
workspace
,
out
)
tensors_to_save
,
tensor_objects
=
prepare_for_saving
(
q
,
k
,
v
,
inp_fp8
,
qkv_weight_fp8
,
out
)
ctx
.
save_for_backward
(
*
tensors_to_save
)
ctx
.
tensor_objects
=
tensor_objects
...
...
@@ -2595,7 +2632,7 @@ class _custom_mha_fp8(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
with
torch
.
cuda
.
nvtx
.
range
(
"_DPA"
):
saved_tensors
=
ctx
.
saved_tensors
(
q
,
k
,
v
,
inp_fp8
,
qkv_weight_fp8
,
workspace
,
out
)
=
restore_from_saved
(
(
q
,
k
,
v
,
inp_fp8
,
qkv_weight_fp8
,
out
)
=
restore_from_saved
(
ctx
.
tensor_objects
,
saved_tensors
)
...
...
@@ -2651,7 +2688,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_dgrad
,
*
_
=
ext
.
general_gemm
(
qkv_weight_fp8
,
dqkv_c
,
workspace
,
ctx
.
dtype
,
use_split_accumulator
=
_2X_ACC_DGRAD
,
layout
=
"NN"
,
...
...
@@ -2661,7 +2697,6 @@ class _custom_mha_fp8(torch.autograd.Function):
qkv_wgrad
,
*
_
=
ext
.
general_gemm
(
inp_fp8
,
dqkv
,
workspace
,
ctx
.
dtype
,
use_split_accumulator
=
_2X_ACC_WGRAD
,
layout
=
"NT"
,
...
...
@@ -2712,9 +2747,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
with
torch
.
no_grad
():
self
.
qkv_bias
.
zero_
()
self
.
qkv_weight
.
fill_
(
1.0
)
self
.
workspace
=
torch
.
empty
(
_CUBLASLT_WORKSPACE_SIZE_BYTES
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
def
forward
(
self
,
...
...
@@ -2733,7 +2765,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
max_s
,
self
.
fast_zero_fill
,
self
.
fp8_meta
,
self
.
workspace
,
self
.
training
,
self
.
mask_type
,
self
.
quantizers
,
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -7,7 +7,7 @@ import subprocess
import
sys
import
pathlib
import
logging
import
copy
import
pytest
import
torch
from
transformer_engine.pytorch
import
(
...
...
@@ -74,7 +74,7 @@ dtypes = ["bf16", "fp16"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_2_1"
,
"cp_3_2"
,
"cp_3_3"
]
configs
=
[
"cp_1_0"
,
"cp_1_2"
,
"cp_2_1"
,
"cp_3_2"
,
"cp_3_3"
]
model_configs_flash_attn
=
{
k
:
model_configs_flash_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
@@ -97,12 +97,16 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if
"p2p"
in
cp_comm_type
and
config
.
window_size
!=
(
-
1
,
0
)
and
config
.
window_size
!=
(
-
1
,
-
1
):
pytest
.
skip
(
"CP implementation with KV P2P does not support sliding window yet!"
)
if
cp_comm_type
==
"all_gather"
and
qkv_format
==
"thd"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support THD format yet!"
)
if
cp_comm_type
==
"all_gather"
and
config
.
attn_bias_type
!=
"no_bias"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support bias yet!"
)
if
"a2a"
in
cp_comm_type
and
qkv_format
==
"thd"
:
pytest
.
skip
(
"CP implementation with QKVO A2A does not support THD format yet!"
)
if
qkv_format
==
"thd"
:
if
cp_comm_type
==
"all_gather"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support THD format yet!"
)
if
cp_comm_type
==
"a2a+p2p"
:
pytest
.
skip
(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if
"a2a"
in
cp_comm_type
and
config
.
attn_bias_type
!=
"no_bias"
:
pytest
.
skip
(
"CP implementation with QKVO A2A does not support bias yet!"
)
if
"a2a"
in
cp_comm_type
and
(
config
.
num_heads
%
2
!=
0
or
config
.
num_gqa_groups
%
2
!=
0
):
...
...
@@ -184,7 +188,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_1_4"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
@@ -225,10 +229,14 @@ def test_cp_with_fused_attention(
if
qkv_format
==
"thd"
and
config
.
attn_bias_type
==
"post_scale_bias"
:
pytest
.
skip
(
"THD format does not support post_scale_bias yet!"
)
if
qkv_format
==
"thd"
and
cp_comm_type
==
"all_gather"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support THD format yet!"
)
if
qkv_format
==
"thd"
and
"a2a"
in
cp_comm_type
:
pytest
.
skip
(
"CP implementation with QKVO A2A does not support THD format yet!"
)
if
qkv_format
==
"thd"
:
if
cp_comm_type
==
"all_gather"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support THD format yet!"
)
if
cp_comm_type
==
"a2a+p2p"
:
pytest
.
skip
(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
if
dtype
==
"fp8"
and
cp_comm_type
==
"all_gather"
:
pytest
.
skip
(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
...
...
@@ -282,6 +290,14 @@ def test_cp_with_fused_attention(
)
dtypes
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp8"
:
torch
.
bfloat16
}
if
qkv_format
==
"thd"
:
config
=
copy
.
deepcopy
(
config
)
if
"causal"
in
config
.
attn_mask_type
:
config
.
attn_mask_type
=
"padding_causal"
else
:
config
.
attn_mask_type
=
"padding"
fp8_meta
=
{}
fp8_meta
[
"recipe"
]
=
None
fp8_meta
[
"local_recipes"
]
=
[]
...
...
tests/pytorch/attention/test_cp_utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/attention/test_kv_cache.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment