Unverified Commit 46c6ef31 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Jax primitives for permutation on single GPU (#2473)



* branch off of initial permutation jax-triton PR
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Set 0 as the size of dummy tensors to reduce memory usage.
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Implement primitives, wrapper, test for wrapper, edit trit
on binding to accomodate scalars
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change implemementation of VJP functions to match correct pattern. Deduce some static scalar args from shapes of inputs. Accept B, S instead of num_tokens. Change test to use value_and_grad to test vjp funcs properly
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* formatting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix pylint
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix test to compare to the correct reference impl. relax 1 tol for grad compare, fix lint the rightway
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test_permutation to use value_and_grad for reference impl, tighten tols, and add unpermute with probs for token combine bwd rule
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* added forgotten file in prev commit
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* merge with_probs to without_probs
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* add aserts and fix lint
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dbaa02d0
This diff is collapsed.
...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive): ...@@ -73,7 +73,7 @@ class AmaxCalculationPrimitive(BasePrimitive):
transpose_batch_sequence, transpose_batch_sequence,
): ):
""" """
amax calcuation abstract amax calculation abstract
""" """
del amax_scope, transpose_batch_sequence del amax_scope, transpose_batch_sequence
...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive): ...@@ -251,7 +251,7 @@ class RHTAmaxCalculationPrimitive(BasePrimitive):
flatten_axis, flatten_axis,
): ):
""" """
amax calcuation implementation amax calculation implementation
""" """
assert RHTAmaxCalculationPrimitive.inner_primitive is not None assert RHTAmaxCalculationPrimitive.inner_primitive is not None
( (
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""MoE Permutation API for JAX.
This module provides high-level token dispatch and combine operations for
Mixture of Experts (MoE) models with proper automatic differentiation support.
Token Dispatch (Permute):
- Forward: Permute tokens according to routing map (scatter to experts)
- Backward: Unpermute gradients (gather from experts)
Token Combine (Unpermute):
- Forward: Unpermute tokens and merge with weights (gather from experts)
- Backward: Permute gradients (scatter to experts)
"""
from functools import partial
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map,
permute_with_mask_map,
unpermute_with_mask_map,
unpermute_bwd_with_merging_probs,
make_chunk_sort_map,
sort_chunks_by_map,
)
__all__ = [
"token_dispatch",
"token_combine",
"sort_chunks_by_index",
]
def token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
"""
Dispatch tokens to experts based on routing map.
This is the forward pass of the MoE permutation. Tokens are scattered
to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
routing_map : jnp.ndarray
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed.
num_out_tokens : int
The number of output tokens after permutation. This should equal the sum of
routing_map and must be provided explicitly for JIT compatibility.
probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens], or None if probs was not provided.
row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
"""
return _token_dispatch(inp, routing_map, probs, num_out_tokens)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule(
inp, routing_map, probs, num_out_tokens
)
return output, permuted_probs, row_id_map
def _token_dispatch_fwd_rule(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[
Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, int, int, int, bool],
]:
"""Forward pass rule for token_dispatch."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
assert routing_map.ndim in [2, 3], f"routing_map must be 2D or 3D, got {routing_map.ndim}D"
# Infer dimensions from input shapes
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_experts = routing_map.shape[-1]
# Verify consistency between inp and routing_map
routing_num_tokens = (
routing_map.shape[0] * routing_map.shape[1]
if routing_map.ndim == 3
else routing_map.shape[0]
)
assert num_tokens == routing_num_tokens, (
f"Token count mismatch: inp has {num_tokens} tokens, "
f"routing_map has {routing_num_tokens} tokens"
)
# Always compute row_id_map internally from routing_map
row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)
with_probs = probs is not None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals)
# Include with_probs flag to know how to handle backward pass
residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs)
return (output, permuted_probs, row_id_map), residuals
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
residuals: Tuple[jnp.ndarray, int, int, int, bool],
g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient
# Backward: unpermute gradients (gather from experts back to tokens)
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
# =============================================================================
# Token Combine (Unpermute) with VJP
# =============================================================================
def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Combine tokens from experts back to original token positions.
This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs).
Parameters
----------
inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size].
row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly.
Returns
-------
output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size].
"""
return _token_combine(inp, row_id_map, merging_probs)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> jnp.ndarray:
"""Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs)
return output
def _token_combine_fwd_rule(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]:
"""Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0]
num_experts = (row_id_map.shape[1] - 1) // 2
hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0]
# Call triton extension
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals)
# Include inp in residuals for backward with merging_probs
residuals = (
row_id_map,
inp,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
)
return output, residuals
def _token_combine_bwd_rule(
row_id_map: jnp.ndarray,
residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int],
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_combine."""
(
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
hidden_size,
num_out_tokens,
) = residuals
output_grad = g
with_merging_probs = merging_probs is not None
if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else:
# Simple case: just permute gradients back
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None
return inp_grad, merging_probs_grad
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
# =============================================================================
# Chunk Sort with VJP
# =============================================================================
def sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Sort chunks of tokens according to sorted indices.
Parameters
----------
inp : jnp.ndarray
Input tensor of shape [batch, sequence, hidden_size] or [num_tokens, hidden_size].
split_sizes : jnp.ndarray
Sizes of each chunk of shape [num_splits].
sorted_indices : jnp.ndarray
Permutation indices for chunks of shape [num_splits].
Returns
-------
output : jnp.ndarray
Sorted output tensor of shape [num_tokens, hidden_size].
row_id_map : jnp.ndarray
Row ID map for reversing the sort.
"""
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Internal sort_chunks_by_index with custom VJP."""
(output, row_id_map), _ = _sort_chunks_by_index_fwd_rule(inp, split_sizes, sorted_indices)
return output, row_id_map
def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
# Infer dimensions from input shape
num_tokens = inp.shape[0] * inp.shape[1] if inp.ndim == 3 else inp.shape[0]
hidden_size = inp.shape[-1]
num_splits = split_sizes.shape[0]
row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, num_tokens, num_splits)
output, _ = sort_chunks_by_map(
inp,
row_id_map,
None, # No probs
num_tokens,
hidden_size,
is_forward=True,
)
# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
return (output, row_id_map), residuals
def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
output_grad, _ = g
# Backward: reverse the sort
inp_grad, _ = sort_chunks_by_map(
output_grad,
row_id_map,
None,
num_tokens,
hidden_size,
is_forward=False,
)
return (inp_grad,)
_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
...@@ -20,6 +20,10 @@ Usage: ...@@ -20,6 +20,10 @@ Usage:
@staticmethod @staticmethod
def lowering(ctx, x, **kwargs): def lowering(ctx, x, **kwargs):
return triton_call_lowering(ctx, my_kernel, x, ...) return triton_call_lowering(ctx, my_kernel, x, ...)
# Use permutation functions
from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map
""" """
from .utils import * from .utils import *
from .permutation import *
This diff is collapsed.
...@@ -176,7 +176,9 @@ def triton_call_lowering( ...@@ -176,7 +176,9 @@ def triton_call_lowering(
*array_args: Input arrays (from ctx) *array_args: Input arrays (from ctx)
grid: Grid dimensions (int or tuple) grid: Grid dimensions (int or tuple)
input_output_aliases: Mapping of input to output aliases input_output_aliases: Mapping of input to output aliases
constexprs: Compile-time constants for the kernel constexprs: Compile-time constants for the kernel. This includes both
tl.constexpr arguments AND scalar runtime arguments (like
num_tokens, strides) that are known at JAX trace time.
Returns: Returns:
MLIR lowering result MLIR lowering result
...@@ -189,8 +191,10 @@ def triton_call_lowering( ...@@ -189,8 +191,10 @@ def triton_call_lowering(
return triton_call_lowering( return triton_call_lowering(
ctx, my_kernel, x, ctx, my_kernel, x,
grid=(triton.cdiv(n, block_size),), grid=(triton.cdiv(n, block_size),),
n_elements=n, constexprs={
BLOCK_SIZE=block_size "n_elements": n, # scalar arg (not tl.constexpr in kernel)
"BLOCK_SIZE": block_size, # tl.constexpr arg
},
) )
""" """
# Get compute capability using gpu_triton # Get compute capability using gpu_triton
...@@ -203,9 +207,13 @@ def triton_call_lowering( ...@@ -203,9 +207,13 @@ def triton_call_lowering(
else: else:
arg_names = kernel_fn.arg_names arg_names = kernel_fn.arg_names
# Build signature for inputs + outputs # Build signature for tensor arguments only (inputs + outputs)
# Scalar arguments should be passed via constexprs and will be
# specialized into the kernel at compile time
all_avals = list(ctx.avals_in) + list(ctx.avals_out) all_avals = list(ctx.avals_in) + list(ctx.avals_out)
signature = {arg_names[i]: get_triton_dtype(aval) for i, aval in enumerate(all_avals)} constexpr_names = set(constexprs.keys()) if constexprs else set()
tensor_arg_names = [n for n in arg_names if n not in constexpr_names]
signature = {n: get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)}
# Normalize grid to 3D # Normalize grid to 3D
if isinstance(grid, int): if isinstance(grid, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment