Unverified Commit 46c6ef31 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Jax primitives for permutation on single GPU (#2473)



* branch off of initial permutation jax-triton PR
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Set 0 as the size of dummy tensors to reduce memory usage.
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Implement primitives, wrapper, test for wrapper, edit trit
on binding to accomodate scalars
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change implemementation of VJP functions to match correct pattern. Deduce some static scalar args from shapes of inputs. Accept B, S instead of num_tokens. Change test to use value_and_grad to test vjp funcs properly
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* formatting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix pylint
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix test to compare to the correct reference impl. relax 1 tol for grad compare, fix lint the rightway
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test_permutation to use value_and_grad for reference impl, tighten tols, and add unpermute with probs for token combine bwd rule
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* added forgotten file in prev commit
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* merge with_probs to without_probs
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* add aserts and fix lint
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dbaa02d0
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for permutation Triton kernels and high-level APIs"""
import jax
import jax.numpy as jnp
import pytest
# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
sort_chunks_by_index,
)
from utils import assert_allclose
def reference_make_row_id_map(
routing_map: jnp.ndarray,
num_tokens: int,
num_experts: int,
) -> jnp.ndarray:
"""
Reference implementation of make_row_id_map using JAX primitives.
Parameters
----------
routing_map : jnp.ndarray
Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts
are routed to which tokens (1 = routed, 0 = not routed).
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts in the input tensor.
Returns
-------
row_id_map : jnp.ndarray
The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1].
"""
row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32)
# For each expert, compute cumulative sum to get destination indices
cumsum_per_expert = jnp.cumsum(routing_map, axis=0)
# Compute total tokens per expert
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
# Build the row_id_map
for token_idx in range(num_tokens):
routed_experts = jnp.where(routing_map[token_idx] == 1)[0]
n_routed = len(routed_experts)
# Store number of routed experts in the last position
row_id_map = row_id_map.at[token_idx, -1].set(n_routed)
# For each routed expert, compute destination row and store it
dest_rows = []
expert_indices = []
for expert_idx in routed_experts:
# Destination row = expert offset + (cumsum - 1)
dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1
dest_rows.append(dest_row)
expert_indices.append(expert_idx)
# Sort by destination row
if n_routed > 0:
sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort
sorted_dest_rows = jnp.array(dest_rows)[sort_indices]
sorted_expert_indices = jnp.array(expert_indices)[sort_indices]
# Store sorted destination rows and expert indices
for i in range(n_routed):
row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i])
row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i])
return row_id_map
def _reference_permute_impl(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> tuple:
"""
Internal helper for reference permutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
probs : jnp.ndarray
The probabilities of the input tensor.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts.
num_out_tokens : int
Number of tokens in the permuted tensor.
hidden_size : int
Hidden size of the input tensor.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray
Permuted probabilities if probs was provided, None otherwise.
"""
output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype)
for token_idx in range(num_tokens):
n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range()
for i in range(n_routed):
# Don't use int() here - JAX can index with traced values,
# and int() breaks autodiff gradient tracking
dest_row = row_id_map[token_idx, i]
expert_idx = row_id_map[token_idx, num_experts + i]
# Get probability for this expert
if probs is not None:
if probs.ndim == 1:
prob = probs[token_idx]
else:
prob = probs[token_idx, expert_idx]
# Match kernel behavior: if prob == 0.0, zero out the output (padding indicator)
if prob == 0.0:
output = output.at[dest_row].set(0.0)
else:
output = output.at[dest_row].set(inp[token_idx])
permuted_probs = permuted_probs.at[dest_row].set(prob)
else:
output = output.at[dest_row].set(inp[token_idx])
return output, permuted_probs
def _reference_unpermute_impl(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: jnp.ndarray,
permuted_probs: jnp.ndarray,
num_tokens: int,
num_experts: int,
hidden_size: int,
) -> tuple:
"""
Internal helper for reference unpermutation implementation.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
permuted_probs : jnp.ndarray
The permuted probabilities.
num_tokens : int
Number of tokens.
num_experts : int
Number of experts.
hidden_size : int
Hidden size.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
unpermuted_probs : jnp.ndarray
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
unpermuted_probs = (
None
if permuted_probs is None
else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype)
)
for token_idx in range(num_tokens):
n_routed = int(row_id_map[token_idx, -1]) # int() needed for Python range()
for i in range(n_routed):
# Don't use int() here - JAX can index with traced values,
# and int() breaks autodiff gradient tracking
src_row = row_id_map[token_idx, i]
expert_idx = row_id_map[token_idx, num_experts + i]
if merging_probs is not None:
weight = merging_probs[token_idx, expert_idx]
output = output.at[token_idx].add(inp[src_row] * weight)
else:
output = output.at[token_idx].add(inp[src_row])
if permuted_probs is not None:
unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set(
permuted_probs[src_row]
)
return output, unpermuted_probs
def reference_token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: jnp.ndarray = None,
) -> tuple:
"""
Reference implementation of token_dispatch using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [num_tokens, num_experts].
num_out_tokens : int
Number of tokens in the permuted tensor.
probs : jnp.ndarray, optional
The probabilities of shape [num_tokens, num_experts].
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : jnp.ndarray or None
Permuted probabilities of shape [num_out_tokens], or None if probs not provided.
row_id_map : jnp.ndarray
The row_id_map for the permutation.
"""
num_tokens, num_experts = routing_map.shape
hidden_size = inp.shape[1]
row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts)
output, permuted_probs = _reference_permute_impl(
inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size
)
return output, permuted_probs, row_id_map
def reference_token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: jnp.ndarray,
) -> jnp.ndarray:
"""
Reference implementation of token_combine using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1].
merging_probs : jnp.ndarray
The merging probabilities for weighted reduction.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape [num_tokens, hidden_size].
"""
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
hidden_size = inp.shape[1]
output, _ = _reference_unpermute_impl(
inp, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size
)
return output
def reference_make_chunk_sort_map(
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
num_tokens: int,
num_splits: int,
) -> jnp.ndarray:
"""
Reference implementation of make_chunk_sort_map using JAX primitives.
Parameters
----------
split_sizes : jnp.ndarray
The sizes of the chunks of shape [num_splits,].
sorted_indices : jnp.ndarray
The indices of the sorted chunks of shape [num_splits,].
num_tokens : int
Number of tokens.
num_splits : int
Number of splits.
Returns
-------
row_id_map : jnp.ndarray
Row ID map for chunk sorting of shape [num_tokens,].
"""
row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32)
# Compute cumulative positions
cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
# For each chunk, compute the destination indices
dest_offset = 0
for sorted_idx in sorted_indices:
chunk_start = cumsum_sizes[sorted_idx]
chunk_end = cumsum_sizes[sorted_idx + 1]
chunk_size = chunk_end - chunk_start
# Map source positions to destination positions
for i in range(chunk_size):
row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i)
dest_offset += chunk_size
return row_id_map
def reference_sort_chunks_by_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: jnp.ndarray,
num_tokens: int,
hidden_size: int,
is_forward: bool,
) -> tuple:
"""
Reference implementation of sort_chunks_by_map using JAX primitives.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
The token to destination mapping of shape [num_tokens,].
probs : jnp.ndarray
The probabilities.
num_tokens : int
Number of tokens.
hidden_size : int
Hidden size.
is_forward : bool
Whether this is forward or backward.
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
permuted_probs : jnp.ndarray
Sorted probabilities if probs was provided, None otherwise.
"""
output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype)
permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype)
if is_forward:
# Forward: src -> dest
for src_idx in range(num_tokens):
# Don't use int() - JAX can index with traced values
dest_idx = row_id_map[src_idx]
output = output.at[dest_idx].set(inp[src_idx])
if probs is not None:
permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx])
else:
# Backward: dest -> src
for dest_idx in range(num_tokens):
# Don't use int() - JAX can index with traced values
src_idx = row_id_map[dest_idx]
output = output.at[dest_idx].set(inp[src_idx])
if probs is not None:
permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx])
return output, permuted_probs
class TestHighLevelPermutationAPI:
"""Test high-level permutation APIs (token_dispatch, token_combine, etc.)
These tests compare the high-level APIs against reference implementations
to verify correctness of both forward and backward passes.
"""
@staticmethod
def generate_routing_map(
num_tokens: int,
num_experts: int,
tokens_per_expert: int = 2,
key: jax.Array = None,
):
"""Generate random routing map for testing"""
if key is None:
key = jax.random.PRNGKey(0)
routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
for token_idx in range(num_tokens):
key, subkey = jax.random.split(key)
expert_indices = jax.random.choice(
subkey, num_experts, shape=(tokens_per_expert,), replace=False
)
routing_map = routing_map.at[token_idx, expert_indices].set(1)
return routing_map
# =========================================================================
# token_dispatch tests
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_token_dispatch(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype):
"""Test token_dispatch forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data
key, inp_key = jax.random.split(key)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Define loss functions
def loss_fn(x):
output, _, _ = token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(output**2)
def ref_loss_fn(x):
output, _, _ = reference_token_dispatch(x, routing_map, num_out_tokens)
return jnp.sum(output**2)
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# Compare forward outputs
output, _, _ = token_dispatch(inp, routing_map, num_out_tokens)
ref_output, _, _ = reference_token_dispatch(inp, routing_map, num_out_tokens)
assert_allclose(output, ref_output)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
# =========================================================================
# token_dispatch with probs tests
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_token_dispatch_with_probs(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype
):
"""Test token_dispatch with probs forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data and probs
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0
)
# Define loss function that uses token_dispatch with probs
# We compute gradients w.r.t. both inp and probs
def loss_fn(x, p):
output, permuted_probs, _ = token_dispatch(x, routing_map, num_out_tokens, probs=p)
return jnp.sum(output**2) + jnp.sum(permuted_probs**2)
def ref_loss_fn(x, p):
output, permuted_probs, _ = reference_token_dispatch(
x, routing_map, num_out_tokens, probs=p
)
return jnp.sum(output**2) + jnp.sum(permuted_probs**2)
loss_val, (inp_grad, probs_grad) = jax.value_and_grad(loss_fn, argnums=(0, 1))(inp, probs)
ref_loss_val, (ref_inp_grad, ref_probs_grad) = jax.value_and_grad(
ref_loss_fn, argnums=(0, 1)
)(inp, probs)
output, permuted_probs, _ = token_dispatch(inp, routing_map, num_out_tokens, probs=probs)
ref_output, ref_permuted_probs, _ = reference_token_dispatch(
inp, routing_map, num_out_tokens, probs=probs
)
# Compare forward outputs
assert_allclose(output, ref_output)
assert_allclose(permuted_probs, ref_permuted_probs)
# Compare loss and gradients
assert_allclose(loss_val, ref_loss_val)
assert_allclose(inp_grad, ref_inp_grad)
assert_allclose(probs_grad, ref_probs_grad)
# =========================================================================
# token_combine tests
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize("with_merging_probs", [True, False])
def test_token_combine(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_merging_probs
):
"""Test token_combine forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Get row_id_map from reference_token_dispatch
key, dummy_key = jax.random.split(key)
dummy_inp = jax.random.uniform(
dummy_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
_, _, row_id_map = reference_token_dispatch(dummy_inp, routing_map, num_out_tokens)
# Generate input data (from expert outputs)
key, inp_key, merge_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
if with_merging_probs:
merging_probs = jax.random.uniform(
merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0
)
# Normalize per token
merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8)
else:
merging_probs = None
# Define loss functions
def loss_fn(x):
output = token_combine(x, row_id_map, merging_probs)
return jnp.sum(output**2)
def ref_loss_fn(x):
output = reference_token_combine(x, row_id_map, merging_probs)
return jnp.sum(output**2)
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# Compare forward outputs
output = token_combine(inp, row_id_map, merging_probs)
ref_output = reference_token_combine(inp, row_id_map, merging_probs)
assert_allclose(output, ref_output)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
# =========================================================================
# sort_chunks_by_index tests
# =========================================================================
@pytest.mark.parametrize(
"num_splits,total_tokens,hidden_size",
[
(4, 128, 256),
(8, 256, 512),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_sort_chunks_by_index(self, num_splits, total_tokens, hidden_size, dtype):
"""Test sort_chunks_by_index forward and backward pass against reference"""
key = jax.random.PRNGKey(42)
# Generate random split sizes
key, size_key = jax.random.split(key)
split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits)
split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1]))
# Generate sorted indices
key, sort_key = jax.random.split(key)
sorted_indices = jax.random.permutation(sort_key, num_splits)
# Generate input data
key, inp_key = jax.random.split(key)
inp = jax.random.uniform(
inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
row_id_map = reference_make_chunk_sort_map(
split_sizes, sorted_indices, total_tokens, num_splits
)
# Define loss functions
def loss_fn(x):
output, _ = sort_chunks_by_index(x, split_sizes, sorted_indices)
return jnp.sum(output**2)
def ref_loss_fn(x):
output, _ = reference_sort_chunks_by_map(
x, row_id_map, None, total_tokens, hidden_size, is_forward=True
)
return jnp.sum(output**2)
loss_val, computed_grad = jax.value_and_grad(loss_fn)(inp)
ref_loss_val, ref_grad = jax.value_and_grad(ref_loss_fn)(inp)
# Compare forward outputs
output, _ = sort_chunks_by_index(inp, split_sizes, sorted_indices)
ref_output, _ = reference_sort_chunks_by_map(
inp, row_id_map, None, total_tokens, hidden_size, is_forward=True
)
assert_allclose(output, ref_output)
# Compare loss and gradient
assert_allclose(loss_val, ref_loss_val)
assert_allclose(computed_grad, ref_grad)
# =========================================================================
# Round-trip tests (token_dispatch -> expert processing -> token_combine)
# =========================================================================
@pytest.mark.parametrize(
"num_tokens,num_experts,hidden_size,tokens_per_expert",
[
(32, 8, 256, 2),
(64, 16, 512, 3),
],
)
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
def test_dispatch_combine_roundtrip(
self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype
):
"""Test that token_dispatch followed by token_combine recovers original input"""
key = jax.random.PRNGKey(42)
# Generate routing map
routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key)
num_out_tokens = int(jnp.sum(routing_map))
# Generate input data
key, inp_key = jax.random.split(key)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
# Create uniform merging probs (equal weight for all routed experts)
merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
# Dispatch tokens to experts (returns output, permuted_probs, row_id_map)
dispatched, _, row_id_map = token_dispatch(inp, routing_map, num_out_tokens)
# Combine tokens back (with uniform merging) (new signature)
combined = token_combine(dispatched, row_id_map, merging_probs)
# Compare with original input
assert_allclose(combined, inp)
...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
transpose_batch_sequence, transpose_batch_sequence,
): ):
""" """
amax calcuation abstract amax calculation abstract
""" """
del amax_scope, transpose_batch_sequence del amax_scope, transpose_batch_sequence
...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive): ...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive):
flatten_axis, flatten_axis,
): ):
""" """
amax calcuation implementation amax calculation implementation
""" """
assert RHTAmaxCalculationPrimitive.inner_primitive is not None assert RHTAmaxCalculationPrimitive.inner_primitive is not None
( (
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MoE Permutation API for JAX.
This module provides high-level token dispatch and combine operations for
Mixture of Experts (MoE) models with proper automatic differentiation support.
Token Dispatch (Permute):
- Forward: Permute tokens according to routing map (scatter to experts)
- Backward: Unpermute gradients (gather from experts)
Token Combine (Unpermute):
- Forward: Unpermute tokens and merge with weights (gather from experts)
- Backward: Permute gradients (scatter to experts)
"""
from functools import partial
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map,
permute_with_mask_map,
unpermute_with_mask_map,
unpermute_bwd_with_merging_probs,
make_chunk_sort_map,
sort_chunks_by_map,
)
__all__ = [
"token_dispatch",
"token_combine",
"sort_chunks_by_index",
]
def token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
"""
Dispatch tokens to experts based on routing map.
This is the forward pass of the MoE permutation. Tokens are scattered
to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed.
num_out_tokens : int
The number of output tokens after permutation. This should equal the sum of
routing_map and must be provided explicitly for JIT compatibility.
probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens], or None if probs was not provided.
row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
"""
return _token_dispatch(inp, routing_map, probs, num_out_tokens)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule(
inp, routing_map, probs, num_out_tokens
)
return output, permuted_probs, row_id_map
def _token_dispatch_fwd_rule(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[
Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, int, int, int, bool],
]:
"""Forward pass rule for token_dispatch."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
assert routing_map.ndim in [2, 3], f"routing_map must be 2D or 3D, got {routing_map.ndim}D"
# Infer dimensions from input shapes
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_experts = routing_map.shape[-1]
# Verify consistency between inp and routing_map
routing_num_tokens = (
routing_map.shape[0] * routing_map.shape[1]
if routing_map.ndim == 3
else routing_map.shape[0]
)
assert num_tokens == routing_num_tokens, (
f"Token count mismatch: inp has {num_tokens} tokens, "
f"routing_map has {routing_num_tokens} tokens"
)
# Always compute row_id_map internally from routing_map
row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)
with_probs = probs is not None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals)
# Include with_probs flag to know how to handle backward pass
residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs)
return (output, permuted_probs, row_id_map), residuals
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
residuals: Tuple[jnp.ndarray, int, int, int, bool],
g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient
# Backward: unpermute gradients (gather from experts back to tokens)
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
# =============================================================================
# Token Combine (Unpermute) with VJP
# =============================================================================
def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Combine tokens from experts back to original token positions.
This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs).
Parameters
----------
inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly.
Returns
-------
output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size].
"""
return _token_combine(inp, row_id_map, merging_probs)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> jnp.ndarray:
"""Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs)
return output
def _token_combine_fwd_rule(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]:
"""Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0]
# Call triton extension
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals)
# Include inp in residuals for backward with merging_probs
residuals = (
row_id_map,
inp,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
)
return output, residuals
def _token_combine_bwd_rule(
row_id_map: jnp.ndarray,
residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int],
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_combine."""
(
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
) = residuals
output_grad = g
with_merging_probs = merging_probs is not None
if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else:
# Simple case: just permute gradients back
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None
return inp_grad, merging_probs_grad
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
# =============================================================================
# Chunk Sort with VJP
# =============================================================================
def sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Sort chunks of tokens according to sorted indices.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
split_sizes : jnp.ndarray
Sizes of each chunk of shape [num_splits].
sorted_indices : jnp.ndarray
Permutation indices for chunks of shape [num_splits].
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
Row ID map for reversing the sort.
"""
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Internal sort_chunks_by_index with custom VJP."""
(output, row_id_map), _ = _sort_chunks_by_index_fwd_rule(inp, split_sizes, sorted_indices)
return output, row_id_map
def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
# Infer dimensions from input shape
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_splits = split_sizes.shape[0]
row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, num_tokens, num_splits)
output, _ = sort_chunks_by_map(
inp,
row_id_map,
None, # No probs
num_tokens,
hidden_size,
is_forward=True,
)
# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
return (output, row_id_map), residuals
def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
output_grad, _ = g
# Backward: reverse the sort
inp_grad, _ = sort_chunks_by_map(
output_grad,
row_id_map,
None,
num_tokens,
hidden_size,
is_forward=False,
)
return (inp_grad,)
_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
...@@ -20,6 +20,10 @@ Usage: ...@@ -20,6 +20,10 @@ Usage:
@staticmethod @staticmethod
def lowering(ctx, x, **kwargs): def lowering(ctx, x, **kwargs):
return triton_call_lowering(ctx, my_kernel, x, ...) return triton_call_lowering(ctx, my_kernel, x, ...)
# Use permutation functions
from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map
""" """
from .utils import * from .utils import *
from .permutation import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for permutation in MOE using Triton kernels."""
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import triton
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
from transformer_engine.common.triton.permutation import (
_row_id_map_pass_1_kernel,
_row_id_map_pass_2_kernel,
_row_id_map_pass_3_kernel,
_permute_kernel,
_unpermute_kernel,
_unpermute_bwd_with_merging_probs_kernel,
_make_chunk_sort_map_kernel,
_sort_chunks_by_map_kernel,
)
from .utils import triton_call_lowering
__all__ = [
"make_row_id_map",
"permute_with_mask_map",
"unpermute_with_mask_map",
"unpermute_bwd_with_merging_probs",
"make_chunk_sort_map",
"sort_chunks_by_map",
]
DEFAULT_BLOCK_SIZE = 1024
def _get_min_block_size(kernel, default=128):
if hasattr(kernel, "configs"):
return min(config.kwargs.get("BLOCK_SIZE", default) for config in kernel.configs)
return default
class RowIdMapPass1Primitive(BasePrimitive):
"""
Pass 1 of row_id_map generation: block cumsum.
For each expert, compute the cumsum of every block_size tokens.
"""
name = "te_row_id_map_pass1_triton"
multiple_results = True
impl_static_args = (1, 2, 3) # num_tokens, num_experts, block_size
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(routing_map_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 1."""
del block_size # Only affects grid, not output shape
assert routing_map_aval.shape == (
num_tokens,
num_experts,
), f"routing_map shape mismatch: expected ({num_tokens}, {num_experts})"
row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = (
num_experts,
triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE),
)
return (
jax.core.ShapedArray(row_id_map_shape, jnp.int32),
jax.core.ShapedArray(workspace_shape, jnp.int32),
)
@staticmethod
def impl(routing_map, num_tokens, num_experts, block_size):
"""Forward to inner primitive."""
assert RowIdMapPass1Primitive.inner_primitive is not None
return RowIdMapPass1Primitive.inner_primitive.bind(
routing_map,
num_tokens=num_tokens,
num_experts=num_experts,
block_size=block_size,
)
@staticmethod
def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
routing_stride_token = num_experts
routing_stride_expert = 1
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
grid = (num_experts, triton.cdiv(num_tokens, block_size))
# All scalar arguments must be passed as constexprs
return triton_call_lowering(
ctx,
_row_id_map_pass_1_kernel,
routing_map, # Only tensor arguments here
grid=grid,
constexprs={
"num_tokens": num_tokens,
"stride_routing_map_token": routing_stride_token,
"stride_routing_map_expert": routing_stride_expert,
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"BLOCK_SIZE": block_size,
},
)
register_primitive(RowIdMapPass1Primitive)
class RowIdMapPass2Primitive(BasePrimitive):
"""
Pass 2 of row_id_map generation: cumsum all and process the mask.
"""
name = "te_row_id_map_pass2_triton"
multiple_results = True
impl_static_args = (2, 3, 4) # num_tokens, num_experts, block_size
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 2 (in-place operation)."""
del row_id_map_aval, workspace_aval
del block_size
row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE))
return (
jax.core.ShapedArray(row_id_map_shape, jnp.int32),
jax.core.ShapedArray(workspace_shape, jnp.int32),
)
@staticmethod
def impl(row_id_map, workspace, num_tokens, num_experts, block_size):
"""Forward to inner primitive."""
assert RowIdMapPass2Primitive.inner_primitive is not None
return RowIdMapPass2Primitive.inner_primitive.bind(
row_id_map,
workspace,
num_tokens=num_tokens,
num_experts=num_experts,
block_size=block_size,
)
@staticmethod
def lowering(ctx, row_id_map, workspace, *, num_tokens, num_experts, block_size):
"""MLIR lowering using triton_call_lowering."""
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
grid = (num_experts, triton.cdiv(num_tokens, block_size))
workspace_load_width = triton.next_power_of_2(
num_experts * triton.cdiv(num_tokens, block_size)
)
return triton_call_lowering(
ctx,
_row_id_map_pass_2_kernel,
row_id_map,
workspace,
grid=grid,
input_output_aliases={0: 0, 1: 1},
constexprs={
"num_tokens": num_tokens,
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"WORKSPACE_LOAD_WIDTH": workspace_load_width,
"BLOCK_SIZE": block_size,
},
)
register_primitive(RowIdMapPass2Primitive)
class RowIdMapPass3Primitive(BasePrimitive):
"""
Pass 3 of row_id_map generation: make the row_id_map from sparse to dense structure.
"""
name = "te_row_id_map_pass3_triton"
multiple_results = False
impl_static_args = (1, 2) # num_tokens, num_experts
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(row_id_map_aval, *, num_tokens, num_experts):
"""Shape/dtype inference for pass 3 (in-place operation)."""
del row_id_map_aval
row_id_map_shape = (num_tokens, num_experts * 2 + 1)
return jax.core.ShapedArray(row_id_map_shape, jnp.int32)
@staticmethod
def impl(row_id_map, num_tokens, num_experts):
"""Forward to inner primitive."""
assert RowIdMapPass3Primitive.inner_primitive is not None
return RowIdMapPass3Primitive.inner_primitive.bind(
row_id_map,
num_tokens=num_tokens,
num_experts=num_experts,
)
@staticmethod
def lowering(ctx, row_id_map, *, num_tokens, num_experts):
"""MLIR lowering using triton_call_lowering."""
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
grid = (num_tokens,)
load_size = triton.next_power_of_2(num_experts)
return triton_call_lowering(
ctx,
_row_id_map_pass_3_kernel,
row_id_map,
grid=grid,
input_output_aliases={0: 0},
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"num_experts": num_experts,
"LOAD_SIZE": load_size,
},
)
register_primitive(RowIdMapPass3Primitive)
class PermuteWithMaskMapPrimitive(BasePrimitive):
"""
Permute the input tensor based on the row_id_map.
"""
name = "te_permute_with_mask_map_triton"
multiple_results = True
# scale and permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False)
# but they need to be in the signature for the kernel call
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
inp_aval,
row_id_map_aval,
probs_aval,
scale_aval, # dummy, same shape as inp
permuted_scale_aval, # dummy, same shape as inp
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
):
"""Shape/dtype inference for permute."""
del row_id_map_aval, scale_aval, permuted_scale_aval
del num_tokens, num_experts
output_shape = (num_out_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
if with_probs:
permuted_probs_aval = jax.core.ShapedArray((num_out_tokens,), probs_aval.dtype)
else:
permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)
return output_aval, permuted_probs_aval
@staticmethod
def impl(
inp,
row_id_map,
probs,
scale,
permuted_scale,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
):
"""Forward to inner primitive."""
assert PermuteWithMaskMapPrimitive.inner_primitive is not None
return PermuteWithMaskMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
probs,
scale,
permuted_scale,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
)
@staticmethod
def lowering(
ctx,
inp,
row_id_map,
probs,
scale,
permuted_scale,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
output_stride_hidden = 1
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
permuted_probs_stride_token = 1
if with_probs:
# Check if probs is 2D [num_tokens, num_experts] or 1D [num_tokens]
probs_aval = ctx.avals_in[2]
if len(probs_aval.shape) > 1:
probs_stride_token = num_experts
probs_stride_expert = 1
else:
probs_stride_token = 1
probs_stride_expert = 1
else:
probs_stride_token = 0
probs_stride_expert = 0
# Grid function equivalent: (num_tokens, cdiv(hidden_size, BLOCK_SIZE))
# Use minimum BLOCK_SIZE from autotune configs to ensure grid covers all elements
block_size = _get_min_block_size(_permute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
return triton_call_lowering(
ctx,
_permute_kernel,
inp,
row_id_map,
probs,
scale,
permuted_scale,
grid=grid,
constexprs={
"scale_hidden_dim": 0,
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
"stride_output_token": output_stride_token,
"stride_output_hidden": output_stride_hidden,
"stride_probs_token": probs_stride_token,
"stride_probs_expert": probs_stride_expert,
"stride_scale_token": hidden_size,
"stride_scale_hidden": 1,
"stride_permuted_probs_token": permuted_probs_stride_token,
"stride_permuted_scale_token": hidden_size,
"stride_permuted_scale_hidden": 1,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PERMUTE_PROBS": with_probs,
"PERMUTE_SCALE": False,
"BLOCK_SIZE": block_size,
},
)
register_primitive(PermuteWithMaskMapPrimitive)
class UnpermuteWithMaskMapPrimitive(BasePrimitive):
"""
Unpermute the input tensor based on the row_id_map.
"""
name = "te_unpermute_with_mask_map_triton"
multiple_results = True
impl_static_args = (
4,
5,
6,
7,
8,
) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
inp_aval,
row_id_map_aval,
merging_probs_aval,
permuted_probs_aval,
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""Shape/dtype inference for unpermute."""
del row_id_map_aval, merging_probs_aval, with_merging_probs
output_shape = (num_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
if with_probs:
unpermuted_probs_shape = (num_tokens, num_experts)
unpermuted_probs_aval = jax.core.ShapedArray(
unpermuted_probs_shape, permuted_probs_aval.dtype
)
else:
unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)
return output_aval, unpermuted_probs_aval
@staticmethod
def impl(
inp,
row_id_map,
merging_probs,
permuted_probs,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""Forward to inner primitive."""
assert UnpermuteWithMaskMapPrimitive.inner_primitive is not None
return UnpermuteWithMaskMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
)
@staticmethod
def lowering(
ctx,
inp,
row_id_map,
merging_probs,
permuted_probs,
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
output_stride_hidden = 1
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
if with_merging_probs:
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
else:
merging_probs_stride_token = 0
merging_probs_stride_expert = 0
permuted_probs_stride_token = 1
unpermuted_probs_stride_token = num_experts
unpermuted_probs_stride_expert = 1
# Grid - use minimum BLOCK_SIZE from autotune configs
block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
return triton_call_lowering(
ctx,
_unpermute_kernel,
inp,
row_id_map,
merging_probs,
permuted_probs,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
"stride_output_token": output_stride_token,
"stride_output_hidden": output_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_permuted_probs_token": permuted_probs_stride_token,
"stride_unpermuted_probs_token": unpermuted_probs_stride_token,
"stride_unpermuted_probs_expert": unpermuted_probs_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"WITH_MERGING_PROBS": with_merging_probs,
"PERMUTE_PROBS": with_probs,
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteWithMaskMapPrimitive)
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
"""
Backward pass for unpermute with merging probabilities.
This kernel computes gradients for both the input and merging_probs.
"""
name = "te_unpermute_bwd_with_merging_probs_triton"
multiple_results = True
impl_static_args = (4, 5, 6, 7) # num_tokens, num_experts, num_out_tokens, hidden_size
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
fwd_output_grad_aval,
fwd_input_aval,
merging_probs_aval,
row_id_map_aval,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""Shape/dtype inference for unpermute backward with merging probs."""
del fwd_input_aval, row_id_map_aval
# fwd_input_grad has same shape as fwd_input
fwd_input_grad_shape = (num_out_tokens, hidden_size)
fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype)
# merging_probs_grad has same shape as merging_probs
merging_probs_grad_shape = (num_tokens, num_experts)
merging_probs_grad_aval = jax.core.ShapedArray(
merging_probs_grad_shape, merging_probs_aval.dtype
)
return fwd_input_grad_aval, merging_probs_grad_aval
@staticmethod
def impl(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""Forward to inner primitive."""
assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None
return UnpermuteBwdWithMergingProbsPrimitive.inner_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
)
@staticmethod
def lowering(
ctx,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
# Compute strides
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
fwd_output_grad_stride_token = hidden_size
fwd_output_grad_stride_hidden = 1
fwd_input_grad_stride_token = hidden_size
fwd_input_grad_stride_hidden = 1
fwd_input_stride_token = hidden_size
fwd_input_stride_hidden = 1
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
merging_probs_grad_stride_token = num_experts
merging_probs_grad_stride_expert = 1
# Grid - one program per token
grid = (num_tokens,)
# Get min block size from autotune configs for consistency
block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)
# Pass inputs in kernel argument order: fwd_output_grad, fwd_input, merging_probs, row_id_map
return triton_call_lowering(
ctx,
_unpermute_bwd_with_merging_probs_kernel,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_fwd_output_grad_token": fwd_output_grad_stride_token,
"stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden,
"stride_fwd_input_grad_token": fwd_input_grad_stride_token,
"stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden,
"stride_fwd_input_token": fwd_input_stride_token,
"stride_fwd_input_hidden": fwd_input_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_merging_probs_grad_token": merging_probs_grad_stride_token,
"stride_merging_probs_grad_expert": merging_probs_grad_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteBwdWithMergingProbsPrimitive)
def unpermute_bwd_with_merging_probs(
fwd_output_grad: jnp.ndarray,
row_id_map: jnp.ndarray,
fwd_input: jnp.ndarray,
merging_probs: jnp.ndarray,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Backward pass for unpermute with merging probabilities.
This computes gradients for both the input tensor and merging_probs.
Parameters
----------
fwd_output_grad : jnp.ndarray
Gradient of the forward output of shape `[num_tokens, hidden_size]`.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input : jnp.ndarray
The input tensor from the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs : jnp.ndarray
The merging probabilities of shape `[num_tokens, num_experts]`.
num_tokens : int
Number of tokens in the unpermuted tensor.
num_experts : int
Number of experts.
num_out_tokens : int
Number of tokens in the permuted tensor.
hidden_size : int
Hidden size.
Returns
-------
fwd_input_grad : jnp.ndarray
Gradient w.r.t. the input tensor of shape `[num_out_tokens, hidden_size]`.
merging_probs_grad : jnp.ndarray
Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
"""
# Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map
return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
)
class MakeChunkSortMapPrimitive(BasePrimitive):
"""
Make a row_id_map for chunk sort.
"""
name = "te_make_chunk_sort_map_triton"
multiple_results = False
impl_static_args = (2, 3) # num_tokens, num_splits
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(split_sizes_aval, sorted_indices_aval, *, num_tokens, num_splits):
"""Shape/dtype inference."""
del sorted_indices_aval
assert split_sizes_aval.shape == (num_splits,)
return jax.core.ShapedArray((num_tokens,), jnp.int32)
@staticmethod
def impl(split_sizes, sorted_indices, num_tokens, num_splits):
"""Forward to inner primitive."""
assert MakeChunkSortMapPrimitive.inner_primitive is not None
return MakeChunkSortMapPrimitive.inner_primitive.bind(
split_sizes,
sorted_indices,
num_tokens=num_tokens,
num_splits=num_splits,
)
@staticmethod
def lowering(ctx, split_sizes, sorted_indices, *, num_tokens, num_splits):
"""MLIR lowering using triton_call_lowering."""
grid = (num_tokens,)
return triton_call_lowering(
ctx,
_make_chunk_sort_map_kernel,
split_sizes,
sorted_indices,
grid=grid,
constexprs={
"num_splits": num_splits,
"IDX_LOAD_WIDTH": triton.next_power_of_2(num_splits),
},
)
register_primitive(MakeChunkSortMapPrimitive)
class SortChunksByMapPrimitive(BasePrimitive):
"""
Sort chunks with row_id_map.
"""
name = "te_sort_chunks_by_map_triton"
multiple_results = True
impl_static_args = (3, 4, 5, 6) # num_tokens, hidden_size, is_forward, with_probs
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs
):
"""Shape/dtype inference."""
del row_id_map_aval, is_forward
output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype)
if with_probs:
permuted_probs_aval = jax.core.ShapedArray((num_tokens,), probs_aval.dtype)
else:
permuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)
return output_aval, permuted_probs_aval
@staticmethod
def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs):
"""Forward to inner primitive."""
assert SortChunksByMapPrimitive.inner_primitive is not None
return SortChunksByMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
probs,
num_tokens=num_tokens,
hidden_size=hidden_size,
is_forward=is_forward,
with_probs=with_probs,
)
@staticmethod
def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
output_stride_hidden = 1
probs_stride_token = 1
permuted_probs_stride_token = 1
# Grid - use minimum BLOCK_SIZE from autotune configs
block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
return triton_call_lowering(
ctx,
_sort_chunks_by_map_kernel,
inp,
row_id_map,
probs,
grid=grid,
constexprs={
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
"stride_output_token": output_stride_token,
"stride_output_hidden": output_stride_hidden,
"stride_probs_token": probs_stride_token,
"stride_permuted_probs_token": permuted_probs_stride_token,
"hidden_size": hidden_size,
"PERMUTE_PROBS": with_probs,
"FORWARD": is_forward,
"BLOCK_SIZE": block_size,
},
)
register_primitive(SortChunksByMapPrimitive)
def make_row_id_map(
routing_map: jnp.ndarray,
num_tokens: int,
num_experts: int,
) -> jnp.ndarray:
"""
Prepare the row_id_map for the permutation.
This function chains 3 Triton kernel passes together.
Parameters
----------
routing_map : jnp.ndarray
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts in the input tensor.
Returns
-------
row_id_map : jnp.ndarray
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens.
The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
to the first n_routed row indices above.
"""
block_size = DEFAULT_BLOCK_SIZE
# Pass 1: Block cumsum
row_id_map_pass1, workspace_tensor = RowIdMapPass1Primitive.outer_primitive.bind(
routing_map,
num_tokens=num_tokens,
num_experts=num_experts,
block_size=block_size,
)
# Pass 2: Cumsum all and process the mask
row_id_map_pass2, _ = RowIdMapPass2Primitive.outer_primitive.bind(
row_id_map_pass1,
workspace_tensor,
num_tokens=num_tokens,
num_experts=num_experts,
block_size=block_size,
)
# Initialize columns [num_experts:] to -1 since Pass 1/2 only wrote to [0:num_experts]
# Reference implementation expects -1 for invalid entries
row_id_map = row_id_map_pass2.at[:, num_experts:].set(-1)
# Pass 3: Make the row_id_map from sparse to dense structure
row_id_map = RowIdMapPass3Primitive.outer_primitive.bind(
row_id_map,
num_tokens=num_tokens,
num_experts=num_experts,
)
return row_id_map
def permute_with_mask_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Permute the input tensor based on the row_id_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs : Optional[jnp.ndarray]
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
Number of experts in the input tensor.
num_out_tokens : int
Number of tokens in the permuted tensor.
hidden_size : int
Hidden size of the input tensor.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape `[num_out_tokens, hidden_size]`.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities if probs was provided, None otherwise.
"""
with_probs = probs is not None
# Handle None probs by creating dummy tensor
if not with_probs:
probs = jnp.zeros((0,), dtype=inp.dtype)
# Create dummy scale tensors (not used when PERMUTE_SCALE=False, but required by kernel signature)
dummy_scale = inp
dummy_permuted_scale = inp
output, permuted_probs = PermuteWithMaskMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
probs,
dummy_scale,
dummy_permuted_scale,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
)
if not with_probs:
permuted_probs = None
return output, permuted_probs
def unpermute_with_mask_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
permuted_probs: Optional[jnp.ndarray],
num_tokens: int,
num_experts: int,
hidden_size: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Unpermute the input tensor based on the row_id_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs : Optional[jnp.ndarray]
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
permuted_probs : Optional[jnp.ndarray]
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens : int
Number of tokens in the permuted tensor.
num_experts : int
Number of experts in the permuted tensor.
hidden_size : int
Hidden size of the permuted tensor.
Returns
-------
output : jnp.ndarray
Unpermuted output tensor of shape `[num_tokens, hidden_size]`.
unpermuted_probs : Optional[jnp.ndarray]
Unpermuted probabilities if permuted_probs was provided, None otherwise.
"""
with_merging_probs = merging_probs is not None
with_probs = permuted_probs is not None
# Handle None inputs by creating dummy tensors
if not with_merging_probs:
merging_probs = jnp.zeros((0,), dtype=inp.dtype)
if not with_probs:
permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
)
if not with_probs:
unpermuted_probs = None
return output, unpermuted_probs
def make_chunk_sort_map(
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
num_tokens: int,
num_splits: int,
) -> jnp.ndarray:
"""
Make a row_id_map for chunk sort.
Parameters
----------
split_sizes : jnp.ndarray
The sizes of the chunks of shape `[num_splits,]`.
sorted_indices : jnp.ndarray
The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens : int
Number of tokens in the input tensor.
num_splits : int
Number of splits of split_sizes and sorted_indices.
Returns
-------
row_id_map : jnp.ndarray
Row ID map for chunk sorting of shape `[num_tokens,]`.
"""
return MakeChunkSortMapPrimitive.outer_primitive.bind(
split_sizes,
sorted_indices,
num_tokens=num_tokens,
num_splits=num_splits,
)
def sort_chunks_by_map(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_tokens: int,
hidden_size: int,
is_forward: bool,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Sort chunks with row_id_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map : jnp.ndarray
The token to expert mapping tensor of shape `[num_tokens,]`.
probs : Optional[jnp.ndarray]
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens : int
Number of tokens in the input tensor.
hidden_size : int
Hidden size of the input tensor.
is_forward : bool
Whether the sort is for forward or backward.
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape `[num_tokens, hidden_size]`.
permuted_probs : Optional[jnp.ndarray]
Sorted probabilities if probs was provided, None otherwise.
"""
with_probs = probs is not None
# Handle None probs by creating dummy tensor
if not with_probs:
probs = jnp.zeros((0,), dtype=inp.dtype)
output, permuted_probs = SortChunksByMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
probs,
num_tokens=num_tokens,
hidden_size=hidden_size,
is_forward=is_forward,
with_probs=with_probs,
)
if not with_probs:
permuted_probs = None
return output, permuted_probs
...@@ -176,7 +176,9 @@ def triton_call_lowering( ...@@ -176,7 +176,9 @@ def triton_call_lowering(
*array_args: Input arrays (from ctx) *array_args: Input arrays (from ctx)
grid: Grid dimensions (int or tuple) grid: Grid dimensions (int or tuple)
input_output_aliases: Mapping of input to output aliases input_output_aliases: Mapping of input to output aliases
constexprs: Compile-time constants for the kernel constexprs: Compile-time constants for the kernel. This includes both
tl.constexpr arguments AND scalar runtime arguments (like
num_tokens, strides) that are known at JAX trace time.
Returns: Returns:
MLIR lowering result MLIR lowering result
...@@ -189,8 +191,10 @@ def triton_call_lowering( ...@@ -189,8 +191,10 @@ def triton_call_lowering(
return triton_call_lowering( return triton_call_lowering(
ctx, my_kernel, x, ctx, my_kernel, x,
grid=(triton.cdiv(n, block_size),), grid=(triton.cdiv(n, block_size),),
n_elements=n, constexprs={
BLOCK_SIZE=block_size "n_elements": n, # scalar arg (not tl.constexpr in kernel)
"BLOCK_SIZE": block_size, # tl.constexpr arg
},
) )
""" """
# Get compute capability using gpu_triton # Get compute capability using gpu_triton
...@@ -203,9 +207,13 @@ def triton_call_lowering( ...@@ -203,9 +207,13 @@ def triton_call_lowering(
else: else:
arg_names = kernel_fn.arg_names arg_names = kernel_fn.arg_names
# Build signature for inputs + outputs # Build signature for tensor arguments only (inputs + outputs)
# Scalar arguments should be passed via constexprs and will be
# specialized into the kernel at compile time
all_avals = list(ctx.avals_in) + list(ctx.avals_out) all_avals = list(ctx.avals_in) + list(ctx.avals_out)
signature = {arg_names[i]: get_triton_dtype(aval) for i, aval in enumerate(all_avals)} constexpr_names = set(constexprs.keys()) if constexprs else set()
tensor_arg_names = [n for n in arg_names if n not in constexpr_names]
signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)}
# Normalize grid to 3D # Normalize grid to 3D
if isinstance(grid, int): if isinstance(grid, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment