Unverified Commit a652730f authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Custom partitioning for Permutation primitives (#2591)



* initial impl, not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* consolidate different unpermute primitives with with_pad and with_merging_probs booleans. Implement partitioning for all permutation primitives
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add distributed test for non-padding permutation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix issues in distributed test for padding permutation. Make common kernel zero intiialize output permuted scales, permuted probs and output tokens
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert zeroing in triton common kernel as it is a race condition. Instead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix utils to handle input output aliasing in autotuned kernels
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Clean up comments, and add more comments explaining input output alias in utils
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and greptile comment
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix issues that lint fixing introduced
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6a34b657
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for distributed/sharded execution of MoE permutation primitives.
Testing Strategy:
=================
MoE permutation is data-dependent - the destination index for each token depends
on how many tokens before it are routed to the same expert. This means:
1. We CANNOT compare sharded output against global reference directly
2. Instead, we verify that each GPU's LOCAL output is correct according to its
LOCAL routing (which produces LOCAL row_id_map with LOCAL indices)
For data-parallel MoE without expert parallelism:
- Each GPU has ALL experts replicated
- Each GPU processes a subset of tokens (sharded on token/batch dimension)
- Each GPU computes its own local row_id_map from its local routing_map slice
- Each GPU's output is local and doesn't need to match global output
These tests verify:
1. Local token_dispatch: sharded input -> local row_id_map -> local permute (forward + backward)
2. Local roundtrip: dispatch + combine recovers original input (forward + backward)
"""
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
# High-level API with VJP support
from transformer_engine.jax.permutation import (
token_dispatch,
token_combine,
)
# Reference implementations from test_permutation.py
from test_permutation import (
reference_make_row_id_map,
_reference_permute_impl,
_reference_unpermute_impl,
reference_token_combine,
)
# Dispatch/combine test cases: (num_tokens, num_experts, hidden_size, topk)
# topk = number of experts each token is routed to
# Includes small, medium-large, and largest stress test cases.
ALL_DISPATCH_COMBINE_CASES = [
(128, 4, 64, 2),
(4096, 32, 1280, 2),
(4096, 256, 4096, 6),
]
DISPATCH_COMBINE_CASES = {
"L0": ALL_DISPATCH_COMBINE_CASES[0:1],
"L2": ALL_DISPATCH_COMBINE_CASES,
}
# Dispatch/combine with padding test cases: (num_tokens, num_experts, hidden_size, topk, align_size)
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 4, 64, 2, 8),
(4096, 32, 1280, 2, 128),
(4096, 256, 4096, 6, 16),
]
DISPATCH_COMBINE_PADDING_CASES = {
"L0": ALL_DISPATCH_COMBINE_PADDING_CASES[0:1],
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}
# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": [jnp.float32],
"L2": ALL_DTYPES,
}
class TestDistributedPermutation:
"""Test distributed/sharded execution of MoE permutation primitives.
These tests validate that custom partitioning produces correct LOCAL results
when inputs are sharded across multiple devices.
Key insight: With data-parallel MoE, each GPU independently processes its
local tokens. The row_id_map is generated locally and contains LOCAL indices.
We verify correctness by comparing each shard's output against the reference
implementation run on that shard's local data.
"""
@staticmethod
def compute_padded_output_size(
num_tokens: int,
num_experts: int,
topk: int,
align_size: int,
num_dp_devices: int,
) -> int:
"""Compute global_num_out_tokens for distributed padding tests.
Each device processes local_num_tokens tokens. We compute the worst-case
padded output size per device, then multiply by num_dp_devices to get
a global size that ensures global / num_dp >= local_worst.
"""
local_num_tokens = num_tokens // num_dp_devices
local_raw_out = local_num_tokens * topk
local_worst = ((local_raw_out + num_experts * (align_size - 1)) // align_size) * align_size
return local_worst * num_dp_devices
@staticmethod
def generate_routing_map(
num_tokens: int,
num_experts: int,
topk: int = 2, # Number of experts each token is routed to (max 1s per row).
key: jax.Array = None,
):
if key is None:
key = jax.random.PRNGKey(0)
routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32)
for token_idx in range(num_tokens):
key, subkey = jax.random.split(key)
expert_indices = jax.random.choice(subkey, num_experts, shape=(topk,), replace=False)
routing_map = routing_map.at[token_idx, expert_indices].set(1)
return routing_map
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_token_dispatch(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_shardy,
):
"""
Test token_dispatch with sharded inputs.
Verifies that sharded execution produces the same result as chunk-wise
reference execution. The sharded primitive:
1. Receives global num_out_tokens (partition function divides it)
2. Each GPU operates on its local shard independently
3. Results are gathered (concatenated) across GPUs
Output ordering: [GPU0_expert0, GPU0_expert1, ... | GPU1_expert0, ...]
The reference processes each chunk independently and concatenates,
matching the sharded execution's output ordering.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Shard on token (batch) dimension
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
# Compute num_out_tokens as concrete values
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
# Local num_out_tokens is used for reference implementation
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
global_num_out_tokens = num_tokens * topk
local_num_tokens = num_tokens // num_dp_devices
local_num_out_tokens = local_num_tokens * topk
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
probs_sharding = NamedSharding(mesh, sharded_pspec)
# Shard the inputs
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
probs_sharded = jax.device_put(probs, probs_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def target_dispatch(x, rm, p):
# Pass global num_out_tokens - partition function divides it
out, perm_probs, rid_map, _, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p
)
return out, perm_probs, rid_map
# Reference: process each GPU's shard independently, then concatenate
# This matches how the sharded primitive operates:
# - Each GPU processes its local shard
# - Results are gathered (concatenated) across GPUs
# Output ordering: [GPU0_exp0, GPU0_exp1, ... | GPU1_exp0, GPU1_exp1, ...]
inp_shards = jnp.reshape(inp, (num_dp_devices, local_num_tokens, hidden_size))
routing_shards = jnp.reshape(
routing_map, (num_dp_devices, local_num_tokens, num_experts)
)
probs_shards = jnp.reshape(probs, (num_dp_devices, local_num_tokens, num_experts))
ref_outputs = []
ref_perm_probs_list = []
ref_rid_maps = []
for i in range(num_dp_devices):
shard_rid_map = reference_make_row_id_map(routing_shards[i])
shard_out, shard_perm_probs = _reference_permute_impl(
inp_shards[i], shard_rid_map, probs_shards[i], local_num_out_tokens
)
ref_outputs.append(shard_out)
ref_perm_probs_list.append(shard_perm_probs)
ref_rid_maps.append(shard_rid_map)
# Concatenate like all_gather would
ref_out = jnp.concatenate(ref_outputs, axis=0)
ref_perm_probs = jnp.concatenate(ref_perm_probs_list, axis=0)
ref_rid_map = jnp.concatenate(ref_rid_maps, axis=0)
# Run target on sharded inputs
target_out, target_perm_probs, target_rid_map = target_dispatch(
inp_sharded, routing_sharded, probs_sharded
)
# Compare forward outputs
assert_allclose(jax.device_get(target_out), ref_out, dtype=dtype)
assert_allclose(jax.device_get(target_perm_probs), ref_perm_probs, dtype=dtype)
# Verify row_id_map n_routed column matches routing_map sum
target_rid_map_np = jax.device_get(target_rid_map)
assert jnp.array_equal(
target_rid_map_np[:, -1], ref_rid_map[:, -1]
), "n_routed column mismatch"
# Sanity checks
target_out_np = jax.device_get(target_out)
target_perm_probs_np = jax.device_get(target_perm_probs)
assert not np.any(np.isnan(target_out_np)), "Output contains NaN"
assert not np.any(np.isnan(target_perm_probs_np)), "Permuted probs contain NaN"
assert np.all(target_perm_probs_np >= 0), "Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def target_loss(x, rm, p):
out, perm_probs, _, _, _ = token_dispatch(x, rm, global_num_out_tokens, probs=p)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
# Reference loss: process chunks independently and sum
def ref_chunk_loss(inp_chunk, routing_chunk, probs_chunk):
rid_map = reference_make_row_id_map(routing_chunk)
out, perm_probs = _reference_permute_impl(
inp_chunk, rid_map, probs_chunk, local_num_out_tokens
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
target_grad_fn = jax.jit(jax.grad(target_loss, argnums=(0, 2)))
ref_chunk_grad_fn = jax.jit(jax.grad(ref_chunk_loss, argnums=(0, 2)))
target_inp_grad, target_probs_grad = target_grad_fn(
inp_sharded, routing_sharded, probs_sharded
)
# Compute reference gradients per chunk, then concatenate
ref_inp_grads = []
ref_probs_grads = []
for i in range(num_dp_devices):
chunk_inp_grad, chunk_probs_grad = ref_chunk_grad_fn(
inp_shards[i], routing_shards[i], probs_shards[i]
)
ref_inp_grads.append(chunk_inp_grad)
ref_probs_grads.append(chunk_probs_grad)
ref_inp_grad = jnp.concatenate(ref_inp_grads, axis=0)
ref_probs_grad = jnp.concatenate(ref_probs_grads, axis=0)
assert_allclose(jax.device_get(target_inp_grad), ref_inp_grad, dtype=dtype)
assert_allclose(jax.device_get(target_probs_grad), ref_probs_grad, dtype=dtype)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk",
DISPATCH_COMBINE_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_roundtrip(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_shardy,
):
"""
Test roundtrip: token_dispatch followed by token_combine with sharded inputs.
Each GPU:
1. Gets a shard of the input and routing_map
2. Performs local dispatch (permute)
3. Performs local combine (unpermute)
4. With uniform merging probs, should recover original input
Tests both forward pass and backward pass (gradient should be 2*x).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key = jax.random.split(key, 2)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
# Uniform merging probs for perfect roundtrip
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
# Compute num_out_tokens as concrete value
# Global num_out_tokens is passed to token_dispatch (partition function divides it)
global_num_out_tokens = num_tokens * topk
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
merging_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def roundtrip(x, rm, mprobs):
dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens)
return token_combine(dispatched, rid_map, mprobs)
roundtrip_out = roundtrip(inp_sharded, routing_sharded, merging_sharded)
# Should recover original input
assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def roundtrip_loss(x, rm, mprobs):
dispatched, _, rid_map, _, _ = token_dispatch(x, rm, global_num_out_tokens)
combined = token_combine(dispatched, rid_map, mprobs)
return jnp.sum(combined**2)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn = jax.jit(jax.grad(roundtrip_loss, argnums=0))
computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded)
expected_grad = 2.0 * inp_sharded
assert_allclose(
jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_token_dispatch_with_padding(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
align_size,
dtype,
use_shardy,
):
"""
Test token_dispatch with padding using sharded inputs.
Tests both forward pass (output values) and backward pass (gradients).
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate global inputs
key, inp_key, prob_key = jax.random.split(key, 3)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
probs = jax.random.uniform(
prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.1, maxval=1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens = self.compute_padded_output_size(
num_tokens, num_experts, topk, align_size, num_dp_devices
)
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
probs_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
probs_sharded = jax.device_put(probs, probs_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def dispatch_with_padding(x, rm, p):
out, perm_probs, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p, align_size=align_size
)
return out, perm_probs, rid_map, pad_offsets
out, perm_probs, rid_map, pad_offsets = dispatch_with_padding(
inp_sharded, routing_sharded, probs_sharded
)
# Sanity checks
out_np = jax.device_get(out)
perm_probs_np = jax.device_get(perm_probs)
assert not np.any(np.isnan(out_np)), "Output contains NaN"
assert not np.any(np.isnan(perm_probs_np)), "Permuted probs contain NaN"
assert np.all(perm_probs_np >= 0), "Permuted probs contain negative values"
# ================================================================
# Backward pass test (gradients)
# ================================================================
def loss_with_padding(x, rm, p):
out, perm_probs, _, _, _ = token_dispatch(
x, rm, global_num_out_tokens, probs=p, align_size=align_size
)
return jnp.sum(out**2) + jnp.sum(perm_probs**2)
grad_fn = jax.jit(jax.grad(loss_with_padding, argnums=(0, 2)))
inp_grad, probs_grad = grad_fn(inp_sharded, routing_sharded, probs_sharded)
# Gradients should not contain NaN
assert not np.any(np.isnan(jax.device_get(inp_grad))), "Input gradient contains NaN"
assert not np.any(np.isnan(jax.device_get(probs_grad))), "Probs gradient contains NaN"
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest_parametrize_wrapper(
"num_tokens,num_experts,hidden_size,topk,align_size",
DISPATCH_COMBINE_PADDING_CASES,
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_shardy", [False, True])
def test_local_roundtrip_with_padding(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
num_tokens,
num_experts,
hidden_size,
topk,
align_size,
dtype,
use_shardy,
):
"""
Test roundtrip with padding/alignment using sharded inputs.
With uniform merging probs, should recover original input.
Tests both forward pass and backward pass.
"""
jax.config.update("jax_use_shardy_partitioner", use_shardy)
key = jax.random.PRNGKey(42)
# Generate inputs
key, inp_key = jax.random.split(key, 2)
inp = jax.random.uniform(
inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0
)
routing_map = self.generate_routing_map(num_tokens, num_experts, topk, key)
# Uniform merging probs
uniform_merging_probs = routing_map.astype(dtype) / jnp.maximum(
jnp.sum(routing_map, axis=1, keepdims=True), 1.0
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
dp_axis = mesh_resource.dp_resource
sharded_pspec = PartitionSpec(dp_axis, None)
num_dp_devices = mesh.shape[dp_axis] if dp_axis else 1
# For padding + sharding, we need to account for per-shard padding overhead.
# Each shard needs E*(A-1) extra space for worst-case padding.
# Compute global_num_out_tokens such that global / num_dp >= local_worst.
global_num_out_tokens = self.compute_padded_output_size(
num_tokens, num_experts, topk, align_size, num_dp_devices
)
with mesh:
inp_sharding = NamedSharding(mesh, sharded_pspec)
routing_sharding = NamedSharding(mesh, sharded_pspec)
merging_sharding = NamedSharding(mesh, sharded_pspec)
inp_sharded = jax.device_put(inp, inp_sharding)
routing_sharded = jax.device_put(routing_map, routing_sharding)
merging_sharded = jax.device_put(uniform_merging_probs, merging_sharding)
# ================================================================
# Forward pass test
# ================================================================
@jax.jit
def roundtrip_with_padding(x, rm, mprobs):
dispatched, _, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, align_size=align_size
)
return token_combine(dispatched, rid_map, mprobs, pad_offsets)
roundtrip_out = roundtrip_with_padding(inp_sharded, routing_sharded, merging_sharded)
# Should recover original input
assert_allclose(jax.device_get(roundtrip_out), jax.device_get(inp_sharded), dtype=dtype)
# ================================================================
# Backward pass test (gradients)
# ================================================================
def roundtrip_loss_with_padding(x, rm, mprobs):
dispatched, _, rid_map, pad_offsets, _ = token_dispatch(
x, rm, global_num_out_tokens, align_size=align_size
)
combined = token_combine(dispatched, rid_map, mprobs, pad_offsets)
return jnp.sum(combined**2)
# With uniform merging probs, roundtrip is identity, so gradient should be 2*x
grad_fn = jax.jit(jax.grad(roundtrip_loss_with_padding, argnums=0))
computed_grad = grad_fn(inp_sharded, routing_sharded, merging_sharded)
expected_grad = 2.0 * inp_sharded
assert_allclose(
jax.device_get(computed_grad), jax.device_get(expected_grad), dtype=dtype
)
......@@ -201,8 +201,15 @@ def _permute_kernel(
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr, # pylint: disable=unused-argument
permuted_probs_buf_ptr, # pylint: disable=unused-argument
# sizes
scale_hidden_dim,
num_tokens, # pylint: disable=unused-argument
num_out_tokens, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -228,12 +235,17 @@ def _permute_kernel(
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
......@@ -306,6 +318,10 @@ def _unpermute_kernel(
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
# These are unused in the unpermute kernel but maintain consistency with the permute kernel.
output_buf_ptr, # pylint: disable=unused-argument
unpermuted_probs_buf_ptr, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......
......@@ -137,7 +137,7 @@ def token_dispatch(
)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6))
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
......@@ -240,6 +240,7 @@ def _token_dispatch_fwd_rule(
num_experts,
worst_case_out_tokens,
hidden_size,
align_size=align_size,
)
else:
# No padding
......@@ -268,7 +269,6 @@ def _token_dispatch_fwd_rule(
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
_worst_case_out_tokens: int,
_align_size: Optional[int],
......@@ -281,8 +281,12 @@ def _token_dispatch_bwd_rule(
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch.
Returns gradients for (inp, routing_map, probs).
routing_map gradient is None since it's a discrete routing decision.
"""
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
......@@ -309,7 +313,9 @@ def _token_dispatch_bwd_rule(
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
# Return gradients for (inp, routing_map, probs)
# routing_map is non-differentiable (discrete routing), so return None
return inp_grad, None, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
......@@ -497,6 +503,8 @@ def _token_combine_bwd_rule(
else:
# Simple case: just permute gradients back
if pad_offsets is not None:
# Note: align_size uses default (128) since buffer sizes are already
# determined from forward pass (stored in residuals as num_out_tokens)
inp_grad, _ = permute_with_mask_map_and_pad(
output_grad,
row_id_map,
......@@ -506,6 +514,7 @@ def _token_combine_bwd_rule(
num_experts,
num_out_tokens,
hidden_size,
align_size=128, # Default, sizes already computed in forward
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
......
......@@ -8,9 +8,13 @@ from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
import triton
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
from transformer_engine.jax.cpp_extensions.misc import get_padded_spec, NamedSharding
from transformer_engine.jax.sharding import get_mesh_axis_size
from transformer_engine.common.triton.permutation import (
_row_id_map_pass_1_kernel,
_row_id_map_pass_2_kernel,
......@@ -93,7 +97,6 @@ class RowIdMapPass1Primitive(BasePrimitive):
@staticmethod
def lowering(ctx, routing_map, *, num_tokens, num_experts, block_size):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
routing_stride_token = num_experts
routing_stride_expert = 1
row_id_stride_token = num_experts * 2 + 1
......@@ -101,11 +104,10 @@ class RowIdMapPass1Primitive(BasePrimitive):
grid = (num_experts, triton.cdiv(num_tokens, block_size))
# All scalar arguments must be passed as constexprs
return triton_call_lowering(
ctx,
_row_id_map_pass_1_kernel,
routing_map, # Only tensor arguments here
routing_map,
grid=grid,
constexprs={
"num_tokens": num_tokens,
......@@ -117,6 +119,76 @@ class RowIdMapPass1Primitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(
num_tokens, num_experts, block_size, mesh, arg_infos, result_infos
):
"""Infer output sharding from input sharding."""
del num_tokens, num_experts, block_size, result_infos
routing_map_spec = get_padded_spec(arg_infos[0])
# row_id_map has same token dimension sharding as routing_map
# Shape: (num_tokens, num_experts * 2 + 1)
row_id_map_sharding = NamedSharding(
mesh,
PartitionSpec(routing_map_spec[0], None),
desc="RowIdMapPass1.row_id_map_sharding",
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
desc="RowIdMapPass1.workspace_sharding",
)
return [row_id_map_sharding, workspace_sharding]
@staticmethod
def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos):
"""Row id map 1st pass partition."""
del num_tokens, result_infos
routing_map_spec = get_padded_spec(arg_infos[0])
# Input sharding
arg_shardings = (arg_infos[0].sharding,)
# Output shardings
row_id_map_sharding = NamedSharding(
mesh,
PartitionSpec(routing_map_spec[0], None),
desc="RowIdMapPass1.row_id_map_sharding",
)
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
desc="RowIdMapPass1.workspace_sharding",
)
out_shardings = [row_id_map_sharding, workspace_sharding]
def sharded_impl(routing_map):
# Each shard processes its local tokens
local_num_tokens = routing_map.shape[0]
return RowIdMapPass1Primitive.impl(
routing_map,
num_tokens=local_num_tokens,
num_experts=num_experts,
block_size=block_size,
)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types):
"""Shardy sharding rule for this primitive."""
del num_tokens, num_experts, block_size, mesh, value_types, result_types
prefix = "RowIdMapPass1"
# routing_map shape: (num_tokens, num_experts)
input_spec = (f"{prefix}_tokens", f"{prefix}_experts")
# row_id_map shape: (num_tokens, num_experts * 2 + 1)
# Note: row_id_cols != experts since it's num_experts * 2 + 1
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks")
return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec))
register_primitive(RowIdMapPass1Primitive)
......@@ -185,6 +257,69 @@ class RowIdMapPass2Primitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(
num_tokens, num_experts, block_size, mesh, arg_infos, result_infos
):
"""Infer output sharding from input sharding."""
del num_tokens, num_experts, block_size, result_infos
row_id_map_spec = get_padded_spec(arg_infos[0])
# Output has same sharding as input (in-place operation)
row_id_map_sharding = NamedSharding(
mesh,
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding",
)
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
desc="RowIdMapPass2.workspace_sharding",
)
return [row_id_map_sharding, workspace_sharding]
@staticmethod
def partition(num_tokens, num_experts, block_size, mesh, arg_infos, result_infos):
"""Partition the primitive for distributed execution."""
del num_tokens, result_infos
row_id_map_spec = get_padded_spec(arg_infos[0])
# Input shardings
arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding)
# Output shardings (same as inputs for in-place operation)
row_id_map_sharding = NamedSharding(
mesh,
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding",
)
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
desc="RowIdMapPass2.workspace_sharding",
)
out_shardings = [row_id_map_sharding, workspace_sharding]
def sharded_impl(row_id_map, workspace):
local_num_tokens = row_id_map.shape[0]
return RowIdMapPass2Primitive.impl(
row_id_map,
workspace,
num_tokens=local_num_tokens,
num_experts=num_experts,
block_size=block_size,
)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(num_tokens, num_experts, block_size, mesh, value_types, result_types):
"""Shardy sharding rule for this primitive."""
del num_tokens, num_experts, block_size, mesh, value_types, result_types
prefix = "RowIdMapPass2"
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks")
return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec))
register_primitive(RowIdMapPass2Primitive)
......@@ -240,6 +375,52 @@ class RowIdMapPass3Primitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(num_tokens, num_experts, mesh, arg_infos, result_infos):
"""Infer output sharding from input sharding."""
del num_tokens, num_experts, result_infos
row_id_map_spec = get_padded_spec(arg_infos[0])
# Output has same sharding as input (in-place operation)
return NamedSharding(
mesh,
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass3.row_id_map_sharding",
)
@staticmethod
def partition(num_tokens, num_experts, mesh, arg_infos, result_infos):
"""Partition the primitive for distributed execution."""
del num_tokens, result_infos
row_id_map_spec = get_padded_spec(arg_infos[0])
# Input sharding
arg_shardings = (arg_infos[0].sharding,)
# Output sharding (same as input for in-place operation)
out_sharding = NamedSharding(
mesh,
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass3.row_id_map_sharding",
)
def sharded_impl(row_id_map):
local_num_tokens = row_id_map.shape[0]
return RowIdMapPass3Primitive.impl(
row_id_map,
num_tokens=local_num_tokens,
num_experts=num_experts,
)
return mesh, sharded_impl, out_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(num_tokens, num_experts, mesh, value_types, result_types):
"""Shardy sharding rule for this primitive."""
del num_tokens, num_experts, mesh, value_types, result_types
prefix = "RowIdMapPass3"
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
return SdyShardingRule((row_id_map_spec,), (row_id_map_spec,))
register_primitive(RowIdMapPass3Primitive)
......@@ -251,8 +432,12 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
name = "te_permute_with_mask_map_triton"
multiple_results = True
# scale, permuted_scale are dummy inputs (not used when PERMUTE_SCALE=False)
# pad_offsets can be shape (0,) when not doing padding, or (num_experts,) when padding
# Outer primitive has 6 tensor inputs: inp, row_id_map, probs, scale, permuted_scale, pad_offsets
# Static args for outer primitive: num_tokens, num_experts, num_out_tokens, hidden_size,
# with_probs, with_pad, align_size
# Inner primitive adds output_buf, permuted_probs_buf)
# impl_static_args is for the outer primitive's impl() which has 6 tensor inputs.
impl_static_args = (
6,
7,
......@@ -260,7 +445,8 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
9,
10,
11,
) # num_tokens, num_experts, num_out_tokens, hidden_size, with_probs, with_pad
12,
)
inner_primitive = None
outer_primitive = None
......@@ -272,6 +458,8 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
scale_aval, # dummy, same shape as inp
permuted_scale_aval, # dummy, same shape as inp
pad_offsets_aval,
output_buf_aval=None, # Pre-zeroed output buffer (inner primitive only)
permuted_probs_buf_aval=None, # Pre-zeroed permuted_probs buffer (inner primitive only)
*,
num_tokens,
num_experts,
......@@ -279,10 +467,12 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
hidden_size,
with_probs,
with_pad,
align_size,
):
"""Shape/dtype inference for permute."""
del row_id_map_aval, scale_aval, permuted_scale_aval, pad_offsets_aval
del num_tokens, num_experts, with_pad
del num_tokens, num_experts, with_pad, align_size
del output_buf_aval, permuted_probs_buf_aval # Used for input_output_aliases only
output_shape = (num_out_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
......@@ -308,9 +498,29 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
hidden_size,
with_probs,
with_pad,
align_size, # align_size is only used for sharding, but must be passed since abstract() requires it
):
"""Forward to inner primitive."""
assert PermuteWithMaskMapPrimitive.inner_primitive is not None
# Create pre-zeroed output buffers for the inner primitive.
# When with_pad=True, this ensures padding positions contain zeros.
# These buffers are aliased to the outputs via input_output_aliases in the lowering.
if with_pad:
output_buf = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype)
if with_probs:
permuted_probs_buf = jnp.zeros((num_out_tokens,), dtype=probs.dtype)
else:
permuted_probs_buf = jnp.zeros((0,), dtype=inp.dtype)
else:
# When not padding, use empty buffers (kernel ignores them, lowering skips aliasing)
output_buf = jnp.empty((num_out_tokens, hidden_size), dtype=inp.dtype)
if with_probs:
permuted_probs_buf = jnp.empty((num_out_tokens,), dtype=probs.dtype)
else:
permuted_probs_buf = jnp.empty((0,), dtype=inp.dtype)
return PermuteWithMaskMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
......@@ -318,12 +528,15 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
scale,
permuted_scale,
pad_offsets,
output_buf,
permuted_probs_buf,
num_tokens=num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=with_pad,
align_size=align_size,
)
@staticmethod
......@@ -335,6 +548,8 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
scale,
permuted_scale,
pad_offsets,
output_buf, # Pre-zeroed output buffer (for input_output_aliases)
permuted_probs_buf, # Pre-zeroed permuted_probs buffer (for input_output_aliases)
*,
num_tokens,
num_experts,
......@@ -342,9 +557,10 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
hidden_size,
with_probs,
with_pad,
align_size,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
del align_size
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
......@@ -371,6 +587,18 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
block_size = _get_min_block_size(_permute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
# Use input_output_aliases to alias pre-zeroed buffers to outputs.
# This ensures padding positions contain zeros since the kernel only writes valid positions.
# Input indices: 0=inp, 1=row_id_map, 2=probs, 3=scale, 4=permuted_scale,
# 5=pad_offsets, 6=output_buf, 7=permuted_probs_buf
# Output indices: 0=output, 1=permuted_probs
if with_pad:
input_output_aliases = {6: 0}
if with_probs:
input_output_aliases[7] = 1
else:
input_output_aliases = None
return triton_call_lowering(
ctx,
_permute_kernel,
......@@ -380,9 +608,14 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
scale,
permuted_scale,
pad_offsets,
output_buf,
permuted_probs_buf,
grid=grid,
input_output_aliases=input_output_aliases,
constexprs={
"scale_hidden_dim": 0,
"num_tokens": num_tokens,
"num_out_tokens": num_out_tokens,
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_input_token": inp_stride_token,
......@@ -405,24 +638,242 @@ class PermuteWithMaskMapPrimitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
align_size,
mesh,
arg_infos,
result_infos,
):
"""Infer output sharding from input sharding.
For batch-dimension partitioning:
- Input (num_tokens, hidden_size) is sharded on token dim
- Output (num_out_tokens, hidden_size) gets same token dim sharding
- Permuted probs (num_out_tokens,) gets same token dim sharding
"""
del align_size # Used only in partition
del num_tokens, num_experts, num_out_tokens, hidden_size, with_pad, result_infos
inp_spec = get_padded_spec(arg_infos[0])
# Output has same sharding pattern: (token_shard, None)
output_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0], None),
desc="PermuteWithMaskMap.output_sharding",
)
if with_probs:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0]),
desc="PermuteWithMaskMap.permuted_probs_sharding",
)
else:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="PermuteWithMaskMap.permuted_probs_sharding_empty",
)
return [output_sharding, permuted_probs_sharding]
@staticmethod
def partition(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
align_size,
mesh,
arg_infos,
result_infos,
):
"""Partition the primitive for distributed execution.
For batch-dimension partitioning, each GPU processes its local tokens
independently. The row_id_map contains local destination indices,
so no inter-GPU communication is needed.
"""
del num_tokens, result_infos
inp_spec = get_padded_spec(arg_infos[0])
# Input shardings - preserve original shardings
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
# Output shardings
output_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0], None),
desc="PermuteWithMaskMap.output_sharding",
)
if with_probs:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0]),
desc="PermuteWithMaskMap.permuted_probs_sharding",
)
else:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="PermuteWithMaskMap.permuted_probs_sharding_empty",
)
out_shardings = [output_sharding, permuted_probs_sharding]
# Get number of data parallel devices from the batch sharding axis
batch_axis = inp_spec[0]
if batch_axis is not None:
num_dp_devices = get_mesh_axis_size(batch_axis, mesh)
else:
num_dp_devices = 1
def sharded_impl(inp, row_id_map, probs, scale, permuted_scale, pad_offsets):
# Each shard processes its local tokens independently (data parallelism)
local_num_tokens = inp.shape[0]
# =========================================================================
# MoE Permutation Sharding (data parallelism, no expert parallelism)
# =========================================================================
# Each GPU has ALL experts and processes its local batch of tokens.
#
# TopK bounds output: each token goes to at most topK experts, so:
# global_num_out_tokens = global_num_in_tokens * topK
# local_num_out_tokens = local_num_in_tokens * topK
# = global_num_out_tokens / num_dp_devices
#
# E = num_experts
# A = align_size for padding to group gemm size in cuBLAS
# With padding (align_size != 128, which is the default/no-op value):
# The global num_out_tokens passed here is already worst_case_out_tokens.
# We need to recalculate local worst-case from local raw tokens.
# local_raw_out_tokens = global_raw_out_tokens / num_dp_devices
# local_worst_case = ((local_raw_out + E*(A-1)) // A) * A
#
# Local permute produces output ordered by expert: [E0 | E1 | ... | EN]
# where each expert section contains tokens routed to that expert.
#
# Global assembly (if needed) should be done outside this primitive.
# =========================================================================
# Output size calculation
# =========================================================================
# For both padding and non-padding cases, use simple division.
# The global num_out_tokens is already the worst-case buffer size.
#
# IMPORTANT for padding + sharding:
# Padding overhead is per-shard (each shard needs E*(A-1) extra space).
# The caller must account for this by passing a sufficiently large
# global num_out_tokens such that: global_worst / num_dp >= local_worst
# where local_worst = ((local_raw + E*(A-1)) // A) * A
local_num_out_tokens = num_out_tokens // num_dp_devices
# Local permute - output stays sharded on this GPU
local_output, local_permuted_probs = PermuteWithMaskMapPrimitive.impl(
inp,
row_id_map,
probs,
scale,
permuted_scale,
pad_offsets,
num_tokens=local_num_tokens,
num_experts=num_experts,
num_out_tokens=local_num_out_tokens,
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=with_pad,
align_size=align_size,
)
return local_output, local_permuted_probs
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_probs,
with_pad,
align_size,
mesh,
value_types,
result_types,
):
"""Shardy sharding rule for this primitive."""
del (
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
align_size,
mesh,
value_types,
result_types,
)
prefix = "PermuteWithMaskMap"
# inp: (num_tokens, hidden_size)
inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
# row_id_map: (num_tokens, num_experts * 2 + 1)
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# probs: (num_tokens, num_experts) or (0,)
probs_spec = (
(f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty",)
)
# scale: (num_tokens, hidden_size) - same shape as inp, permuted together
scale_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
# permuted_scale: (num_out_tokens, hidden_size) - same shape as output
permuted_scale_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
# pad_offsets: (num_experts,) or (0,) - uses same experts factor as probs
pad_offsets_spec = (f"{prefix}_experts",) if with_pad else (f"{prefix}_pad_empty",)
# output: (num_out_tokens, hidden_size)
output_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
# permuted_probs: (num_out_tokens,) or (0,)
permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",)
return SdyShardingRule(
(
inp_spec,
row_id_map_spec,
probs_spec,
scale_spec,
permuted_scale_spec,
pad_offsets_spec,
),
(output_spec, permuted_probs_spec),
)
register_primitive(PermuteWithMaskMapPrimitive)
class UnpermuteWithMaskMapPrimitive(BasePrimitive):
"""
Unpermute the input tensor based on the row_id_map.
Unpermute the input tensor based on the row_id_map, optionally with fused unpadding.
"""
name = "te_unpermute_with_mask_map_triton"
multiple_results = True
# Outer primitive has 5 tensor inputs: inp, row_id_map, merging_probs, permuted_probs, pad_offsets
# Static args for outer primitive: num_tokens, num_experts, hidden_size,
# with_merging_probs, with_probs, with_unpad
# Inner primitive has adds output_buf, unpermuted_probs_buf
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
10,
)
inner_primitive = None
outer_primitive = None
......@@ -432,16 +883,20 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
row_id_map_aval,
merging_probs_aval,
permuted_probs_aval,
pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False
pad_offsets_aval,
output_buf_aval=None, # Dummy (inner primitive only)
unpermuted_probs_buf_aval=None, # Dummy (inner primitive only)
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
):
"""Shape/dtype inference for unpermute."""
del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval
del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval, with_unpad
del output_buf_aval, unpermuted_probs_buf_aval
output_shape = (num_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
......@@ -468,20 +923,33 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
):
"""Forward to inner primitive."""
assert UnpermuteWithMaskMapPrimitive.inner_primitive is not None
# Create dummy buffers for kernel signature consistency with _permute_kernel.
# These are not used for pre-zeroing since unpermute writes to all output positions.
output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype)
if with_probs:
unpermuted_probs_buf = jnp.empty((num_tokens, num_experts), dtype=permuted_probs.dtype)
else:
unpermuted_probs_buf = jnp.empty((0,), dtype=inp.dtype)
return UnpermuteWithMaskMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
output_buf,
unpermuted_probs_buf,
num_tokens=num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
with_unpad=with_unpad,
)
@staticmethod
......@@ -492,12 +960,15 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
merging_probs,
permuted_probs,
pad_offsets,
output_buf, # Dummy for kernel signature consistency
unpermuted_probs_buf, # Dummy for kernel signature consistency
*,
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
......@@ -523,7 +994,6 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
# Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False)
return triton_call_lowering(
ctx,
_unpermute_kernel,
......@@ -532,6 +1002,8 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
merging_probs,
permuted_probs,
pad_offsets,
output_buf,
unpermuted_probs_buf,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
......@@ -550,174 +1022,170 @@ class UnpermuteWithMaskMapPrimitive(BasePrimitive):
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"WITH_MERGING_PROBS": with_merging_probs,
"PERMUTE_PROBS": with_probs,
"FUSION_UNPAD": False,
"FUSION_UNPAD": with_unpad,
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteWithMaskMapPrimitive)
class UnpermuteWithMaskMapAndUnpadPrimitive(BasePrimitive):
"""
Unpermute the input tensor based on the row_id_map with fused unpadding.
"""
name = "te_unpermute_with_mask_map_and_unpad_triton"
multiple_results = True
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, hidden_size, with_merging_probs, with_probs
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
inp_aval,
row_id_map_aval,
merging_probs_aval,
permuted_probs_aval,
pad_offsets_aval,
*,
def infer_sharding_from_operands(
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
mesh,
arg_infos,
result_infos,
):
"""Shape/dtype inference for unpermute with unpadding."""
del row_id_map_aval, merging_probs_aval, with_merging_probs, pad_offsets_aval
output_shape = (num_tokens, hidden_size)
output_aval = jax.core.ShapedArray(output_shape, inp_aval.dtype)
"""Infer output sharding from input sharding.
For batch-dimension partitioning:
- row_id_map (num_tokens, num_experts*2+1) is sharded on token dim
- Output (num_tokens, hidden_size) gets same token dim sharding
"""
del num_tokens, num_experts, hidden_size, with_merging_probs, with_unpad, result_infos
row_id_map_spec = get_padded_spec(arg_infos[1])
# Output has same token dimension sharding as row_id_map
output_sharding = NamedSharding(
mesh,
PartitionSpec(row_id_map_spec[0], None),
desc="UnpermuteWithMaskMap.output_sharding",
)
if with_probs:
unpermuted_probs_shape = (num_tokens, num_experts)
unpermuted_probs_aval = jax.core.ShapedArray(
unpermuted_probs_shape, permuted_probs_aval.dtype
unpermuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(row_id_map_spec[0], None),
desc="UnpermuteWithMaskMap.unpermuted_probs_sharding",
)
else:
unpermuted_probs_aval = jax.core.ShapedArray((0,), inp_aval.dtype)
return output_aval, unpermuted_probs_aval
unpermuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty",
)
return [output_sharding, unpermuted_probs_sharding]
@staticmethod
def impl(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
def partition(
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
mesh,
arg_infos,
result_infos,
):
"""Forward to inner primitive."""
assert UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive is not None
return UnpermuteWithMaskMapAndUnpadPrimitive.inner_primitive.bind(
"""Partition the primitive for distributed execution."""
del num_tokens, result_infos
row_id_map_spec = get_padded_spec(arg_infos[1])
# Input shardings - preserve original shardings
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
# Output shardings
output_sharding = NamedSharding(
mesh,
PartitionSpec(row_id_map_spec[0], None),
desc="UnpermuteWithMaskMap.output_sharding",
)
if with_probs:
unpermuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(row_id_map_spec[0], None),
desc="UnpermuteWithMaskMap.unpermuted_probs_sharding",
)
else:
unpermuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="UnpermuteWithMaskMap.unpermuted_probs_sharding_empty",
)
out_shardings = [output_sharding, unpermuted_probs_sharding]
def sharded_impl(inp, row_id_map, merging_probs, permuted_probs, pad_offsets):
# Each shard processes its local tokens
local_num_tokens = row_id_map.shape[0]
return UnpermuteWithMaskMapPrimitive.impl(
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
num_tokens=num_tokens,
num_tokens=local_num_tokens,
num_experts=num_experts,
hidden_size=hidden_size,
hidden_size=hidden_size, # hidden_size is not sharded
with_merging_probs=with_merging_probs,
with_probs=with_probs,
with_unpad=with_unpad,
)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def lowering(
ctx,
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
*,
def shardy_sharding_rule(
num_tokens,
num_experts,
hidden_size,
with_merging_probs,
with_probs,
with_unpad,
mesh,
value_types,
result_types,
):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
inp_stride_token = hidden_size
inp_stride_hidden = 1
output_stride_token = hidden_size
output_stride_hidden = 1
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
if with_merging_probs:
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
else:
merging_probs_stride_token = 0
merging_probs_stride_expert = 0
permuted_probs_stride_token = 1
unpermuted_probs_stride_token = num_experts
unpermuted_probs_stride_expert = 1
# Grid - use minimum BLOCK_SIZE from autotune configs
block_size = _get_min_block_size(_unpermute_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))
"""Shardy sharding rule for this primitive."""
del num_tokens, num_experts, hidden_size, mesh, value_types, result_types
prefix = "UnpermuteWithMaskMap"
# inp: (num_out_tokens, hidden_size)
inp_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
# row_id_map: (num_tokens, num_experts * 2 + 1)
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# merging_probs: (num_tokens, num_experts) or (0,)
merging_probs_spec = (
(f"{prefix}_tokens", f"{prefix}_experts")
if with_merging_probs
else (f"{prefix}_empty",)
)
# permuted_probs: (num_out_tokens,) or (0,)
permuted_probs_spec = (f"{prefix}_out_tokens",) if with_probs else (f"{prefix}_empty2",)
# pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise
pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",)
# output: (num_tokens, hidden_size)
output_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
# unpermuted_probs: (num_tokens, num_experts) or (0,)
unpermuted_probs_spec = (
(f"{prefix}_tokens", f"{prefix}_experts") if with_probs else (f"{prefix}_empty3",)
)
return triton_call_lowering(
ctx,
_unpermute_kernel,
inp,
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
"stride_output_token": output_stride_token,
"stride_output_hidden": output_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_permuted_probs_token": permuted_probs_stride_token,
"stride_unpermuted_probs_token": unpermuted_probs_stride_token,
"stride_unpermuted_probs_expert": unpermuted_probs_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"WITH_MERGING_PROBS": with_merging_probs,
"PERMUTE_PROBS": with_probs,
"FUSION_UNPAD": True,
"BLOCK_SIZE": block_size,
},
return SdyShardingRule(
(inp_spec, row_id_map_spec, merging_probs_spec, permuted_probs_spec, pad_offsets_spec),
(output_spec, unpermuted_probs_spec),
)
register_primitive(UnpermuteWithMaskMapAndUnpadPrimitive)
register_primitive(UnpermuteWithMaskMapPrimitive)
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
"""
Backward pass for unpermute with merging probabilities.
Backward pass for unpermute with merging probabilities, optionally with fused unpadding.
This kernel computes gradients for both the input and merging_probs.
"""
name = "te_unpermute_bwd_with_merging_probs_triton"
multiple_results = True
impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size
impl_static_args = (
5,
6,
7,
8,
9,
) # num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad
inner_primitive = None
outer_primitive = None
......@@ -727,15 +1195,16 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
fwd_input_aval,
merging_probs_aval,
row_id_map_aval,
pad_offsets_aval, # dummy, not used when FUSION_UNPAD=False
pad_offsets_aval,
*,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
):
"""Shape/dtype inference for unpermute backward with merging probs."""
del fwd_input_aval, row_id_map_aval, pad_offsets_aval
del fwd_input_aval, row_id_map_aval, pad_offsets_aval, with_unpad
# fwd_input_grad has same shape as fwd_input
fwd_input_grad_shape = (num_out_tokens, hidden_size)
......@@ -760,6 +1229,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
):
"""Forward to inner primitive."""
assert UnpermuteBwdWithMergingProbsPrimitive.inner_primitive is not None
......@@ -773,6 +1243,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_unpad=with_unpad,
)
@staticmethod
......@@ -788,6 +1259,7 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
......@@ -812,7 +1284,6 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
# Get min block size from autotune configs for consistency
block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)
# Pass all 5 inputs including pad_offsets (even though FUSION_UNPAD=False)
return triton_call_lowering(
ctx,
_unpermute_bwd_with_merging_probs_kernel,
......@@ -838,152 +1309,126 @@ class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive):
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"FUSION_UNPAD": False,
"FUSION_UNPAD": with_unpad,
"BLOCK_SIZE": block_size,
},
)
register_primitive(UnpermuteBwdWithMergingProbsPrimitive)
class UnpermuteBwdWithMergingProbsAndUnpadPrimitive(BasePrimitive):
"""
Backward pass for unpermute with merging probabilities and fused unpadding.
This kernel computes gradients for both the input and merging_probs,
while handling padded outputs.
"""
name = "te_unpermute_bwd_with_merging_probs_and_unpad_triton"
multiple_results = True
impl_static_args = (5, 6, 7, 8) # num_tokens, num_experts, num_out_tokens, hidden_size
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
fwd_output_grad_aval,
fwd_input_aval,
merging_probs_aval,
row_id_map_aval,
pad_offsets_aval,
*,
def infer_sharding_from_operands(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
mesh,
arg_infos,
result_infos,
):
"""Shape/dtype inference for unpermute backward with merging probs and unpadding."""
del fwd_input_aval, row_id_map_aval, pad_offsets_aval
# fwd_input_grad has same shape as fwd_input
fwd_input_grad_shape = (num_out_tokens, hidden_size)
fwd_input_grad_aval = jax.core.ShapedArray(fwd_input_grad_shape, fwd_output_grad_aval.dtype)
# merging_probs_grad has same shape as merging_probs
merging_probs_grad_shape = (num_tokens, num_experts)
merging_probs_grad_aval = jax.core.ShapedArray(
merging_probs_grad_shape, merging_probs_aval.dtype
"""Infer output sharding from input sharding."""
del num_tokens, num_experts, num_out_tokens, hidden_size, with_unpad, result_infos
fwd_output_grad_spec = get_padded_spec(arg_infos[0])
merging_probs_spec = get_padded_spec(arg_infos[2])
# fwd_input_grad has same token sharding as fwd_output_grad
fwd_input_grad_sharding = NamedSharding(
mesh,
PartitionSpec(fwd_output_grad_spec[0], None),
desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding",
)
return fwd_input_grad_aval, merging_probs_grad_aval
# merging_probs_grad has same sharding as merging_probs
merging_probs_grad_sharding = NamedSharding(
mesh,
PartitionSpec(merging_probs_spec[0], None),
desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding",
)
return [fwd_input_grad_sharding, merging_probs_grad_sharding]
@staticmethod
def impl(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
def partition(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
mesh,
arg_infos,
result_infos,
):
"""Forward to inner primitive."""
assert UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive is not None
return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.inner_primitive.bind(
"""Partition the primitive for distributed execution."""
del num_tokens, num_out_tokens, result_infos
fwd_output_grad_spec = get_padded_spec(arg_infos[0])
merging_probs_spec = get_padded_spec(arg_infos[2])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
fwd_input_grad_sharding = NamedSharding(
mesh,
PartitionSpec(fwd_output_grad_spec[0], None),
desc="UnpermuteBwdWithMergingProbs.fwd_input_grad_sharding",
)
merging_probs_grad_sharding = NamedSharding(
mesh,
PartitionSpec(merging_probs_spec[0], None),
desc="UnpermuteBwdWithMergingProbs.merging_probs_grad_sharding",
)
out_shardings = [fwd_input_grad_sharding, merging_probs_grad_sharding]
def sharded_impl(fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets):
local_num_tokens = row_id_map.shape[0]
# NOTE: local_num_out_tokens is obtained from the actual tensor shape,
# which reflects the data-dependent output size from the forward pass.
local_num_out_tokens = fwd_input.shape[0]
return UnpermuteBwdWithMergingProbsPrimitive.impl(
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
num_tokens=num_tokens,
num_tokens=local_num_tokens,
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
num_out_tokens=local_num_out_tokens,
hidden_size=hidden_size, # hidden_size is not sharded
with_unpad=with_unpad,
)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def lowering(
ctx,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
*,
def shardy_sharding_rule(
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
with_unpad,
mesh,
value_types,
result_types,
):
"""MLIR lowering using triton_call_lowering."""
del num_out_tokens
# Compute strides
row_id_stride_token = num_experts * 2 + 1
row_id_stride_expert = 1
fwd_output_grad_stride_token = hidden_size
fwd_output_grad_stride_hidden = 1
fwd_input_grad_stride_token = hidden_size
fwd_input_grad_stride_hidden = 1
fwd_input_stride_token = hidden_size
fwd_input_stride_hidden = 1
merging_probs_stride_token = num_experts
merging_probs_stride_expert = 1
merging_probs_grad_stride_token = num_experts
merging_probs_grad_stride_expert = 1
# Grid - one program per token
grid = (num_tokens,)
# Get min block size from autotune configs for consistency
block_size = _get_min_block_size(_unpermute_bwd_with_merging_probs_kernel)
return triton_call_lowering(
ctx,
_unpermute_bwd_with_merging_probs_kernel,
fwd_output_grad,
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
grid=grid,
constexprs={
"stride_row_id_map_token": row_id_stride_token,
"stride_row_id_map_expert": row_id_stride_expert,
"stride_fwd_output_grad_token": fwd_output_grad_stride_token,
"stride_fwd_output_grad_hidden": fwd_output_grad_stride_hidden,
"stride_fwd_input_grad_token": fwd_input_grad_stride_token,
"stride_fwd_input_grad_hidden": fwd_input_grad_stride_hidden,
"stride_fwd_input_token": fwd_input_stride_token,
"stride_fwd_input_hidden": fwd_input_stride_hidden,
"stride_merging_probs_token": merging_probs_stride_token,
"stride_merging_probs_expert": merging_probs_stride_expert,
"stride_merging_probs_grad_token": merging_probs_grad_stride_token,
"stride_merging_probs_grad_expert": merging_probs_grad_stride_expert,
"num_experts": num_experts,
"hidden_size": hidden_size,
"PROBS_LOAD_WIDTH": triton.next_power_of_2(num_experts),
"FUSION_UNPAD": True,
"BLOCK_SIZE": block_size,
},
"""Shardy sharding rule for this primitive."""
del num_tokens, num_experts, num_out_tokens, hidden_size, mesh, value_types, result_types
prefix = "UnpermuteBwdWithMergingProbs"
fwd_output_grad_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
fwd_input_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
merging_probs_spec = (f"{prefix}_tokens", f"{prefix}_experts")
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# pad_offsets: (num_experts,) when with_unpad=True, or dummy (0,) otherwise
pad_offsets_spec = (f"{prefix}_experts",) if with_unpad else (f"{prefix}_pad_empty",)
fwd_input_grad_spec = (f"{prefix}_out_tokens", f"{prefix}_hidden")
merging_probs_grad_spec = (f"{prefix}_tokens", f"{prefix}_experts")
return SdyShardingRule(
(
fwd_output_grad_spec,
fwd_input_spec,
merging_probs_spec,
row_id_map_spec,
pad_offsets_spec,
),
(fwd_input_grad_spec, merging_probs_grad_spec),
)
register_primitive(UnpermuteBwdWithMergingProbsAndUnpadPrimitive)
register_primitive(UnpermuteBwdWithMergingProbsPrimitive)
def unpermute_bwd_with_merging_probs(
......@@ -1027,7 +1472,7 @@ def unpermute_bwd_with_merging_probs(
merging_probs_grad : jnp.ndarray
Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
"""
# Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature)
# Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature)
dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
# Pass arguments in kernel order: fwd_output_grad, fwd_input, merging_probs, row_id_map, pad_offsets
return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
......@@ -1040,6 +1485,7 @@ def unpermute_bwd_with_merging_probs(
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_unpad=False,
)
......@@ -1088,7 +1534,7 @@ def unpermute_bwd_with_merging_probs_and_unpad(
merging_probs_grad : jnp.ndarray
Gradient w.r.t. merging_probs of shape `[num_tokens, num_experts]`.
"""
return UnpermuteBwdWithMergingProbsAndUnpadPrimitive.outer_primitive.bind(
return UnpermuteBwdWithMergingProbsPrimitive.outer_primitive.bind(
fwd_output_grad,
fwd_input,
merging_probs,
......@@ -1098,6 +1544,7 @@ def unpermute_bwd_with_merging_probs_and_unpad(
num_experts=num_experts,
num_out_tokens=num_out_tokens,
hidden_size=hidden_size,
with_unpad=True,
)
......@@ -1147,6 +1594,54 @@ class MakeChunkSortMapPrimitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(num_tokens, num_splits, mesh, arg_infos, result_infos):
"""Infer output sharding from input sharding."""
del num_tokens, num_splits, result_infos, arg_infos
# row_id_map is replicated since split_sizes and sorted_indices are typically small
return NamedSharding(
mesh,
PartitionSpec(None),
desc="MakeChunkSortMap.row_id_map_sharding",
)
@staticmethod
def partition(num_tokens, num_splits, mesh, arg_infos, result_infos):
"""Partition the primitive for distributed execution."""
del result_infos
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="MakeChunkSortMap.row_id_map_sharding",
)
def sharded_impl(split_sizes, sorted_indices):
return MakeChunkSortMapPrimitive.impl(
split_sizes,
sorted_indices,
num_tokens=num_tokens,
num_splits=num_splits,
)
return mesh, sharded_impl, out_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(num_tokens, num_splits, mesh, value_types, result_types):
"""Shardy sharding rule for this primitive."""
del num_tokens, num_splits, mesh, value_types, result_types
prefix = "MakeChunkSortMap"
split_sizes_spec = (f"{prefix}_splits",)
sorted_indices_spec = (f"{prefix}_splits",)
row_id_map_spec = (f"{prefix}_tokens",)
return SdyShardingRule(
(split_sizes_spec, sorted_indices_spec),
(row_id_map_spec,),
)
register_primitive(MakeChunkSortMapPrimitive)
......@@ -1228,6 +1723,91 @@ class SortChunksByMapPrimitive(BasePrimitive):
},
)
@staticmethod
def infer_sharding_from_operands(
num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos
):
"""Infer output sharding from input sharding."""
del num_tokens, hidden_size, is_forward, result_infos
inp_spec = get_padded_spec(arg_infos[0])
output_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0], None),
desc="SortChunksByMap.output_sharding",
)
if with_probs:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0]),
desc="SortChunksByMap.permuted_probs_sharding",
)
else:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="SortChunksByMap.permuted_probs_sharding_empty",
)
return [output_sharding, permuted_probs_sharding]
@staticmethod
def partition(num_tokens, hidden_size, is_forward, with_probs, mesh, arg_infos, result_infos):
"""Partition the primitive for distributed execution."""
del num_tokens, result_infos
inp_spec = get_padded_spec(arg_infos[0])
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
output_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0], None),
desc="SortChunksByMap.output_sharding",
)
if with_probs:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(inp_spec[0]),
desc="SortChunksByMap.permuted_probs_sharding",
)
else:
permuted_probs_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="SortChunksByMap.permuted_probs_sharding_empty",
)
out_shardings = [output_sharding, permuted_probs_sharding]
def sharded_impl(inp, row_id_map, probs):
local_num_tokens = inp.shape[0]
return SortChunksByMapPrimitive.impl(
inp,
row_id_map,
probs,
num_tokens=local_num_tokens,
hidden_size=hidden_size, # hidden_size is not sharded
is_forward=is_forward,
with_probs=with_probs,
)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
num_tokens, hidden_size, is_forward, with_probs, mesh, value_types, result_types
):
"""Shardy sharding rule for this primitive."""
del num_tokens, hidden_size, is_forward, mesh, value_types, result_types
prefix = "SortChunksByMap"
inp_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
row_id_map_spec = (f"{prefix}_tokens",)
probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty",)
output_spec = (f"{prefix}_tokens", f"{prefix}_hidden")
permuted_probs_spec = (f"{prefix}_tokens",) if with_probs else (f"{prefix}_empty2",)
return SdyShardingRule(
(inp_spec, row_id_map_spec, probs_spec),
(output_spec, permuted_probs_spec),
)
register_primitive(SortChunksByMapPrimitive)
......@@ -1356,6 +1936,7 @@ def permute_with_mask_map(
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=False,
align_size=128, # Default value, no-op for non-padding case
)
if not with_probs:
......@@ -1373,6 +1954,7 @@ def permute_with_mask_map_and_pad(
num_experts: int,
num_out_tokens: int,
hidden_size: int,
align_size: int = 128,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""
Permute the input tensor based on the row_id_map with fused padding.
......@@ -1395,13 +1977,18 @@ def permute_with_mask_map_and_pad(
Number of tokens in the permuted tensor (including padding).
hidden_size : int
Hidden size of the input tensor.
align_size : int
Alignment size for padding (default: 128). Used for distributed sharding
to correctly compute local buffer sizes.
Returns
-------
output : jnp.ndarray
Permuted and padded output tensor of shape `[num_out_tokens, hidden_size]`.
Padding positions are zero-filled.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities if probs was provided, None otherwise.
Padding positions are zero-filled.
"""
with_probs = probs is not None
......@@ -1426,8 +2013,14 @@ def permute_with_mask_map_and_pad(
hidden_size=hidden_size,
with_probs=with_probs,
with_pad=True,
align_size=align_size,
)
# Note: Zero-filling of padding positions is handled by pre-zeroing the output
# buffers in impl() using jnp.zeros(), then aliasing them to the kernel's outputs
# via input_output_aliases. The kernel only writes to valid positions, leaving
# padding positions at zero.
if not with_probs:
permuted_probs = None
......@@ -1479,7 +2072,7 @@ def unpermute_with_mask_map(
merging_probs = jnp.zeros((0,), dtype=inp.dtype)
if not with_probs:
permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
# Create dummy pad_offsets (not used when FUSION_UNPAD=False, but required by kernel signature)
# Create dummy pad_offsets (not used when with_unpad=False, but required by kernel signature)
dummy_pad_offsets = jnp.zeros((0,), dtype=jnp.int32)
output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
......@@ -1493,6 +2086,7 @@ def unpermute_with_mask_map(
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
with_unpad=False,
)
if not with_probs:
......@@ -1550,7 +2144,7 @@ def unpermute_with_mask_map_and_unpad(
if not with_probs:
permuted_probs = jnp.zeros((0,), dtype=inp.dtype)
output, unpermuted_probs = UnpermuteWithMaskMapAndUnpadPrimitive.outer_primitive.bind(
output, unpermuted_probs = UnpermuteWithMaskMapPrimitive.outer_primitive.bind(
inp,
row_id_map,
merging_probs,
......@@ -1561,6 +2155,7 @@ def unpermute_with_mask_map_and_unpad(
hidden_size=hidden_size,
with_merging_probs=with_merging_probs,
with_probs=with_probs,
with_unpad=True,
)
if not with_probs:
......
......@@ -409,7 +409,8 @@ def triton_call_lowering(
kernel_constexprs = constexprs if constexprs is not None else {}
# Handle autotuned kernels - compile all configs
if isinstance(kernel_fn, autotuner.Autotuner):
is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
if is_autotuned:
# Compile all configs for runtime selection
kernel_calls = []
actual_kernel_fn = kernel_fn.fn
......@@ -450,24 +451,23 @@ def triton_call_lowering(
kernel_calls.append((config_call, str(config)))
# Create autotuned kernel call
# Convert input_output_aliases to format with sizes
if input_output_aliases is None:
input_output_aliases = {}
input_output_aliases_with_sizes = tuple(
(
input_idx,
output_idx,
ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize,
)
for input_idx, output_idx in input_output_aliases.items()
)
# IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
#
# Background:
# 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
# output can reuse an input's buffer. XLA may or may not honor this.
# 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
# save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
#
# The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
# but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
# to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
#
# WAR: Don't pass aliases to TritonAutotunedKernelCall.
kernel_call = gpu_triton.TritonAutotunedKernelCall(
f"{actual_kernel_fn.__name__}_autotuned",
kernel_calls,
input_output_aliases_with_sizes,
(), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
)
else:
......@@ -498,15 +498,17 @@ def triton_call_lowering(
serialized_metadata = b""
call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)
if input_output_aliases is None:
input_output_aliases = {}
if input_output_aliases:
ffi_operand_output_aliases = input_output_aliases
else:
ffi_operand_output_aliases = None
# Use JAX FFI lowering with compressed protobuf
rule = jax.ffi.ffi_lowering(
"triton_kernel_call", # Custom call target registered in gpu_triton.py
api_version=2,
backend_config=zlib.compress(call_proto),
operand_output_aliases=input_output_aliases,
operand_output_aliases=ffi_operand_output_aliases,
)
return rule(ctx, *array_args)
......@@ -157,8 +157,8 @@ def permute_with_mask_map(
scale_hidden_dim : int
Hidden size of the scale tensor.
"""
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
# since the kernel doesn't write to padding positions.
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed.
# The kernel writes only to valid positions, leaving padding positions at zero.
alloc = torch.zeros if pad_offsets is not None else torch.empty
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = (
......@@ -178,7 +178,13 @@ def permute_with_mask_map(
scale,
permuted_scale,
pad_offsets,
# Pass output buffers as input parameters (for JAX input_output_aliases compatibility).
# In PyTorch, these point to the same memory as the output pointers below.
output,
permuted_probs,
scale_hidden_dim,
num_tokens,
num_out_tokens,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......@@ -252,6 +258,10 @@ def unpermute_with_mask_map(
merging_probs,
permuted_probs,
pad_offsets,
# Dummy buffer parameters for kernel signature consistency with _permute_kernel.
# These are unused in unpermute but maintain consistent interface.
output, # output_buf_ptr (unused, passed for signature consistency)
unpermuted_probs, # unpermuted_probs_buf_ptr (unused, passed for signature consistency)
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment