Unverified Commit 5ba01faa authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization (#1921)



* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization

1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
  that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch]format code
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [Common]perf expert_idx loaded once
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* fix: pad_offsets can be None
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* add padding + merging probs bwd support. Not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fix garbage initialized act grad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* all test passing for jax permutation + pad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change test permutation to reduce test time
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* triggering PR refresh
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix lint in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix typo
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatartdophung <tdophung@nvidia.com>
parent 97a09c29
......@@ -4,6 +4,8 @@
"""Tests for permutation Triton kernels and high-level APIs"""
import functools
import jax
import jax.numpy as jnp
import pytest
......@@ -14,68 +16,117 @@ from transformer_engine.jax.permutation import (
token_combine,
sort_chunks_by_index,
)
from utils import assert_allclose
from utils import assert_allclose, pytest_parametrize_wrapper
# =============================================================================
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
# =============================================================================
# All dispatch/combine test cases
ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:2],
"L2": ALL_DISPATCH_COMBINE_CASES,
}
# All sort chunks test cases
ALL_SORT_CHUNKS_CASES = [
(8, 4096, 1280),
(64, 4096, 4096),
(256, 4096, 9216),
]
SORT_CHUNKS_CASES = {
"L0": ALL_SORT_CHUNKS_CASES[0:2],
"L2": ALL_SORT_CHUNKS_CASES,
}
# All dispatch/combine with padding test cases
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:2],
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}
# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": ALL_DTYPES,
"L2": ALL_DTYPES,
}
# With probs options
ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
"L0": [True],
"L2": ALL_WITH_PROBS,
}
def reference_make_row_id_map(
routing_map: jnp.ndarray,
num_tokens: int,
num_experts: int,
) -> jnp.ndarray:
"""
Reference implementation of make_row_id_map using JAX primitives.
Vectorized reference implementation of make_row_id_map using JAX primitives.
Parameters
----------
routing_map : jnp.ndarray
Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
are routed to which tokens (1 = routed, 0 = not routed).
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts in the input tensor.
Returns
-------
row_id_map : jnp.ndarray
The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
"""
row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32)
num_tokens, num_experts = routing_map.shape
# For each expert, compute cumulative sum to get destination indices
cumsum_per_expert = jnp.cumsum(routing_map, axis=0)
# Compute total tokens per expert
# Compute total tokens per expert and expert offsets
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
# Build the row_id_map
for token_idx in range(num_tokens):
routed_experts = jnp.where(routing_map[token_idx] == 1)[0]
n_routed = len(routed_experts)
# Store number of routed experts in the last position
row_id_map = row_id_map.at[token_idx, -1].set(n_routed)
# For each routed expert, compute destination row and store it
dest_rows = []
expert_indices = []
for expert_idx in routed_experts:
# Destination row = expert offset + (cumsum - 1)
dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1
dest_rows.append(dest_row)
expert_indices.append(expert_idx)
# Sort by destination row
if n_routed > 0:
sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort
sorted_dest_rows = jnp.array(dest_rows)[sort_indices]
sorted_expert_indices = jnp.array(expert_indices)[sort_indices]
# Store sorted destination rows and expert indices
for i in range(n_routed):
row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i])
row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i])
# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
dest_rows_all = (expert_offsets[None, :] + cumsum_per_expert - 1) * routing_map + (-1) * (
1 - routing_map
)
# Count routed experts per token
n_routed_per_token = jnp.sum(routing_map, axis=1)
# For each token, we need to sort by descending dest_row and pack into row_id_map
# Use a large negative value for non-routed experts so they sort to the end
sort_keys = jnp.where(routing_map == 1, -dest_rows_all, jnp.iinfo(jnp.int32).max)
sorted_expert_indices = jnp.argsort(sort_keys, axis=1)
# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]
# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
row_id_map = jnp.concatenate(
[
sorted_dest_rows.astype(jnp.int32),
sorted_expert_indices.astype(jnp.int32),
n_routed_per_token.astype(jnp.int32)[:, None],
],
axis=1,
)
return row_id_map
......@@ -84,13 +135,10 @@ def _reference_permute_impl(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> tuple:
"""
Internal helper for reference permutation implementation.
Vectorized internal helper for reference permutation implementation.
Parameters
----------
......@@ -100,14 +148,8 @@ def _reference_permute_impl(
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
probs : jnp.ndarray
The probabilities of the input tensor.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts.
num_out_tokens : int
Number of tokens in the permuted tensor.
hidden_size : int
Hidden size of the input tensor.
Returns
-------
......@@ -116,33 +158,63 @@ def _reference_permute_impl(
permuted_probs : jnp.ndarray
Permuted probabilities if probs was provided, None otherwise.
"""
num_tokens, hidden_size = inp.shape
num_experts = (row_id_map.shape[1] - 1) // 2
# Extract destination rows, expert indices, and n_routed from row_id_map
dest_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts]
expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts]
n_routed = row_id_map[:, 2 * num_experts] # [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
# (slots beyond n_routed may contain garbage, not -1)
slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts]
valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts]
# Flatten for scatter operations
flat_dest_rows = dest_rows.flatten() # [num_tokens * num_experts]
flat_valid_mask = valid_mask.flatten()
flat_token_indices = jnp.repeat(jnp.arange(num_tokens), num_experts)
flat_expert_indices = expert_indices.flatten()
# Set invalid dest_rows to num_out_tokens (out of bounds, will be dropped)
# This avoids overwriting valid entries at index 0 with zeros
flat_dest_rows_clamped = jnp.where(flat_valid_mask, flat_dest_rows, num_out_tokens)
# Gather input tokens and scatter to output
output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype)
for token_idx in range(num_tokens):
n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range()
for i in range(n_routed):
# Don't use int() here - JAX can index with traced values,
# and int() breaks autodiff gradient tracking
dest_row = row_id_map[token_idx, i]
expert_idx = row_id_map[token_idx, num_experts + i]
# Get probability for this expert
if probs is not None:
if probs.ndim == 1:
prob = probs[token_idx]
else:
prob = probs[token_idx, expert_idx]
# Match kernel behavior: if prob == 0.0, zero out the output (padding indicator)
if prob == 0.0:
output = output.at[dest_row].set(0.0)
else:
output = output.at[dest_row].set(inp[token_idx])
permuted_probs = permuted_probs.at[dest_row].set(prob)
else:
output = output.at[dest_row].set(inp[token_idx])
gathered_inp = inp[flat_token_indices] # [num_tokens * num_experts, hidden_size]
# Use segment_sum-like operation via scatter
# For each valid (token, expert) pair, write inp[token] to output[dest_row]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
output = output.at[flat_dest_rows_clamped].set(
gathered_inp,
mode="drop",
)
permuted_probs = None
if probs is not None:
permuted_probs = jnp.zeros((num_out_tokens,), dtype=probs.dtype)
# Vectorized approach: gather probs and scatter to permuted_probs
if probs.ndim == 1:
flat_probs = probs[flat_token_indices]
else:
# Clamp invalid expert indices to 0 to avoid wraparound indexing with -1
# The result for invalid entries will be ignored anyway since they target num_out_tokens
# Cast to int32 explicitly for consistent indexing behavior
flat_expert_indices_clamped = jnp.where(flat_valid_mask, flat_expert_indices, 0).astype(
jnp.int32
)
flat_probs = probs[flat_token_indices.astype(jnp.int32), flat_expert_indices_clamped]
# Invalid entries target num_out_tokens and get dropped by mode="drop"
permuted_probs = permuted_probs.at[flat_dest_rows_clamped.astype(jnp.int32)].set(
flat_probs,
mode="drop",
)
return output, permuted_probs
......@@ -152,12 +224,9 @@ def _reference_unpermute_impl(
row_id_map: jnp.ndarray,
merging_probs: jnp.ndarray,
permuted_probs: jnp.ndarray,
num_tokens: int,
num_experts: int,
hidden_size: int,
) -> tuple:
"""
Internal helper for reference unpermutation implementation.
Vectorized internal helper for reference unpermutation implementation.
Parameters
----------
......@@ -169,12 +238,6 @@ def _reference_unpermute_impl(
The merging probabilities for weighted reduction.
permuted_probs : jnp.ndarray
The permuted probabilities.
num_tokens : int
Number of tokens.
num_experts : int
Number of experts.
hidden_size : int
Hidden size.
Returns
-------
......@@ -183,31 +246,44 @@ def _reference_unpermute_impl(
unpermuted_probs : jnp.ndarray
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
unpermuted_probs = (
None
if permuted_probs is None
else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
)
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
for token_idx in range(num_tokens):
n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range()
for i in range(n_routed):
# Don't use int() here - JAX can index with traced values,
# and int() breaks autodiff gradient tracking
src_row = row_id_map[token_idx, i]
expert_idx = row_id_map[token_idx, num_experts + i]
if merging_probs is not None:
weight = merging_probs[token_idx, expert_idx]
output = output.at[token_idx].add(inp[src_row] * weight)
else:
output = output.at[token_idx].add(inp[src_row])
if permuted_probs is not None:
unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set(
permuted_probs[src_row]
)
# Extract source rows, expert indices, and n_routed from row_id_map
src_rows = row_id_map[:, :num_experts] # [num_tokens, num_experts]
expert_indices = row_id_map[:, num_experts : 2 * num_experts] # [num_tokens, num_experts]
n_routed = row_id_map[:, 2 * num_experts] # [num_tokens]
# Create mask for valid entries: slot_idx < n_routed[token]
# The kernel's row_id_map only guarantees valid data in the first n_routed slots
slot_indices = jnp.arange(num_experts)[None, :] # [1, num_experts]
valid_mask = slot_indices < n_routed[:, None] # [num_tokens, num_experts]
# Clamp invalid src_rows to 0 (they won't be used due to masking)
src_rows_clamped = jnp.where(valid_mask, src_rows, 0)
# Gather input from permuted positions
gathered_inp = inp[src_rows_clamped] # [num_tokens, num_experts, hidden_size]
# Apply merging probs if provided
if merging_probs is not None:
# Gather the merging weights for each (token, expert) pair using advanced indexing
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
weights = merging_probs[token_idx, expert_indices] # [num_tokens, num_experts]
gathered_inp = gathered_inp * weights[:, :, None]
# Mask out invalid entries and sum across experts
gathered_inp = jnp.where(valid_mask[:, :, None], gathered_inp, 0.0)
output = jnp.sum(gathered_inp, axis=1) # [num_tokens, hidden_size]
unpermuted_probs = None
if permuted_probs is not None:
gathered_probs = permuted_probs[src_rows_clamped] # [num_tokens, num_experts]
unpermuted_probs = jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
unpermuted_probs = unpermuted_probs.at[token_idx, expert_indices].set(
jnp.where(valid_mask, gathered_probs, 0.0)
)
return output, unpermuted_probs
......@@ -241,13 +317,8 @@ def reference_token_dispatch(
row_id_map : jnp.ndarray
The row_id_map for the permutation.
"""
num_tokens, num_experts = routing_map.shape
hidden_size = inp.shape[1]
row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts)
output, permuted_probs = _reference_permute_impl(
inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size
)
row_id_map = reference_make_row_id_map(routing_map)
output, permuted_probs = _reference_permute_impl(inp, row_id_map, probs, num_out_tokens)
return output, permuted_probs, row_id_map
......@@ -274,13 +345,7 @@ def reference_token_combine(
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
"""
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
hidden_size = inp.shape[1]
output, _ = _reference_unpermute_impl(
inp, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size
)
output, _ = _reference_unpermute_impl(inp, row_id_map, merging_probs, None)
return output
......@@ -289,10 +354,9 @@ def reference_make_chunk_sort_map(
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
num_tokens: int,
num_splits: int,
) -> jnp.ndarray:
"""
Reference implementation of make_chunk_sort_map using JAX primitives.
Vectorized reference implementation of make_chunk_sort_map using JAX primitives.
Parameters
----------
......@@ -302,45 +366,48 @@ def reference_make_chunk_sort_map(
The indices of the sorted chunks of shape [num_splits,].
num_tokens : int
Number of tokens.
num_splits : int
Number of splits.
Returns
-------
row_id_map : jnp.ndarray
Row ID map for chunk sorting of shape [num_tokens,].
"""
row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32)
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
# Compute cumulative positions
cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
# Compute destination chunk boundaries based on sorted order
sorted_sizes = split_sizes[sorted_indices]
dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)])
# For each chunk, compute the destination indices
dest_offset = 0
for sorted_idx in sorted_indices:
chunk_start = cumsum_sizes[sorted_idx]
chunk_end = cumsum_sizes[sorted_idx + 1]
chunk_size = chunk_end - chunk_start
# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
inverse_indices = jnp.argsort(sorted_indices)
dest_offsets = dest_cumsum[inverse_indices]
# Map source positions to destination positions
for i in range(chunk_size):
row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i)
# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices = jnp.arange(num_tokens)
dest_offset += chunk_size
# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
return row_id_map
# within_chunk_offset[i] = position i's offset within its chunk
within_chunk_offset = position_indices - src_cumsum[chunk_ids]
# destination[i] = dest_offsets[chunk_ids[i]] + within_chunk_offset[i]
row_id_map = dest_offsets[chunk_ids] + within_chunk_offset
return row_id_map.astype(jnp.int32)
def reference_sort_chunks_by_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
num_tokens: int,
hidden_size: int,
is_forward: bool,
) -> tuple:
"""
Reference implementation of sort_chunks_by_map using JAX primitives.
Vectorized reference implementation of sort_chunks_by_map using JAX primitives.
Parameters
----------
......@@ -350,10 +417,6 @@ def reference_sort_chunks_by_map(
The token to destination mapping of shape [num_tokens,].
probs : jnp.ndarray
The probabilities.
num_tokens : int
Number of tokens.
hidden_size : int
Hidden size.
is_forward : bool
Whether this is forward or backward.
......@@ -364,25 +427,25 @@ def reference_sort_chunks_by_map(
permuted_probs : jnp.ndarray
Sorted probabilities if probs was provided, None otherwise.
"""
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype)
num_tokens = inp.shape[0]
hidden_size = inp.shape[1]
if is_forward:
# Forward: src -> dest
for src_idx in range(num_tokens):
# Don't use int() - JAX can index with traced values
dest_idx = row_id_map[src_idx]
output = output.at[dest_idx].set(inp[src_idx])
if probs is not None:
permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx])
# Forward: scatter inp[src] to output[dest] where dest = row_id_map[src]
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
output = output.at[row_id_map].set(inp)
if probs is not None:
permuted_probs = jnp.zeros((num_tokens,), dtype=probs.dtype)
permuted_probs = permuted_probs.at[row_id_map].set(probs)
else:
permuted_probs = None
else:
# Backward: dest -> src
for dest_idx in range(num_tokens):
# Don't use int() - JAX can index with traced values
src_idx = row_id_map[dest_idx]
output = output.at[dest_idx].set(inp[src_idx])
if probs is not None:
permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx])
# Backward: gather output[dest] = inp[src] where src = row_id_map[dest]
output = inp[row_id_map]
if probs is not None:
permuted_probs = probs[row_id_map]
else:
permuted_probs = None
return output, permuted_probs
......@@ -415,20 +478,24 @@ class TestHighLevelPermutationAPI:
return routing_map
# =========================================================================
# token_dispatch tests
# =========================================================================
@pytest.mark.parametrize(
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
DISPATCH_COMBINE_CASES,
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype):
"""Test token_dispatch forward and backward pass against reference"""
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_token_dispatch(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
):
"""
Individual test for token_dispatch forward and backward passes.
This test validates dispatch in isolation to catch errors that might be
masked when combined with token_combine in the roundtrip test.
Uses value_and_grad to validate both forward (via loss comparison) and
backward (via gradient comparison) passes against reference implementation.
"""
key = jax.random.PRNGKey(42)
# Generate routing map
......@@ -436,173 +503,231 @@ class TestHighLevelPermutationAPI:
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data
key, inp_key = jax.random.split(key)
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Define loss functions
def loss_fn(x):
output, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(output**2)
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
def ref_loss_fn(x):
output, _, _ = reference_token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(output**2)
# Generate reference row_id_map for comparison
ref_row_id_map = reference_make_row_id_map(routing_map)
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# =====================================================================
# Test forward and backward pass using value_and_grad
# (value validates forward, grad validates backward)
# =====================================================================
if with_probs:
# Compare forward outputs
output, _, _ = token_dispatch(inp, routing_map, num_out_tokens)
ref_output, _, _ = reference_token_dispatch(inp, routing_map, num_out_tokens)
assert_allclose(output, ref_output)
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
@jax.jit
def ref_dispatch_loss(x, p):
out, perm_probs = _reference_permute_impl(x, ref_row_id_map, p, num_out_tokens)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
loss_val, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
inp, probs
)
ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_dispatch_loss, argnums=(0, 1)
)(inp, probs)
# Validate forward loss matches
assert_allclose(loss_val, ref_loss_val, dtype=dtype)
# Validate gradients
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
else:
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(out**2)
@jax.jit
def ref_dispatch_loss_no_probs(x):
out, _ = _reference_permute_impl(x, ref_row_id_map, None, num_out_tokens)
return jnp.sum(out**2)
loss_val, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
ref_loss_val, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
# Validate forward loss matches
assert_allclose(loss_val, ref_loss_val, dtype=dtype)
# Validate gradients
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
# =========================================================================
# token_dispatch with probs tests
# Consolidated dispatch + combine tests
# =========================================================================
@pytest.mark.parametrize(
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
DISPATCH_COMBINE_CASES,
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_token_dispatch_with_probs(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_dispatch_and_combine(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs
):
"""Test token_dispatch with probs forward and backward pass against reference"""
"""
Comprehensive test for token_dispatch and token_combine.
Tests:
1. Dispatch forward pass against reference (element-by-element)
2. Dispatch backward pass against reference
3. Combine forward pass against reference (element-by-element)
4. Combine backward pass against reference
5. Roundtrip: dispatch + combine recovers original input
6. row_id_map n_routed column validation
7. Probs permutation (when with_probs=True)
"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data and probs
key, inp_key, prob_key = jax.random.split(key, 3)
# Generate input data
key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0
)
# Define loss function that uses token_dispatch with probs
# We compute gradients w.r.t. both inp and probs
def loss_fn(x, p):
output, permuted_probs, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(output**2) + jnp.sum(permuted_probs**2)
def ref_loss_fn(x, p):
output, permuted_probs, _ = reference_token_dispatch(
x, routing_map, num_out_tokens, probs=p
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
return jnp.sum(output**2) + jnp.sum(permuted_probs**2)
loss_val, (inp_grad, probs_grad) = jax.value_and_grad(loss_fn, argnums=(0, 1))(inp, probs)
ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_loss_fn, argnums=(0, 1)
)(inp, probs)
output, permuted_probs, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs)
ref_output, ref_permuted_probs, _ = reference_token_dispatch(
inp, routing_map, num_out_tokens, probs=probs
# Generate merging probs (normalized per token)
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
# Compare forward outputs
assert_allclose(output, ref_output)
assert_allclose(permuted_probs, ref_permuted_probs)
# Compare loss and gradients
assert_allclose(loss_val, ref_loss_val)
assert_allclose(inp_grad, ref_inp_grad)
assert_allclose(probs_grad, ref_probs_grad)
# =========================================================================
# token_combine tests
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_token_combine(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_merging_probs
):
"""Test token_combine forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Get row_id_map from reference_token_dispatch
key, dummy_key = jax.random.split(key)
dummy_inp = jax.random.uniform(
dummy_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed
merging_probs = merging_probs / jnp.maximum(
jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
)
_, _, row_id_map = reference_token_dispatch(dummy_inp, routing_map, num_out_tokens)
# Generate input data (from expert outputs)
key, inp_key, merge_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
# =====================================================================
# Test 1: Dispatch forward pass
# =====================================================================
output, permuted_probs, row_id_map, _, _ = token_dispatch(
inp, routing_map, num_out_tokens, probs=probs
)
ref_output, ref_permuted_probs = _reference_permute_impl(
inp, row_id_map, probs, num_out_tokens
)
if with_merging_probs:
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0
# Validate row_id_map structure: n_routed column should match routing_map sum
n_routed_actual = row_id_map[:, -1]
n_routed_expected = jnp.sum(routing_map, axis=1)
assert jnp.array_equal(
n_routed_actual, n_routed_expected
), "make_row_id_map n_routed column mismatch"
# Compare dispatch output
assert_allclose(output, ref_output, dtype=dtype)
if with_probs:
assert_allclose(permuted_probs, ref_permuted_probs, dtype=dtype)
# =====================================================================
# Test 2: Dispatch backward pass
# =====================================================================
if with_probs:
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
@jax.jit
def ref_dispatch_loss(x, p):
out, perm_probs = _reference_permute_impl(x, row_id_map, p, num_out_tokens)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
_, (inp_grad, probs_grad) = jax.value_and_grad(dispatch_loss, argnums=(0, 1))(
inp, probs
)
# Normalize per token
merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8)
_, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_dispatch_loss, argnums=(0, 1)
)(inp, probs)
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
assert_allclose(probs_grad, ref_probs_grad, dtype=dtype)
else:
merging_probs = None
# Define loss functions
def loss_fn(x):
output = token_combine(x, row_id_map, merging_probs)
return jnp.sum(output**2)
def ref_loss_fn(x):
output = reference_token_combine(x, row_id_map, merging_probs)
return jnp.sum(output**2)
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(out**2)
@jax.jit
def ref_dispatch_loss_no_probs(x):
out, _ = _reference_permute_impl(x, row_id_map, None, num_out_tokens)
return jnp.sum(out**2)
_, inp_grad = jax.value_and_grad(dispatch_loss_no_probs)(inp)
_, ref_inp_grad = jax.value_and_grad(ref_dispatch_loss_no_probs)(inp)
assert_allclose(inp_grad, ref_inp_grad, dtype=dtype)
# =====================================================================
# Test 3: Combine forward pass
# =====================================================================
combined = token_combine(output, row_id_map, merging_probs)
ref_combined = _reference_unpermute_impl(output, row_id_map, merging_probs, None)[0]
assert_allclose(combined, ref_combined, dtype=dtype)
# =====================================================================
# Test 4: Combine backward pass
# =====================================================================
@jax.jit
def combine_loss(x):
return jnp.sum(token_combine(x, row_id_map, merging_probs) ** 2)
@jax.jit
def ref_combine_loss(x):
return jnp.sum(_reference_unpermute_impl(x, row_id_map, merging_probs, None)[0] ** 2)
_, combine_grad = jax.value_and_grad(combine_loss)(output)
_, ref_combine_grad = jax.value_and_grad(ref_combine_loss)(output)
assert_allclose(combine_grad, ref_combine_grad, dtype=dtype)
# =====================================================================
# Test 5: Roundtrip (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
# Compare forward outputs
output = token_combine(inp, row_id_map, merging_probs)
ref_output = reference_token_combine(inp, row_id_map, merging_probs)
assert_allclose(output, ref_output)
@jax.jit
def roundtrip(x):
dispatched, _, rid_map, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return token_combine(dispatched, rid_map, uniform_merging_probs)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
roundtrip_output = roundtrip(inp)
assert_allclose(roundtrip_output, inp, dtype=dtype)
# =========================================================================
# sort_chunks_by_index tests
# =========================================================================
@pytest.mark.parametrize(
@pytest_parametrize_wrapper(
"num_splits,total_tokens,hidden_size",
[
(4, 128, 256),
(8, 256, 512),
],
SORT_CHUNKS_CASES,
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest_parametrize_wrapper("dtype", DTYPES)
def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype):
"""Test sort_chunks_by_index forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
......@@ -622,73 +747,181 @@ class TestHighLevelPermutationAPI:
inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
row_id_map = reference_make_chunk_sort_map(
split_sizes, sorted_indices, total_tokens, num_splits
)
# Get reference row_id_map
row_id_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens)
# Define loss functions
# Define loss functions (JIT compiled for performance)
@jax.jit
def loss_fn(x):
output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices)
return jnp.sum(output**2)
@jax.jit
def ref_loss_fn(x):
output, _ = reference_sort_chunks_by_map(
x, row_id_map, None, total_tokens, hidden_size, is_forward=True
)
output, _ = reference_sort_chunks_by_map(x, row_id_map, None, is_forward=True)
return jnp.sum(output**2)
# Test forward pass
output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
ref_output, _ = reference_sort_chunks_by_map(inp, row_id_map, None, is_forward=True)
# Test backward pass with JIT
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# Compare forward outputs
output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
ref_output, _ = reference_sort_chunks_by_map(
inp, row_id_map, None, total_tokens, hidden_size, is_forward=True
)
# Compare forward and backward
assert_allclose(output, ref_output)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
# =========================================================================
# Round-trip tests (token_dispatch -> expert processing -> token_combine)
# Consolidated dispatch + combine with padding tests
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_dispatch_combine_roundtrip(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("with_probs", WITH_PROBS)
def test_dispatch_and_combine_with_padding(
self, num_tokens, num_experts, hidden_size, topk, align_size, dtype, with_probs
):
"""Test that token_dispatch followed by token_combine recovers original input"""
"""
Comprehensive test for token_dispatch and token_combine with padding/unpadding.
Tests:
1. Dispatch with padding: output shape and alignment
2. Dispatch backward pass with padding
3. Combine with unpad: output shape
4. Combine backward pass with unpad
5. Roundtrip with padding: dispatch + combine recovers original
6. Probs permutation with padding (when with_probs=True)
"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
num_out_tokens = int(jnp.sum(routing_map))
# Compute worst-case padded size
worst_case_size = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
# Generate input data
key, inp_key = jax.random.split(key)
key, inp_key, prob_key, merge_key = jax.random.split(key, 4)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Create uniform merging probs (equal weight for all routed experts)
merging_probs = routing_map.astype(dtype) / jnp.maximum(
# Generate probs if needed (minval > 0 to avoid kernel's special prob==0 handling)
probs = None
if with_probs:
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
# Generate merging probs (normalized per token)
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
merging_probs = merging_probs * routing_map.astype(dtype) # Zero out non-routed
merging_probs = merging_probs / jnp.maximum(
jnp.sum(merging_probs, axis=1, keepdims=True), 1e-8
)
# =====================================================================
# Test 1: Dispatch with padding - forward pass
# =====================================================================
output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert = token_dispatch(
inp, routing_map, num_out_tokens, probs=probs, align_size=align_size
)
# Check output shape
assert output.shape == (worst_case_size, hidden_size)
if with_probs:
assert permuted_probs is not None
assert permuted_probs.shape == (worst_case_size,)
else:
assert permuted_probs is None
# Check alignment: each expert's tokens should be aligned
for expert_idx in range(num_experts):
expert_tokens = int(target_tokens_per_expert[expert_idx])
assert expert_tokens % align_size == 0 or expert_tokens == 0
# =====================================================================
# Test 2: Dispatch with padding - backward pass
# =====================================================================
if with_probs:
@jax.jit
def dispatch_loss(x, p):
out, perm_probs, _, _, _ = token_dispatch(
x, routing_map, num_out_tokens, probs=p, align_size=align_size
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
inp_grad, probs_grad = jax.grad(dispatch_loss, argnums=(0, 1))(inp, probs)
assert inp_grad.shape == inp.shape
assert probs_grad.shape == probs.shape
assert not jnp.any(jnp.isnan(inp_grad))
assert not jnp.any(jnp.isnan(probs_grad))
else:
@jax.jit
def dispatch_loss_no_probs(x):
out, _, _, _, _ = token_dispatch(
x, routing_map, num_out_tokens, align_size=align_size
)
return jnp.sum(out**2)
inp_grad = jax.grad(dispatch_loss_no_probs)(inp)
assert inp_grad.shape == inp.shape
assert not jnp.any(jnp.isnan(inp_grad))
# =====================================================================
# Test 3: Combine with unpad - forward pass
# =====================================================================
combined = token_combine(output, row_id_map, merging_probs, pad_offsets)
assert combined.shape == (num_tokens, hidden_size)
# =====================================================================
# Test 4: Combine with unpad - backward pass
# =====================================================================
@jax.jit
def combine_loss(x):
return jnp.sum(token_combine(x, row_id_map, merging_probs, pad_offsets) ** 2)
combine_grad = jax.grad(combine_loss)(output)
assert combine_grad.shape == output.shape
assert not jnp.any(jnp.isnan(combine_grad))
# =====================================================================
# Test 5: Roundtrip with padding (dispatch + combine = original)
# =====================================================================
# Use uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
# Dispatch tokens to experts (returns output, permuted_probs, row_id_map)
dispatched, _, row_id_map = token_dispatch(inp, routing_map, num_out_tokens)
@jax.jit
def roundtrip(x):
dispatched, _, rid_map, p_offsets, _ = token_dispatch(
x, routing_map, num_out_tokens, align_size=align_size
)
return token_combine(dispatched, rid_map, uniform_merging_probs, p_offsets)
roundtrip_output = roundtrip(inp)
assert_allclose(roundtrip_output, inp, dtype=dtype)
# Combine tokens back (with uniform merging) (new signature)
combined = token_combine(dispatched, row_id_map, merging_probs)
# Test roundtrip gradient
@jax.jit
def roundtrip_loss(x):
return jnp.sum(roundtrip(x) ** 2)
# Compare with original input
assert_allclose(combined, inp)
roundtrip_grad = jax.grad(roundtrip_loss)(inp)
assert roundtrip_grad.shape == inp.shape
assert not jnp.any(jnp.isnan(roundtrip_grad))
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
import os
import random
import torch
......@@ -13,6 +14,7 @@ from transformer_engine.common import recipe
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs,
moe_permute_and_pad_with_probs as te_permute_and_pad_with_probs,
moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs,
......@@ -24,6 +26,7 @@ from transformer_engine.pytorch import (
MXFP8Quantizer,
)
import transformer_engine_torch as tex
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding
import copy
seed = 1234
......@@ -653,6 +656,522 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs=False,
align_size=16,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens is None:
num_out_tokens = num_tokens * topK
print(
"permutation and padding:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK}"
f" with_merging_probs:{with_merging_probs} align_size:{align_size} {te_dtype}"
)
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)
tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()
permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_fwd_input.requires_grad_(True)
restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# moe_permute_with_probs and Fp8Padding, moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_probs, _ = fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)
permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)
# unpadding + unpermute
unpermute_unpad_fwd_input = permuted_paded_output.detach()
unpermute_unpad_fwd_input.requires_grad_(True)
fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
probs_naive = probs
unpermuted_unpaded_output = te_unpermute(
unpaded_output,
row_id_map,
merging_probs=probs_naive if with_merging_probs else None,
restore_shape=restore_shape,
)
unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
###################################################################################################################################
#
# fusion moe_permute_with_probs and Fp8Padding, fusion fusion moe_unpermute and Fp8Unpadding
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_and_pad_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_and_pad_fwd_input.requires_grad_(True)
probs_fusion = probs_naive.detach().clone()
probs_fusion.requires_grad_(True)
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs_fusion,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)
# fusion unpad and unpermute
fusion_unpermute_unpad_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_unpad_fwd_input.requires_grad_(True)
fusion_unpermuted_unpaded_output = te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
merging_probs=probs_fusion if with_merging_probs else None,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_unpaded_output.backward(fusion_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
permuted_paded_output_ = permuted_paded_output.float()
fusion_permuted_padded_output_ = fusion_permuted_padded_output.float()
permute_pad_fwd_input_grad = permute_pad_fwd_input.grad.float()
fusion_permute_and_pad_fwd_input_grad = fusion_permute_and_pad_fwd_input.grad.float()
unpermuted_unpaded_output_ = unpermuted_unpaded_output.float()
fusion_unpermuted_unpaded_output_ = fusion_unpermuted_unpaded_output.float()
unpermute_unpad_fwd_input_grad = unpermute_unpad_fwd_input.grad.float()
fusion_unpermute_unpad_fwd_input_grad = fusion_unpermute_unpad_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close(
permuted_paded_output_,
fusion_permuted_padded_output_,
msg=f"Mismatch in te_permute_and_pad fwd",
**tols,
)
torch.testing.assert_close(
permute_pad_fwd_input_grad,
fusion_permute_and_pad_fwd_input_grad,
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
torch.testing.assert_close(
unpermuted_unpaded_output_,
fusion_unpermuted_unpaded_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
unpermute_unpad_fwd_input_grad,
fusion_unpermute_unpad_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
torch.testing.assert_close(
permuted_paded_probs.float(),
fusion_permuted_padded_probs.float(),
msg=f"Mismatch in te_permute_and_pad bwd",
**tols,
)
if with_merging_probs:
torch.testing.assert_close(
probs_naive.grad.float(),
probs_fusion.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
def permute_and_pad():
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
fp8_padding(permuted_output, tokens_per_expert_list)
fp8_padding(permuted_probs.unsqueeze(-1), tokens_per_expert_list)
def fusion_permute_and_pad():
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
) = te_permute_and_pad_with_probs(
fusion_permute_and_pad_fwd_input,
probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permuted_padded_probs = fusion_permuted_padded_probs.unsqueeze(-1)
t1 = perf_test_cuda_kernel(lambda: permute_and_pad())
t2 = perf_test_cuda_kernel(lambda: fusion_permute_and_pad())
print(f"permute_and_pad\t\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
permuted_paded_output,
permute_pad_bwd_input,
forward_input=[permute_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_permuted_padded_output,
fusion_permute_pad_bwd_input,
forward_input=[fusion_permute_and_pad_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute_and_pad\t\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def unpad_unpermute():
unpaded_output = fp8_unpadding(unpermute_unpad_fwd_input, tokens_per_expert_list)
unpermuted_unpaded_output = te_unpermute(
unpaded_output, row_id_map, restore_shape=restore_shape
)
unpermuted_unpaded_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
t1 = perf_test_cuda_kernel(lambda: unpad_unpermute())
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(
fusion_unpermute_unpad_fwd_input,
row_id_map,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
)
print(f"unpermute_and_unpad\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
unpermuted_unpaded_output,
unpermute_unpad_bwd_input,
forward_input=([unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_unpaded_output,
fusion_unpermute_bwd_input,
forward_input=([fusion_unpermute_unpad_fwd_input, probs]),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_and_unpad\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def _test_permutation_and_padding_with_merging_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
align_size=16,
BENCHMARK=False,
):
"""
Test the combination of merging_probs AND pad_offsets together in moe_unpermute.
This specifically tests the backward pass fix where pad_offsets must be used
when computing gradients with merging_probs.
"""
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"permutation and padding with merging probs:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} align_size:{align_size} {te_dtype}"
)
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
else:
pytest.skip("Invalid dtype.")
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs = probs.to(dtype)
probs.requires_grad_(True)
tokens_per_expert = routing_map.sum(dim=0).cpu()
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
num_permute_pad_out_tokens = target_tokens_per_expert.sum().item()
permute_pad_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_bwd_input = torch.rand(
(num_permute_pad_out_tokens, hidden_size), dtype=dtype
).cuda()
unpermute_unpad_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
permute_pad_fwd_input.requires_grad_(True)
restore_shape = permute_pad_fwd_input.shape
###################################################################################################################################
#
# Reference: moe_permute_with_probs + Fp8Padding, then Fp8Unpadding + moe_unpermute with merging_probs
#
###################################################################################################################################
# permute + padding
permuted_output, permuted_probs, row_id_map = te_permute_with_probs(
permute_pad_fwd_input,
probs,
routing_map,
num_out_tokens=num_out_tokens,
)
tokens_per_expert_list = tokens_per_expert.tolist()
fp8_padding = Fp8Padding(num_expert, align_size)
permuted_paded_output, _ = fp8_padding(permuted_output, tokens_per_expert_list)
permuted_paded_output.backward(permute_pad_bwd_input, retain_graph=True)
# Reference: unpadding + unpermute WITH merging_probs
ref_unpermute_fwd_input = permuted_paded_output.detach()
ref_unpermute_fwd_input.requires_grad_(True)
ref_probs = probs.detach()
ref_probs.requires_grad_(True)
fp8_unpadding = Fp8Unpadding(num_expert, align_size)
unpaded_output = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
ref_unpermuted_output = te_unpermute(
unpaded_output, row_id_map, ref_probs, restore_shape=restore_shape
)
ref_unpermuted_output.backward(unpermute_unpad_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Fused: moe_permute_and_pad_with_probs, then moe_unpermute with BOTH merging_probs AND pad_offsets
#
###################################################################################################################################
# fusion permute_and_pad
fusion_permute_fwd_input = permute_pad_fwd_input.detach()
fusion_permute_fwd_input.requires_grad_(True)
fusion_probs = probs.detach()
fusion_probs.requires_grad_(True)
(
fusion_permuted_padded_output,
fusion_permuted_padded_probs,
fused_row_id_map,
pad_offsets,
_,
) = te_permute_and_pad_with_probs(
fusion_permute_fwd_input,
fusion_probs,
routing_map,
tokens_per_expert,
align_size,
)
fusion_permute_pad_bwd_input = permute_pad_bwd_input.detach()
fusion_permuted_padded_output.backward(fusion_permute_pad_bwd_input, retain_graph=True)
# Fused: unpermute with BOTH merging_probs AND pad_offsets
fusion_unpermute_fwd_input = fusion_permuted_padded_output.detach()
fusion_unpermute_fwd_input.requires_grad_(True)
fusion_merging_probs = probs.detach()
fusion_merging_probs.requires_grad_(True)
fusion_unpermuted_output = te_unpermute(
fusion_unpermute_fwd_input,
fused_row_id_map,
fusion_merging_probs,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
fusion_unpermute_bwd_input = unpermute_unpad_bwd_input.detach()
fusion_unpermuted_output.backward(fusion_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
# Check forward pass
ref_unpermuted_output_ = ref_unpermuted_output.float()
fusion_unpermuted_output_ = fusion_unpermuted_output.float()
if not BENCHMARK:
torch.testing.assert_close(
ref_unpermuted_output_,
fusion_unpermuted_output_,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets fwd",
**tols,
)
# Check backward pass - activation gradients
ref_unpermute_fwd_input_grad = ref_unpermute_fwd_input.grad.float()
fusion_unpermute_fwd_input_grad = fusion_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
ref_unpermute_fwd_input_grad,
fusion_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (act_grad)",
**tols,
)
# Check backward pass - probs gradients
ref_probs_grad = ref_probs.grad.float()
fusion_probs_grad = fusion_merging_probs.grad.float()
torch.testing.assert_close(
ref_probs_grad,
fusion_probs_grad,
msg=f"Mismatch in te_unpermute with merging_probs and pad_offsets bwd (probs_grad)",
**tols,
)
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
if BENCHMARK:
def ref_unpad_unpermute():
unpaded = fp8_unpadding(ref_unpermute_fwd_input, tokens_per_expert_list)
return te_unpermute(unpaded, row_id_map, ref_probs, restore_shape=restore_shape)
def fused_unpermute():
return te_unpermute(
fusion_unpermute_fwd_input,
fused_row_id_map,
fusion_merging_probs,
restore_shape=restore_shape,
pad_offsets=pad_offsets,
)
t1 = perf_test_cuda_kernel(lambda: ref_unpad_unpermute())
t2 = perf_test_cuda_kernel(lambda: fused_unpermute())
print(f"unpermute_unpad_with_probs\tfwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
ref_unpermuted_output,
unpermute_unpad_bwd_input,
forward_input=[ref_unpermute_fwd_input, ref_probs],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
fusion_unpermuted_output,
fusion_unpermute_bwd_input,
forward_input=[fusion_unpermute_fwd_input, fusion_merging_probs],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute_unpad_with_probs\tbwd: naive: {t1:.3f} ms, fusion: {t2:.3f} ms")
def _test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
......@@ -1126,7 +1645,7 @@ if te.is_bf16_available():
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map(
te_dtype,
......@@ -1155,7 +1674,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map(
te_dtype,
......@@ -1180,6 +1699,74 @@ def test_permutation_mask_map(
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 8, 1280, 2),
(4096, 64, 4096, 6),
(4096, 256, 7168, 6),
(4096, 512, 9216, 8),
],
)
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_permutation_and_padding_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_merging_probs,
):
BENCHMARK = False
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_merging_probs=with_merging_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_out_tokens", [None])
@pytest.mark.parametrize(
"num_tokens, num_expert, hidden_size, topK",
[
(4096, 8, 1280, 2),
(4096, 64, 4096, 6),
(4096, 256, 7168, 6),
(4096, 512, 9216, 8),
],
)
def test_permutation_and_padding_with_merging_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
"""Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets."""
BENCHMARK = False
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
def test_permutation_mask_map_empty_input(te_dtype):
with_probs = True
......@@ -1201,9 +1788,9 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_permutation_mask_map_alongside_probs(
te_dtype,
num_tokens,
......@@ -1253,10 +1840,10 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_permutation_mask_map_fp8(
......@@ -1341,7 +1928,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("tp_size", [2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
te_dtype,
......@@ -1376,6 +1963,10 @@ def test_chunk_permutation_empty_input(te_dtype):
)
@pytest.mark.skipif(
os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k single_case",
)
def test_permutation_single_case():
print("GPU:", torch.cuda.get_device_name(0))
......@@ -1413,6 +2004,26 @@ def test_permutation_single_case():
BENCHMARK=Benchmark,
)
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=Benchmark,
)
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
......@@ -1479,6 +2090,30 @@ def benchmark_single_case(
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_and_padding_mask_map")
_test_permutation_and_padding_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_and_padding_with_merging_probs")
_test_permutation_and_padding_with_merging_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
......@@ -1495,7 +2130,12 @@ def benchmark_single_case(
torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases():
@pytest.mark.skipif(
os.getenv("RUN_BENCHMARK_TESTS", "0") != "1",
reason="Benchmark test - run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark",
)
def test_benchmark_multiple_cases():
"""Benchmark test - skipped by default. Run with: RUN_BENCHMARK_TESTS=1 pytest -k benchmark"""
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
......@@ -1537,4 +2177,4 @@ def benchmark_multiple_cases():
if __name__ == "__main__":
benchmark_multiple_cases()
test_benchmark_multiple_cases()
......@@ -200,6 +200,7 @@ def _permute_kernel(
probs_ptr,
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# sizes
scale_hidden_dim,
# strides
......@@ -224,8 +225,11 @@ def _permute_kernel(
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
......@@ -246,6 +250,15 @@ def _permute_kernel(
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
......@@ -253,11 +266,6 @@ def _permute_kernel(
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
......@@ -297,6 +305,7 @@ def _unpermute_kernel(
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -318,10 +327,12 @@ def _unpermute_kernel(
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
......@@ -348,15 +359,19 @@ def _unpermute_kernel(
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
if FUSION_UNPAD or WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
......@@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
fwd_input_ptr,
merging_probs_ptr,
row_id_map_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
......@@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
......
......@@ -25,8 +25,11 @@ import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map,
permute_with_mask_map,
permute_with_mask_map_and_pad,
unpermute_with_mask_map,
unpermute_with_mask_map_and_unpad,
unpermute_bwd_with_merging_probs,
unpermute_bwd_with_merging_probs_and_unpad,
make_chunk_sort_map,
sort_chunks_by_map,
)
......@@ -43,7 +46,14 @@ def token_dispatch(
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
align_size: Optional[int] = None,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""
Dispatch tokens to experts based on routing map.
......@@ -51,6 +61,10 @@ def token_dispatch(
to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map.
Optionally supports fused padding for alignment when `align_size` is provided.
This is useful for efficient matrix multiplications that require aligned tensor
dimensions. The padding is computed internally from the routing_map.
Parameters
----------
inp : jnp.ndarray
......@@ -59,36 +73,99 @@ def token_dispatch(
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed.
num_out_tokens : int
The number of output tokens after permutation. This should equal the sum of
routing_map and must be provided explicitly for JIT compatibility.
The number of output tokens after permutation (before padding). For the dropless
case, this should be equal to the sum of routing_map. Must be provided explicitly
for JIT compatibility since output shape must be known at compile time.
probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned.
align_size : Optional[int]
Optional alignment size for padding. If provided, outputs will be padded to
align each expert's tokens to a multiple of this size. The output buffer is
allocated with worst-case size, rounded down to align_size:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This enables full JIT compatibility.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
or [worst_case_padded_size, hidden_size] when using padding fusion.
With padding, the actual used portion may be smaller than the buffer; check
actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens], or None if probs was not provided.
Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
or None if probs was not provided.
row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
Note
----
**JIT Compatibility:**
This function is fully JIT-compatible. When using padding (align_size provided),
the output buffer is allocated with a fixed worst-case size that depends only on
compile-time constants (num_out_tokens, num_experts, align_size). The actual
padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
are computed internally from the routing_map and can be traced values.
The worst-case output size is:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This accounts for the maximum possible padding when each expert needs (align_size - 1)
extra tokens to align, rounded down to align_size for buffer alignment.
"""
return _token_dispatch(inp, routing_map, probs, num_out_tokens)
use_padding = align_size is not None
num_experts = routing_map.shape[-1]
if use_padding:
# Compute worst-case output size (compile-time constant)
# This is the maximum possible size when each expert needs max padding
worst_case_out_tokens = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
else:
worst_case_out_tokens = num_out_tokens
return _token_dispatch(
inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule(
inp, routing_map, probs, num_out_tokens
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
probs,
num_out_tokens,
worst_case_out_tokens,
align_size,
use_padding,
)
)
return output, permuted_probs, row_id_map
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def _token_dispatch_fwd_rule(
......@@ -96,9 +173,18 @@ def _token_dispatch_fwd_rule(
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, int, int, int, bool],
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
"""Forward pass rule for token_dispatch."""
# Validate input dimensions
......@@ -126,42 +212,102 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
)
# Compute pad_offsets: cumulative padding for each expert
# pad_offsets[i] = sum of (target - actual) for experts 0..i-1
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = jnp.cumsum(pad_lengths)
pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])
# Use worst_case_out_tokens as the output buffer size (compile-time constant)
# The actual used size is sum(target_tokens_per_expert), which may be smaller.
# Unused positions will be zero-initialized by the kernel.
output, permuted_probs = permute_with_mask_map_and_pad(
inp,
row_id_map,
probs,
pad_offsets,
num_tokens,
num_experts,
worst_case_out_tokens,
hidden_size,
)
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals)
# Include with_probs flag to know how to handle backward pass
residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs)
return (output, permuted_probs, row_id_map), residuals
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
), residuals
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
residuals: Tuple[jnp.ndarray, int, int, int, bool],
g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
_worst_case_out_tokens: int,
_align_size: Optional[int],
_use_padding: bool,
residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
g: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
# Backward: unpermute gradients (gather from experts back to tokens)
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
if pad_offsets is not None:
inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
......@@ -178,6 +324,7 @@ def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Combine tokens from experts back to original token positions.
......@@ -185,33 +332,42 @@ def token_combine(
This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs).
Optionally supports fused unpadding when `pad_offsets` is provided (from
token_dispatch with padding enabled).
Parameters
----------
inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size].
Input tensor from experts of shape [num_out_tokens, hidden_size]
(or [num_out_tokens_padded, hidden_size] when using unpadding).
row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly.
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
If provided, fused unpadding will be performed. This should be the pad_offsets
returned by token_dispatch when using padding.
Returns
-------
output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size].
"""
return _token_combine(inp, row_id_map, merging_probs)
return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
@jax.custom_vjp
def _token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
pad_offsets: Optional[jnp.ndarray],
) -> jnp.ndarray:
"""Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs)
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
return output
......@@ -219,7 +375,20 @@ def _token_combine_fwd_rule(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]:
pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
jnp.ndarray,
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
]:
"""Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0]
......@@ -227,21 +396,34 @@ def _token_combine_fwd_rule(
hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0]
# Call triton extension
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Call triton extension with or without unpadding
if pad_offsets is not None:
output, _ = unpermute_with_mask_map_and_unpad(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals)
# Include inp in residuals for backward with merging_probs
residuals = (
row_id_map,
pad_offsets,
inp,
merging_probs,
num_tokens,
......@@ -253,13 +435,26 @@ def _token_combine_fwd_rule(
def _token_combine_bwd_rule(
row_id_map: jnp.ndarray,
residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int],
residuals: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_combine."""
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
"""Backward pass rule for token_combine.
Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
row_id_map and pad_offsets are integer arrays, so their gradients are None.
"""
(
row_id_map,
pad_offsets,
fwd_input,
merging_probs,
num_tokens,
......@@ -273,30 +468,63 @@ def _token_combine_bwd_rule(
if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if pad_offsets is not None:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
output_grad,
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# The backward kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else:
# Simple case: just permute gradients back
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if pad_offsets is not None:
inp_grad, _ = permute_with_mask_map_and_pad(
output_grad,
row_id_map,
None,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None
return inp_grad, merging_probs_grad
# Return gradients for: inp, row_id_map, merging_probs, pad_offsets
# row_id_map and pad_offsets are integer arrays, so their gradients are None
return inp_grad, None, merging_probs_grad, None
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
......
......@@ -27,8 +27,11 @@ from .utils import triton_call_lowering
__all__ = [
"make_row_id_map",
"permute_with_mask_map",
"permute_with_mask_map_and_pad",
"unpermute_with_mask_map",
"unpermute_with_mask_map_and_unpad",
"unpermute_bwd_with_merging_probs",
"unpermute_bwd_with_merging_probs_and_unpad",
"make_chunk_sort_map",
"sort_chunks_by_map",
]
......@@ -243,20 +246,21 @@ register_primitive(RowIdMapPass3Primitive)
class PermuteWithMaskMapPrimitive(BasePrimitive):
"""
Permute the input tensor based on the row_id_map.
Permute the input tensor based on the row_id_map, optionally with fused padding.
"""
name = "te_permute_with_mask_map_triton"
multiple_results = True
# scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False)
# but they need to be in the signature for the kernel call
# scale, permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False)
# pad_offsets can be shape (0,) when not doing padding, or (num_experts,) when padding
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs
10,
11,
) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, with_pad
inner_primitive = None
outer_primitive = None
......@@ -267,16 +271,18 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
probs_aval,
scale_aval, # dummy, same shape as inp
permuted_scale_aval, # dummy, same shape as inp
pad_offsets_aval,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
):
"""Shape/dtype inference for permute."""
del row_id_map_aval, scale_aval, permuted_scale_aval
del num_tokens, num_experts
del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval
del num_tokens, num_experts, with_pad
output_shape = (num_out_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
......@@ -295,11 +301,13 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
probs,
scale,
permuted_scale,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
):
"""Forward to inner primitive."""
assert PermuteWithMaskMapPrimitive.inner_primitive is not None
......@@ -309,11 +317,13 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
probs,
scale,
permuted_scale,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=with_pad,
)
@staticmethod
......@@ -324,12 +334,14 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
probs,
scale,
permuted_scale,
pad_offsets,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
......@@ -367,6 +379,7 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
probs,
scale,
permuted_scale,
pad_offsets,
grid=grid,
constexprs={
"scale_hidden_dim": 0,
......@@ -387,6 +400,7 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
"hidden_size": hidden_size,
"PERMUTE_PROBS": with_probs,
"PERMUTE_SCALE": False,
"FUSION_PAD": with_pad,
"BLOCK_SIZE": block_size,
},
)
......@@ -403,11 +417,11 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
name = "te_unpermute_with_mask_map_triton"
multiple_results = True
impl_static_args = (
4,
5,
6,
7,
8,
9,
) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
inner_primitive = None
outer_primitive = None
......@@ -418,6 +432,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map_aval,
merging_probs_aval,
permuted_probs_aval,
pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False
*,
num_tokens,
num_experts,
......@@ -426,7 +441,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
with_probs,
):
"""Shape/dtype inference for unpermute."""
del row_id_map_aval, merging_probs_aval, with_merging_probs
del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval
output_shape = (num_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
......@@ -447,6 +462,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
......@@ -460,6 +476,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
......@@ -474,6 +491,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
*,
num_tokens,
num_experts,
......@@ -505,6 +523,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
# Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False)
return triton_call_lowering(
ctx,
_unpermute_kernel,
......@@ -512,6 +531,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
......@@ -530,6 +550,7 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"WITH_MERGING_PROBS": with_merging_probs,
"PERMUTE_PROBS": with_probs,
"FUSION_UNPAD": False,
"BLOCK_SIZE": block_size,
},
)
......@@ -538,6 +559,155 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
register_primitive(UnpermuteWithMaskMapPrimitive)
class UnpermuteWithMaskMapAndUnpadPrimitive(BasePrimitive):
"""
Unpermute the input tensor based on the row_id_map with fused unpadding.
"""
name = "te_unpermute_with_mask_map_and_unpad_triton"
multiple_results = True
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
inp_aval,
row_id_map_aval,
merging_probs_aval,
permuted_probs_aval,
pad_offsets_aval,
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""Shape/dtype inference for unpermute with unpadding."""
del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval
output_shape = (num_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
if with_probs:
unpermuted_probs_shape = (num_tokens, num_experts)
unpermuted_probs_aval = jax.core.ShapedArray(
unpermuted_probs_shape, permuted_probs_aval.dtype
)
else:
unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)
return output_aval, unpermuted_probs_aval
@staticmethod
def impl(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""Forward to inner primitive."""
assert UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive is not None
return UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
)
@staticmethod
def lowering(
ctx,
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
output_stride_hidden = 1
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
if with_merging_probs:
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
else:
merging_probs_stride_token = 0
merging_probs_stride_expert = 0
permuted_probs_stride_token = 1
unpermuted_probs_stride_token = num_experts
unpermuted_probs_stride_expert = 1
# Grid - use minimum BLOCK_SIZE from autotune configs
block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
return triton_call_lowering(
ctx,
_unpermute_kernel,
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
"stride_output_token": output_stride_token,
"stride_output_hidden": output_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_permuted_probs_token": permuted_probs_stride_token,
"stride_unpermuted_probs_token": unpermuted_probs_stride_token,
"stride_unpermuted_probs_expert": unpermuted_probs_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"WITH_MERGING_PROBS": with_merging_probs,
"PERMUTE_PROBS": with_probs,
"FUSION_UNPAD": True,
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteWithMaskMapAndUnpadPrimitive)
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
"""
Backward pass for unpermute with merging probabilities.
......@@ -547,7 +717,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
name = "te_unpermute_bwd_with_merging_probs_triton"
multiple_results = True
impl_static_args = (4, 5, 6, 7) # num_tokens, num_experts, num_out_tokens, hidden_size
impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size
inner_primitive = None
outer_primitive = None
......@@ -557,6 +727,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input_aval,
merging_probs_aval,
row_id_map_aval,
pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False
*,
num_tokens,
num_experts,
......@@ -564,7 +735,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
hidden_size,
):
"""Shape/dtype inference for unpermute backward with merging probs."""
del fwd_input_aval, row_id_map_aval
del fwd_input_aval, row_id_map_aval, pad_offsets_aval
# fwd_input_grad has same shape as fwd_input
fwd_input_grad_shape = (num_out_tokens, hidden_size)
......@@ -584,6 +755,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
......@@ -596,6 +768,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
......@@ -609,6 +782,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
*,
num_tokens,
num_experts,
......@@ -638,7 +812,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
# Get min block size from autotune configs for consistency
block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)
# Pass inputs in kernel argument order: fwd_output_grad, fwd_input, merging_probs, row_id_map
# Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False)
return triton_call_lowering(
ctx,
_unpermute_bwd_with_merging_probs_kernel,
......@@ -646,6 +820,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
......@@ -663,6 +838,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"FUSION_UNPAD": False,
"BLOCK_SIZE": block_size,
},
)
......@@ -671,6 +847,145 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
register_primitive(UnpermuteBwdWithMergingProbsPrimitive)
class UnpermuteBwdWithMergingProbsAndUnpadPrimitive(BasePrimitive):
"""
Backward pass for unpermute with merging probabilities and fused unpadding.
This kernel computes gradients for both the input and merging_probs,
while handling padded outputs.
"""
name = "te_unpermute_bwd_with_merging_probs_and_unpad_triton"
multiple_results = True
impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
fwd_output_grad_aval,
fwd_input_aval,
merging_probs_aval,
row_id_map_aval,
pad_offsets_aval,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""Shape/dtype inference for unpermute backward with merging probs and unpadding."""
del fwd_input_aval, row_id_map_aval, pad_offsets_aval
# fwd_input_grad has same shape as fwd_input
fwd_input_grad_shape = (num_out_tokens, hidden_size)
fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype)
# merging_probs_grad has same shape as merging_probs
merging_probs_grad_shape = (num_tokens, num_experts)
merging_probs_grad_aval = jax.core.ShapedArray(
merging_probs_grad_shape, merging_probs_aval.dtype
)
return fwd_input_grad_aval, merging_probs_grad_aval
@staticmethod
def impl(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""Forward to inner primitive."""
assert UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive is not None
return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
)
@staticmethod
def lowering(
ctx,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
# Compute strides
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
fwd_output_grad_stride_token = hidden_size
fwd_output_grad_stride_hidden = 1
fwd_input_grad_stride_token = hidden_size
fwd_input_grad_stride_hidden = 1
fwd_input_stride_token = hidden_size
fwd_input_stride_hidden = 1
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
merging_probs_grad_stride_token = num_experts
merging_probs_grad_stride_expert = 1
# Grid - one program per token
grid = (num_tokens,)
# Get min block size from autotune configs for consistency
block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)
return triton_call_lowering(
ctx,
_unpermute_bwd_with_merging_probs_kernel,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_fwd_output_grad_token": fwd_output_grad_stride_token,
"stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden,
"stride_fwd_input_grad_token": fwd_input_grad_stride_token,
"stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden,
"stride_fwd_input_token": fwd_input_stride_token,
"stride_fwd_input_hidden": fwd_input_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_merging_probs_grad_token": merging_probs_grad_stride_token,
"stride_merging_probs_grad_expert": merging_probs_grad_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"FUSION_UNPAD": True,
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteBwdWithMergingProbsAndUnpadPrimitive)
def unpermute_bwd_with_merging_probs(
fwd_output_grad: jnp.ndarray,
row_id_map: jnp.ndarray,
......@@ -712,12 +1027,73 @@ def unpermute_bwd_with_merging_probs(
merging_probs_grad : jnp.ndarray
Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
"""
# Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map
# Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature)
dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
# Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets
return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
dummy_pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
)
def unpermute_bwd_with_merging_probs_and_unpad(
fwd_output_grad: jnp.ndarray,
row_id_map: jnp.ndarray,
fwd_input: jnp.ndarray,
merging_probs: jnp.ndarray,
pad_offsets: jnp.ndarray,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Backward pass for unpermute with merging probabilities and fused unpadding.
This computes gradients for both the input tensor and merging_probs,
while handling padded outputs.
Parameters
----------
fwd_output_grad : jnp.ndarray
Gradient of the forward output of shape `[num_tokens, hidden_size]`.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input : jnp.ndarray
The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs : jnp.ndarray
The merging probabilities of shape `[num_tokens, num_experts]`.
pad_offsets : jnp.ndarray
Per-expert cumulative padding offsets of shape `[num_experts]`.
num_tokens : int
Number of tokens in the unpermuted tensor.
num_experts : int
Number of experts.
num_out_tokens : int
Number of tokens in the permuted tensor (including padding).
hidden_size : int
Hidden size.
Returns
-------
fwd_input_grad : jnp.ndarray
Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`.
merging_probs_grad : jnp.ndarray
Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
"""
return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.outer_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
......@@ -957,6 +1333,78 @@ def permute_with_mask_map(
"""
with_probs = probs is not None
# Handle None probs by creating dummy tensor
if not with_probs:
probs = jnp.zeros((0,), dtype=inp.dtype)
# Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature)
dummy_scale = inp
dummy_permuted_scale = inp
# Create dummy pad_offsets (not used when FUSION_PAD=False, but required by kernel signature)
dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
probs,
dummy_scale,
dummy_permuted_scale,
dummy_pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=False,
)
if not with_probs:
permuted_probs = None
return output, permuted_probs
def permute_with_mask_map_and_pad(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
pad_offsets: jnp.ndarray,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Permute the input tensor based on the row_id_map with fused padding.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs : Optional[jnp.ndarray]
The probabilities of the input tensor. If it is not None, it will be permuted.
pad_offsets : jnp.ndarray
Per-expert cumulative padding offsets of shape `[num_experts]`.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts in the input tensor.
num_out_tokens : int
Number of tokens in the permuted tensor (including padding).
hidden_size : int
Hidden size of the input tensor.
Returns
-------
output : jnp.ndarray
Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities if probs was provided, None otherwise.
"""
with_probs = probs is not None
# Handle None probs by creating dummy tensor
if not with_probs:
probs = jnp.zeros((0,), dtype=inp.dtype)
......@@ -971,11 +1419,13 @@ def permute_with_mask_map(
probs,
dummy_scale,
dummy_permuted_scale,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=True,
)
if not with_probs:
......@@ -1029,12 +1479,83 @@ def unpermute_with_mask_map(
merging_probs = jnp.zeros((0,), dtype=inp.dtype)
if not with_probs:
permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
# Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature)
dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
dummy_pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
)
if not with_probs:
unpermuted_probs = None
return output, unpermuted_probs
def unpermute_with_mask_map_and_unpad(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
permuted_probs: Optional[jnp.ndarray],
pad_offsets: jnp.ndarray,
num_tokens: int,
num_experts: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Unpermute the input tensor based on the row_id_map with fused unpadding.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape `[num_out_tokens, hidden_size]` (including padding).
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs : Optional[jnp.ndarray]
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
permuted_probs : Optional[jnp.ndarray]
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
pad_offsets : jnp.ndarray
Per-expert cumulative padding offsets of shape `[num_experts]`.
num_tokens : int
Number of tokens in the unpermuted tensor.
num_experts : int
Number of experts.
hidden_size : int
Hidden size of the tensor.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape `[num_tokens, hidden_size]`.
unpermuted_probs : Optional[jnp.ndarray]
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
with_merging_probs = merging_probs is not None
with_probs = permuted_probs is not None
# Handle None inputs by creating dummy tensors
if not with_merging_probs:
merging_probs = jnp.zeros((0,), dtype=inp.dtype)
if not with_probs:
permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
output, unpermuted_probs = UnpermuteWithMaskMapAndUnpadPrimitive.outer_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
......
......@@ -142,17 +142,31 @@ def compile_triton(
)
# Create kernel object for JAX
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1, # cluster_dims
)
# From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str)
num_warps, # arg1: num_warps (int)
num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
compute_capability, # arg6: compute_capability (int)
)
else:
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1,
)
_TRITON_KERNEL_CACHE[cache_key] = kernel
return kernel
......
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_permute_with_probs,
moe_permute_and_pad_with_probs,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""MoE Permutaion API"""
"""MoE Permutation API"""
import warnings
from typing import Optional, Tuple
import torch
......@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
routing_map: torch.Tensor,
num_out_tokens: int,
probs: torch.Tensor,
pad_offsets: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function):
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
......@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
probs,
fp8_scale,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
......@@ -292,7 +296,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
......@@ -307,12 +311,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None, ctx.probs
return permuted_act_grad, None, None, ctx.probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
assert not isinstance(
permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8."
......@@ -321,13 +325,14 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
None,
permuted_probs_grad,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
return act_grad, None, None, probs_grad, None
class _moe_unpermute_mask_map(torch.autograd.Function):
......@@ -340,6 +345,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map: torch.Tensor,
merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size],
pad_offsets: Optional[torch.Tensor],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -358,6 +364,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert not isinstance(
inp, QuantizedTensor
......@@ -367,15 +375,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
merging_probs,
None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs)
ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
else:
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0)
......@@ -387,15 +396,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.merging_probs, None
return unpermuted_act_grad, None, ctx.merging_probs, None, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, merging_probs = ctx.saved_tensors
fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
......@@ -441,6 +450,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -453,6 +463,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
None,
fp8_scale,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -497,7 +508,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None
return act_grad, None, probs_grad, None, None
def moe_permute(
......@@ -537,7 +548,9 @@ def moe_permute(
if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
output, row_id_map, _ = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, None, None
)
return output, row_id_map
raise ValueError("map_type should be one of 'mask' or 'index'")
......@@ -570,11 +583,67 @@ def moe_permute_with_probs(
By default, set to '-1', meaning no tokens are dropped.
"""
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, probs
inp, routing_map, num_out_tokens, probs, None
)
return output, permuted_probs, row_id_map
def moe_permute_and_pad_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
tokens_per_expert: torch.Tensor,
align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor.
"""
assert (
tokens_per_expert is not None
), "tokens_per_expert must be provided to the fused permute padding function."
assert align_size > 0, f"align_size must be positive, got {align_size}"
# Ensure tokens_per_expert is on the same device as input to avoid device transfers
if tokens_per_expert.device != inp.device:
tokens_per_expert = tokens_per_expert.to(inp.device)
# Calculate aligned token counts per expert
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
if torch.equal(tokens_per_expert, target_tokens_per_expert):
pad_offsets = None
else:
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = torch.cumsum(pad_lengths, dim=0)
pad_offsets = torch.cat(
[torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
)
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -582,6 +651,7 @@ def moe_unpermute(
restore_shape: Optional[torch.Size] = None,
map_type: str = "mask",
probs: Optional[torch.Tensor] = None,
pad_offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -605,6 +675,10 @@ def moe_unpermute(
Options are: 'mask', 'index'.
probs : torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.
"""
if probs is not None:
if merging_probs is not None:
......@@ -616,7 +690,9 @@ def moe_unpermute(
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
return _moe_unpermute_mask_map.apply(
inp, row_id_map, merging_probs, restore_shape, pad_offsets
)
raise ValueError("map_type should be one of 'mask' or 'index'")
......
......@@ -123,6 +123,7 @@ def permute_with_mask_map(
row_id_map: torch.Tensor,
probs: torch.Tensor,
scale: torch.Tensor,
pad_offsets: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -142,6 +143,9 @@ def permute_with_mask_map(
The probabilities of the input tensor. If it is not None, it will be permuted.
scale : torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
......@@ -153,18 +157,18 @@ def permute_with_mask_map(
scale_hidden_dim : int
Hidden size of the scale tensor.
"""
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
# since the kernel doesn't write to padding positions.
alloc = torch.zeros if pad_offsets is not None else torch.empty
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = (
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
)
permuted_scale = (
torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
if scale is not None
else None
)
# pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid](
......@@ -173,6 +177,7 @@ def permute_with_mask_map(
probs,
scale,
permuted_scale,
pad_offsets,
scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
......@@ -193,6 +198,7 @@ def permute_with_mask_map(
hidden_size,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
FUSION_PAD=pad_offsets is not None,
)
return output, permuted_scale, permuted_probs
......@@ -202,6 +208,7 @@ def unpermute_with_mask_map(
row_id_map: torch.Tensor,
merging_probs: Union[torch.Tensor, None],
permuted_probs: Union[torch.Tensor, None],
pad_offsets: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
hidden_size: int,
......@@ -220,6 +227,9 @@ def unpermute_with_mask_map(
to reduce the unpermuted tokens.
permuted_probs : torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding.
If it is not None, it will remove the previously fused padding.
num_tokens : int
Number of tokens in the permuted tensor.
num_experts : int
......@@ -241,6 +251,7 @@ def unpermute_with_mask_map(
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......@@ -259,6 +270,7 @@ def unpermute_with_mask_map(
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
FUSION_UNPAD=pad_offsets is not None,
)
return output, unpermuted_probs
......@@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
row_id_map: torch.Tensor,
fwd_input: torch.Tensor,
merging_probs: torch.Tensor,
pad_offsets: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs : torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int
Number of tokens in the permuted tensor.
num_experts : int
......@@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
hidden_size : int
Hidden size of the output tensor.
"""
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
# Use zeros when pad_offsets is used because padding slots won't be written to
# by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros
# out the padding slots.
alloc = torch.zeros if pad_offsets is not None else torch.empty
act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda")
merging_probs_grad = torch.empty(
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
)
......@@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
row_id_map.stride(0),
row_id_map.stride(1),
fwd_output_grad.stride(0),
......@@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
FUSION_UNPAD=pad_offsets is not None,
)
return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment