Unverified Commit 5ba01faa authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization (#1921)



* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization

1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
  that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch]format code
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [Common]perf expert_idx loaded once
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* fix: pad_offsets can be None
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* add padding + merging probs bwd support. Not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fix garbage initialized act grad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* all test passing for jax permutation + pad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change test permutation to reduce test time
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* triggering PR refresh
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix lint in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix typo
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatartdophung <tdophung@nvidia.com>
parent 97a09c29
This diff is collapsed.
This diff is collapsed.
......@@ -200,6 +200,7 @@ def _permute_kernel(
probs_ptr,
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# sizes
scale_hidden_dim,
# strides
......@@ -224,8 +225,11 @@ def _permute_kernel(
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
......@@ -246,6 +250,15 @@ def _permute_kernel(
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
......@@ -253,11 +266,6 @@ def _permute_kernel(
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
......@@ -297,6 +305,7 @@ def _unpermute_kernel(
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -318,10 +327,12 @@ def _unpermute_kernel(
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
......@@ -348,15 +359,19 @@ def _unpermute_kernel(
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
if FUSION_UNPAD or WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
......@@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
fwd_input_ptr,
merging_probs_ptr,
row_id_map_ptr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
......@@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
......
......@@ -25,8 +25,11 @@ import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map,
permute_with_mask_map,
permute_with_mask_map_and_pad,
unpermute_with_mask_map,
unpermute_with_mask_map_and_unpad,
unpermute_bwd_with_merging_probs,
unpermute_bwd_with_merging_probs_and_unpad,
make_chunk_sort_map,
sort_chunks_by_map,
)
......@@ -43,7 +46,14 @@ def token_dispatch(
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
align_size: Optional[int] = None,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""
Dispatch tokens to experts based on routing map.
......@@ -51,6 +61,10 @@ def token_dispatch(
to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map.
Optionally supports fused padding for alignment when `align_size` is provided.
This is useful for efficient matrix multiplications that require aligned tensor
dimensions. The padding is computed internally from the routing_map.
Parameters
----------
inp : jnp.ndarray
......@@ -59,36 +73,99 @@ def token_dispatch(
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed.
num_out_tokens : int
The number of output tokens after permutation. This should equal the sum of
routing_map and must be provided explicitly for JIT compatibility.
The number of output tokens after permutation (before padding). For the dropless
case, this should be equal to the sum of routing_map. Must be provided explicitly
for JIT compatibility since output shape must be known at compile time.
probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned.
align_size : Optional[int]
Optional alignment size for padding. If provided, outputs will be padded to
align each expert's tokens to a multiple of this size. The output buffer is
allocated with worst-case size, rounded down to align_size:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This enables full JIT compatibility.
Returns
-------
output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size].
Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
or [worst_case_padded_size, hidden_size] when using padding fusion.
With padding, the actual used portion may be smaller than the buffer; check
actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens], or None if probs was not provided.
Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
or None if probs was not provided.
row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
Note
----
**JIT Compatibility:**
This function is fully JIT-compatible. When using padding (align_size provided),
the output buffer is allocated with a fixed worst-case size that depends only on
compile-time constants (num_out_tokens, num_experts, align_size). The actual
padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
are computed internally from the routing_map and can be traced values.
The worst-case output size is:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This accounts for the maximum possible padding when each expert needs (align_size - 1)
extra tokens to align, rounded down to align_size for buffer alignment.
"""
return _token_dispatch(inp, routing_map, probs, num_out_tokens)
use_padding = align_size is not None
num_experts = routing_map.shape[-1]
if use_padding:
# Compute worst-case output size (compile-time constant)
# This is the maximum possible size when each expert needs max padding
worst_case_out_tokens = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
else:
worst_case_out_tokens = num_out_tokens
return _token_dispatch(
inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule(
inp, routing_map, probs, num_out_tokens
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
probs,
num_out_tokens,
worst_case_out_tokens,
align_size,
use_padding,
)
)
return output, permuted_probs, row_id_map
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def _token_dispatch_fwd_rule(
......@@ -96,9 +173,18 @@ def _token_dispatch_fwd_rule(
routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray],
num_out_tokens: int,
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
Tuple[jnp.ndarray, int, int, int, bool],
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
"""Forward pass rule for token_dispatch."""
# Validate input dimensions
......@@ -126,42 +212,102 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
)
# Compute pad_offsets: cumulative padding for each expert
# pad_offsets[i] = sum of (target - actual) for experts 0..i-1
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = jnp.cumsum(pad_lengths)
pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])
# Use worst_case_out_tokens as the output buffer size (compile-time constant)
# The actual used size is sum(target_tokens_per_expert), which may be smaller.
# Unused positions will be zero-initialized by the kernel.
output, permuted_probs = permute_with_mask_map_and_pad(
inp,
row_id_map,
probs,
pad_offsets,
num_tokens,
num_experts,
worst_case_out_tokens,
hidden_size,
)
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals)
# Include with_probs flag to know how to handle backward pass
residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs)
return (output, permuted_probs, row_id_map), residuals
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
), residuals
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
residuals: Tuple[jnp.ndarray, int, int, int, bool],
g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
_worst_case_out_tokens: int,
_align_size: Optional[int],
_use_padding: bool,
residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
g: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
# Backward: unpermute gradients (gather from experts back to tokens)
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
if pad_offsets is not None:
inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
......@@ -178,6 +324,7 @@ def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Combine tokens from experts back to original token positions.
......@@ -185,33 +332,42 @@ def token_combine(
This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs).
Optionally supports fused unpadding when `pad_offsets` is provided (from
token_dispatch with padding enabled).
Parameters
----------
inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size].
Input tensor from experts of shape [num_out_tokens, hidden_size]
(or [num_out_tokens_padded, hidden_size] when using unpadding).
row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly.
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
If provided, fused unpadding will be performed. This should be the pad_offsets
returned by token_dispatch when using padding.
Returns
-------
output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size].
"""
return _token_combine(inp, row_id_map, merging_probs)
return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
@jax.custom_vjp
def _token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
pad_offsets: Optional[jnp.ndarray],
) -> jnp.ndarray:
"""Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs)
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
return output
......@@ -219,7 +375,20 @@ def _token_combine_fwd_rule(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray],
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]:
pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
jnp.ndarray,
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
]:
"""Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0]
......@@ -227,21 +396,34 @@ def _token_combine_fwd_rule(
hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0]
# Call triton extension
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Call triton extension with or without unpadding
if pad_offsets is not None:
output, _ = unpermute_with_mask_map_and_unpad(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
else:
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals)
# Include inp in residuals for backward with merging_probs
residuals = (
row_id_map,
pad_offsets,
inp,
merging_probs,
num_tokens,
......@@ -253,13 +435,26 @@ def _token_combine_fwd_rule(
def _token_combine_bwd_rule(
row_id_map: jnp.ndarray,
residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int],
residuals: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_combine."""
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
"""Backward pass rule for token_combine.
Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
row_id_map and pad_offsets are integer arrays, so their gradients are None.
"""
(
row_id_map,
pad_offsets,
fwd_input,
merging_probs,
num_tokens,
......@@ -273,30 +468,63 @@ def _token_combine_bwd_rule(
if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if pad_offsets is not None:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
output_grad,
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# The backward kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else:
# Simple case: just permute gradients back
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if pad_offsets is not None:
inp_grad, _ = permute_with_mask_map_and_pad(
output_grad,
row_id_map,
None,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None
return inp_grad, merging_probs_grad
# Return gradients for: inp, row_id_map, merging_probs, pad_offsets
# row_id_map and pad_offsets are integer arrays, so their gradients are None
return inp_grad, None, merging_probs_grad, None
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
......
......@@ -142,17 +142,31 @@ def compile_triton(
)
# Create kernel object for JAX
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1, # cluster_dims
)
# From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str)
num_warps, # arg1: num_warps (int)
num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
compute_capability, # arg6: compute_capability (int)
)
else:
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1,
)
_TRITON_KERNEL_CACHE[cache_key] = kernel
return kernel
......
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_permute_with_probs,
moe_permute_and_pad_with_probs,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""MoE Permutaion API"""
"""MoE Permutation API"""
import warnings
from typing import Optional, Tuple
import torch
......@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
routing_map: torch.Tensor,
num_out_tokens: int,
probs: torch.Tensor,
pad_offsets: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function):
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
......@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
probs,
fp8_scale,
pad_offsets,
num_tokens,
num_experts,
num_out_tokens,
......@@ -292,7 +296,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
......@@ -307,12 +311,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None, ctx.probs
return permuted_act_grad, None, None, ctx.probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
assert not isinstance(
permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8."
......@@ -321,13 +325,14 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map,
None,
permuted_probs_grad,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
return act_grad, None, None, probs_grad, None
class _moe_unpermute_mask_map(torch.autograd.Function):
......@@ -340,6 +345,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map: torch.Tensor,
merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size],
pad_offsets: Optional[torch.Tensor],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -358,6 +364,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert not isinstance(
inp, QuantizedTensor
......@@ -367,15 +375,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
merging_probs,
None,
pad_offsets,
num_tokens,
num_experts,
hidden_size,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs)
ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
else:
ctx.save_for_backward(row_id_map)
ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0)
......@@ -387,15 +396,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.merging_probs, None
return unpermuted_act_grad, None, ctx.merging_probs, None, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, merging_probs = ctx.saved_tensors
fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
row_id_map, pad_offsets = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
......@@ -441,6 +450,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
fwd_input,
merging_probs,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -453,6 +463,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map,
None,
fp8_scale,
pad_offsets,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -497,7 +508,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None
return act_grad, None, probs_grad, None, None
def moe_permute(
......@@ -537,7 +548,9 @@ def moe_permute(
if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
output, row_id_map, _ = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, None, None
)
return output, row_id_map
raise ValueError("map_type should be one of 'mask' or 'index'")
......@@ -570,11 +583,67 @@ def moe_permute_with_probs(
By default, set to '-1', meaning no tokens are dropped.
"""
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, probs
inp, routing_map, num_out_tokens, probs, None
)
return output, permuted_probs, row_id_map
def moe_permute_and_pad_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
tokens_per_expert: torch.Tensor,
align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor.
"""
assert (
tokens_per_expert is not None
), "tokens_per_expert must be provided to the fused permute padding function."
assert align_size > 0, f"align_size must be positive, got {align_size}"
# Ensure tokens_per_expert is on the same device as input to avoid device transfers
if tokens_per_expert.device != inp.device:
tokens_per_expert = tokens_per_expert.to(inp.device)
# Calculate aligned token counts per expert
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
if torch.equal(tokens_per_expert, target_tokens_per_expert):
pad_offsets = None
else:
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = torch.cumsum(pad_lengths, dim=0)
pad_offsets = torch.cat(
[torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
)
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -582,6 +651,7 @@ def moe_unpermute(
restore_shape: Optional[torch.Size] = None,
map_type: str = "mask",
probs: Optional[torch.Tensor] = None,
pad_offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -605,6 +675,10 @@ def moe_unpermute(
Options are: 'mask', 'index'.
probs : torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.
"""
if probs is not None:
if merging_probs is not None:
......@@ -616,7 +690,9 @@ def moe_unpermute(
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
return _moe_unpermute_mask_map.apply(
inp, row_id_map, merging_probs, restore_shape, pad_offsets
)
raise ValueError("map_type should be one of 'mask' or 'index'")
......
......@@ -123,6 +123,7 @@ def permute_with_mask_map(
row_id_map: torch.Tensor,
probs: torch.Tensor,
scale: torch.Tensor,
pad_offsets: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -142,6 +143,9 @@ def permute_with_mask_map(
The probabilities of the input tensor. If it is not None, it will be permuted.
scale : torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int
Number of tokens in the input tensor.
num_experts : int
......@@ -153,18 +157,18 @@ def permute_with_mask_map(
scale_hidden_dim : int
Hidden size of the scale tensor.
"""
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
# since the kernel doesn't write to padding positions.
alloc = torch.zeros if pad_offsets is not None else torch.empty
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = (
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
)
permuted_scale = (
torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
if scale is not None
else None
)
# pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid](
......@@ -173,6 +177,7 @@ def permute_with_mask_map(
probs,
scale,
permuted_scale,
pad_offsets,
scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
......@@ -193,6 +198,7 @@ def permute_with_mask_map(
hidden_size,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
FUSION_PAD=pad_offsets is not None,
)
return output, permuted_scale, permuted_probs
......@@ -202,6 +208,7 @@ def unpermute_with_mask_map(
row_id_map: torch.Tensor,
merging_probs: Union[torch.Tensor, None],
permuted_probs: Union[torch.Tensor, None],
pad_offsets: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
hidden_size: int,
......@@ -220,6 +227,9 @@ def unpermute_with_mask_map(
to reduce the unpermuted tokens.
permuted_probs : torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding.
If it is not None, it will remove the previously fused padding.
num_tokens : int
Number of tokens in the permuted tensor.
num_experts : int
......@@ -241,6 +251,7 @@ def unpermute_with_mask_map(
row_id_map,
merging_probs,
permuted_probs,
pad_offsets,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......@@ -259,6 +270,7 @@ def unpermute_with_mask_map(
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
FUSION_UNPAD=pad_offsets is not None,
)
return output, unpermuted_probs
......@@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
row_id_map: torch.Tensor,
fwd_input: torch.Tensor,
merging_probs: torch.Tensor,
pad_offsets: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs : torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int
Number of tokens in the permuted tensor.
num_experts : int
......@@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
hidden_size : int
Hidden size of the output tensor.
"""
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
# Use zeros when pad_offsets is used because padding slots won't be written to
# by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros
# out the padding slots.
alloc = torch.zeros if pad_offsets is not None else torch.empty
act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda")
merging_probs_grad = torch.empty(
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
)
......@@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_input,
merging_probs,
row_id_map,
pad_offsets,
row_id_map.stride(0),
row_id_map.stride(1),
fwd_output_grad.stride(0),
......@@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
FUSION_UNPAD=pad_offsets is not None,
)
return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment