"vscode:/vscode.git/clone" did not exist on "343cea3dbcc5d3ee44654fa6289262c2553486c5"
Unverified Commit 5ba01faa authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization (#1921)



* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization

1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`,
  that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage.
2.Add tests of fused permute/pad and unpermute/unpad.
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [PyTorch]format code
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* [Common]perf expert_idx loaded once
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* fix: pad_offsets can be None
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* add padding + merging probs bwd support. Not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fix garbage initialized act grad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* all test passing for jax permutation + pad
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* change test permutation to reduce test time
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* triggering PR refresh
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix lint in jax
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* format code
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix typo
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatartdophung <tdophung@nvidia.com>
parent 97a09c29
This diff is collapsed.
This diff is collapsed.
...@@ -200,6 +200,7 @@ def _permute_kernel( ...@@ -200,6 +200,7 @@ def _permute_kernel(
probs_ptr, probs_ptr,
scale_ptr, scale_ptr,
permuted_scale_ptr, permuted_scale_ptr,
pad_offsets_ptr,
# sizes # sizes
scale_hidden_dim, scale_hidden_dim,
# strides # strides
...@@ -224,8 +225,11 @@ def _permute_kernel( ...@@ -224,8 +225,11 @@ def _permute_kernel(
hidden_size: tl.constexpr, hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr, PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
expert_idx = 0
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
...@@ -246,6 +250,15 @@ def _permute_kernel( ...@@ -246,6 +250,15 @@ def _permute_kernel(
dst_row = tl.load( dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64) ).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE: if PERMUTE_SCALE:
permuted_scale_off = ( permuted_scale_off = (
...@@ -253,11 +266,6 @@ def _permute_kernel( ...@@ -253,11 +266,6 @@ def _permute_kernel(
) )
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS: if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off) prob = tl.load(probs_ptr + prob_off)
if pid_h == 0: if pid_h == 0:
...@@ -297,6 +305,7 @@ def _unpermute_kernel( ...@@ -297,6 +305,7 @@ def _unpermute_kernel(
row_id_map_ptr, row_id_map_ptr,
merging_probs_ptr, merging_probs_ptr,
permuted_probs_ptr, permuted_probs_ptr,
pad_offsets_ptr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -318,10 +327,12 @@ def _unpermute_kernel( ...@@ -318,10 +327,12 @@ def _unpermute_kernel(
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
data_type = input_ptr.dtype.element_ty data_type = input_ptr.dtype.element_ty
compute_type = tl.float32 compute_type = tl.float32
expert_idx = 0
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
...@@ -348,15 +359,19 @@ def _unpermute_kernel( ...@@ -348,15 +359,19 @@ def _unpermute_kernel(
src_row = tl.load( src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64) ).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden if FUSION_UNPAD or WITH_MERGING_PROBS:
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load( expert_idx = tl.load(
row_id_map_ptr row_id_map_ptr
+ pid_t * stride_row_id_map_token + pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert + (num_experts + idx) * stride_row_id_map_expert
) )
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = ( merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
) )
...@@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -407,6 +422,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
fwd_input_ptr, fwd_input_ptr,
merging_probs_ptr, merging_probs_ptr,
row_id_map_ptr, row_id_map_ptr,
pad_offsets_ptr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -427,6 +443,7 @@ def _unpermute_bwd_with_merging_probs_kernel(
num_experts: tl.constexpr, num_experts: tl.constexpr,
hidden_size: tl.constexpr, hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
data_type = fwd_output_grad_ptr.dtype.element_ty data_type = fwd_output_grad_ptr.dtype.element_ty
...@@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -450,6 +467,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token + pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert + (num_experts + idx) * stride_row_id_map_expert
) )
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0 current_start = 0
while current_start < hidden_size: while current_start < hidden_size:
......
...@@ -25,8 +25,11 @@ import jax.numpy as jnp ...@@ -25,8 +25,11 @@ import jax.numpy as jnp
from transformer_engine.jax.triton_extensions.permutation import ( from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map, make_row_id_map,
permute_with_mask_map, permute_with_mask_map,
permute_with_mask_map_and_pad,
unpermute_with_mask_map, unpermute_with_mask_map,
unpermute_with_mask_map_and_unpad,
unpermute_bwd_with_merging_probs, unpermute_bwd_with_merging_probs,
unpermute_bwd_with_merging_probs_and_unpad,
make_chunk_sort_map, make_chunk_sort_map,
sort_chunks_by_map, sort_chunks_by_map,
) )
...@@ -43,7 +46,14 @@ def token_dispatch( ...@@ -43,7 +46,14 @@ def token_dispatch(
routing_map: jnp.ndarray, routing_map: jnp.ndarray,
num_out_tokens: int, num_out_tokens: int,
probs: Optional[jnp.ndarray] = None, probs: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: align_size: Optional[int] = None,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
""" """
Dispatch tokens to experts based on routing map. Dispatch tokens to experts based on routing map.
...@@ -51,6 +61,10 @@ def token_dispatch( ...@@ -51,6 +61,10 @@ def token_dispatch(
to their designated experts according to the routing map. The row_id_map to their designated experts according to the routing map. The row_id_map
is computed internally from the routing_map. is computed internally from the routing_map.
Optionally supports fused padding for alignment when `align_size` is provided.
This is useful for efficient matrix multiplications that require aligned tensor
dimensions. The padding is computed internally from the routing_map.
Parameters Parameters
---------- ----------
inp : jnp.ndarray inp : jnp.ndarray
...@@ -59,36 +73,99 @@ def token_dispatch( ...@@ -59,36 +73,99 @@ def token_dispatch(
Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Routing mask of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
Values: 1 = routed, 0 = not routed. Values: 1 = routed, 0 = not routed.
num_out_tokens : int num_out_tokens : int
The number of output tokens after permutation. This should equal the sum of The number of output tokens after permutation (before padding). For the dropless
routing_map and must be provided explicitly for JIT compatibility. case, this should be equal to the sum of routing_map. Must be provided explicitly
for JIT compatibility since output shape must be known at compile time.
probs : Optional[jnp.ndarray] probs : Optional[jnp.ndarray]
Optional routing probabilities of shape [batch, sequence, num_experts] or Optional routing probabilities of shape [batch, sequence, num_experts] or
[num_tokens, num_experts]. If provided, permuted_probs will be returned. [num_tokens, num_experts]. If provided, permuted_probs will be returned.
align_size : Optional[int]
Optional alignment size for padding. If provided, outputs will be padded to
align each expert's tokens to a multiple of this size. The output buffer is
allocated with worst-case size, rounded down to align_size:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This enables full JIT compatibility.
Returns Returns
------- -------
output : jnp.ndarray output : jnp.ndarray
Permuted output tensor of shape [num_out_tokens, hidden_size]. Permuted output tensor of shape [num_out_tokens, hidden_size] without padding,
or [worst_case_padded_size, hidden_size] when using padding fusion.
With padding, the actual used portion may be smaller than the buffer; check
actual_num_out_tokens (sum of target_tokens_per_expert) for the actual size.
permuted_probs : Optional[jnp.ndarray] permuted_probs : Optional[jnp.ndarray]
Permuted probabilities of shape [num_out_tokens], or None if probs was not provided. Permuted probabilities of shape [num_out_tokens] or [worst_case_padded_size],
or None if probs was not provided.
row_id_map : jnp.ndarray row_id_map : jnp.ndarray
Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]). Row ID map for use in token_combine (shape [num_tokens, num_experts * 2 + 1]).
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
Note
----
**JIT Compatibility:**
This function is fully JIT-compatible. When using padding (align_size provided),
the output buffer is allocated with a fixed worst-case size that depends only on
compile-time constants (num_out_tokens, num_experts, align_size). The actual
padding offsets (pad_offsets) and aligned token counts (target_tokens_per_expert)
are computed internally from the routing_map and can be traced values.
The worst-case output size is:
((num_out_tokens + num_experts * (align_size - 1)) // align_size) * align_size
This accounts for the maximum possible padding when each expert needs (align_size - 1)
extra tokens to align, rounded down to align_size for buffer alignment.
""" """
return _token_dispatch(inp, routing_map, probs, num_out_tokens) use_padding = align_size is not None
num_experts = routing_map.shape[-1]
if use_padding:
# Compute worst-case output size (compile-time constant)
# This is the maximum possible size when each expert needs max padding
worst_case_out_tokens = (
(num_out_tokens + num_experts * (align_size - 1)) // align_size
) * align_size
else:
worst_case_out_tokens = num_out_tokens
return _token_dispatch(
inp, routing_map, probs, num_out_tokens, worst_case_out_tokens, align_size, use_padding
)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6))
def _token_dispatch( def _token_dispatch(
inp: jnp.ndarray, inp: jnp.ndarray,
routing_map: jnp.ndarray, routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray], probs: Optional[jnp.ndarray],
num_out_tokens: int, num_out_tokens: int,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
]:
"""Internal token_dispatch with custom VJP.""" """Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map), _ = _token_dispatch_fwd_rule( (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
inp, routing_map, probs, num_out_tokens _token_dispatch_fwd_rule(
inp,
routing_map,
probs,
num_out_tokens,
worst_case_out_tokens,
align_size,
use_padding,
)
) )
return output, permuted_probs, row_id_map return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def _token_dispatch_fwd_rule( def _token_dispatch_fwd_rule(
...@@ -96,9 +173,18 @@ def _token_dispatch_fwd_rule( ...@@ -96,9 +173,18 @@ def _token_dispatch_fwd_rule(
routing_map: jnp.ndarray, routing_map: jnp.ndarray,
probs: Optional[jnp.ndarray], probs: Optional[jnp.ndarray],
num_out_tokens: int, num_out_tokens: int,
worst_case_out_tokens: int,
align_size: Optional[int],
use_padding: bool,
) -> Tuple[ ) -> Tuple[
Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], Tuple[
Tuple[jnp.ndarray, int, int, int, bool], jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]: ]:
"""Forward pass rule for token_dispatch.""" """Forward pass rule for token_dispatch."""
# Validate input dimensions # Validate input dimensions
...@@ -126,42 +212,102 @@ def _token_dispatch_fwd_rule( ...@@ -126,42 +212,102 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None with_probs = probs is not None
output, permuted_probs = permute_with_mask_map( if use_padding:
inp, # Compute tokens_per_expert internally from routing_map
row_id_map, # This can be a traced value since output shape uses worst_case_out_tokens
probs, tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
num_tokens,
num_experts, # Calculate aligned token counts per expert
num_out_tokens, target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
hidden_size, jnp.int32
) )
# Compute pad_offsets: cumulative padding for each expert
# pad_offsets[i] = sum of (target - actual) for experts 0..i-1
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = jnp.cumsum(pad_lengths)
pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]])
# Use worst_case_out_tokens as the output buffer size (compile-time constant)
# The actual used size is sum(target_tokens_per_expert), which may be smaller.
# Unused positions will be zero-initialized by the kernel.
output, permuted_probs = permute_with_mask_map_and_pad(
inp,
row_id_map,
probs,
pad_offsets,
num_tokens,
num_experts,
worst_case_out_tokens,
hidden_size,
)
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
# Return (primals, residuals) # Return (primals, residuals)
# Include with_probs flag to know how to handle backward pass residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
residuals = (row_id_map, num_tokens, num_experts, hidden_size, with_probs) return (
return (output, permuted_probs, row_id_map), residuals output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
), residuals
def _token_dispatch_bwd_rule( def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray, _routing_map: jnp.ndarray,
_num_out_tokens: int, _num_out_tokens: int,
residuals: Tuple[jnp.ndarray, int, int, int, bool], _worst_case_out_tokens: int,
g: Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray], _align_size: Optional[int],
_use_padding: bool,
residuals: Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
g: Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch.""" """Backward pass rule for token_dispatch."""
row_id_map, num_tokens, num_experts, hidden_size, with_probs = residuals row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _ = g # Ignore row_id_map gradient output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
# Backward: unpermute gradients (gather from experts back to tokens) # Backward: unpermute gradients (gather from experts back to tokens)
inp_grad, probs_grad = unpermute_with_mask_map( if pad_offsets is not None:
output_grad, inp_grad, probs_grad = unpermute_with_mask_map_and_unpad(
row_id_map, output_grad,
None, # No merging probs row_id_map,
permuted_probs_grad if with_probs else None, None, # No merging probs
num_tokens, permuted_probs_grad if with_probs else None,
num_experts, pad_offsets,
hidden_size, num_tokens,
) num_experts,
hidden_size,
)
else:
inp_grad, probs_grad = unpermute_with_mask_map(
output_grad,
row_id_map,
None, # No merging probs
permuted_probs_grad if with_probs else None,
num_tokens,
num_experts,
hidden_size,
)
return inp_grad, probs_grad if with_probs else None return inp_grad, probs_grad if with_probs else None
...@@ -178,6 +324,7 @@ def token_combine( ...@@ -178,6 +324,7 @@ def token_combine(
inp: jnp.ndarray, inp: jnp.ndarray,
row_id_map: jnp.ndarray, row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None, merging_probs: Optional[jnp.ndarray] = None,
pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
Combine tokens from experts back to original token positions. Combine tokens from experts back to original token positions.
...@@ -185,33 +332,42 @@ def token_combine( ...@@ -185,33 +332,42 @@ def token_combine(
This is the forward pass of MoE unpermutation. Tokens are gathered from This is the forward pass of MoE unpermutation. Tokens are gathered from
experts and merged (optionally weighted by merging_probs). experts and merged (optionally weighted by merging_probs).
Optionally supports fused unpadding when `pad_offsets` is provided (from
token_dispatch with padding enabled).
Parameters Parameters
---------- ----------
inp : jnp.ndarray inp : jnp.ndarray
Input tensor from experts of shape [num_out_tokens, hidden_size]. Input tensor from experts of shape [num_out_tokens, hidden_size]
(or [num_out_tokens_padded, hidden_size] when using unpadding).
row_id_map : jnp.ndarray row_id_map : jnp.ndarray
Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1]. Row ID map from token_dispatch of shape [num_tokens, num_experts * 2 + 1].
merging_probs : Optional[jnp.ndarray] merging_probs : Optional[jnp.ndarray]
Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts]. Merging weights of shape [batch, sequence, num_experts] or [num_tokens, num_experts].
If provided, tokens from different experts are weighted-summed. If provided, tokens from different experts are weighted-summed.
If None, tokens are summed directly. If None, tokens are summed directly.
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] from token_dispatch.
If provided, fused unpadding will be performed. This should be the pad_offsets
returned by token_dispatch when using padding.
Returns Returns
------- -------
output : jnp.ndarray output : jnp.ndarray
Combined output tensor of shape [num_tokens, hidden_size]. Combined output tensor of shape [num_tokens, hidden_size].
""" """
return _token_combine(inp, row_id_map, merging_probs) return _token_combine(inp, row_id_map, merging_probs, pad_offsets)
@partial(jax.custom_vjp, nondiff_argnums=(1,)) @jax.custom_vjp
def _token_combine( def _token_combine(
inp: jnp.ndarray, inp: jnp.ndarray,
row_id_map: jnp.ndarray, row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray], merging_probs: Optional[jnp.ndarray],
pad_offsets: Optional[jnp.ndarray],
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Internal token_combine with custom VJP.""" """Internal token_combine with custom VJP."""
output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs) output, _ = _token_combine_fwd_rule(inp, row_id_map, merging_probs, pad_offsets)
return output return output
...@@ -219,7 +375,20 @@ def _token_combine_fwd_rule( ...@@ -219,7 +375,20 @@ def _token_combine_fwd_rule(
inp: jnp.ndarray, inp: jnp.ndarray,
row_id_map: jnp.ndarray, row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray], merging_probs: Optional[jnp.ndarray],
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int]]: pad_offsets: Optional[jnp.ndarray],
) -> Tuple[
jnp.ndarray,
Tuple[
jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
]:
"""Forward pass rule for token_combine.""" """Forward pass rule for token_combine."""
# Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1] # Infer dimensions from row_id_map shape: [num_tokens, num_experts * 2 + 1]
num_tokens = row_id_map.shape[0] num_tokens = row_id_map.shape[0]
...@@ -227,21 +396,34 @@ def _token_combine_fwd_rule( ...@@ -227,21 +396,34 @@ def _token_combine_fwd_rule(
hidden_size = inp.shape[-1] hidden_size = inp.shape[-1]
num_out_tokens = inp.shape[0] num_out_tokens = inp.shape[0]
# Call triton extension # Call triton extension with or without unpadding
output, _ = unpermute_with_mask_map( if pad_offsets is not None:
inp, output, _ = unpermute_with_mask_map_and_unpad(
row_id_map, inp,
merging_probs, row_id_map,
None, # No permuted probs to unpermute merging_probs,
num_tokens, None, # No permuted probs to unpermute
num_experts, pad_offsets,
hidden_size, num_tokens,
) num_experts,
hidden_size,
)
else:
output, _ = unpermute_with_mask_map(
inp,
row_id_map,
merging_probs,
None, # No permuted probs to unpermute
num_tokens,
num_experts,
hidden_size,
)
# Return (primal, residuals) # Return (primal, residuals)
# Include inp in residuals for backward with merging_probs # Include inp in residuals for backward with merging_probs
residuals = ( residuals = (
row_id_map, row_id_map,
pad_offsets,
inp, inp,
merging_probs, merging_probs,
num_tokens, num_tokens,
...@@ -253,13 +435,26 @@ def _token_combine_fwd_rule( ...@@ -253,13 +435,26 @@ def _token_combine_fwd_rule(
def _token_combine_bwd_rule( def _token_combine_bwd_rule(
row_id_map: jnp.ndarray, residuals: Tuple[
residuals: Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray], int, int, int, int], jnp.ndarray,
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
int,
int,
int,
int,
],
g: jnp.ndarray, g: jnp.ndarray,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: ) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray], None]:
"""Backward pass rule for token_combine.""" """Backward pass rule for token_combine.
Returns gradients for: (inp, row_id_map, merging_probs, pad_offsets)
row_id_map and pad_offsets are integer arrays, so their gradients are None.
"""
( (
row_id_map, row_id_map,
pad_offsets,
fwd_input, fwd_input,
merging_probs, merging_probs,
num_tokens, num_tokens,
...@@ -273,30 +468,63 @@ def _token_combine_bwd_rule( ...@@ -273,30 +468,63 @@ def _token_combine_bwd_rule(
if with_merging_probs: if with_merging_probs:
# Use specialized backward kernel that properly scales by merging_probs # Use specialized backward kernel that properly scales by merging_probs
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( if pad_offsets is not None:
output_grad, inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs_and_unpad(
row_id_map, output_grad,
fwd_input, row_id_map,
merging_probs, fwd_input,
num_tokens, merging_probs,
num_experts, pad_offsets,
num_out_tokens, num_tokens,
hidden_size, num_experts,
) num_out_tokens,
hidden_size,
)
# The backward kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs(
output_grad,
row_id_map,
fwd_input,
merging_probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
else: else:
# Simple case: just permute gradients back # Simple case: just permute gradients back
inp_grad, _ = permute_with_mask_map( if pad_offsets is not None:
output_grad, inp_grad, _ = permute_with_mask_map_and_pad(
row_id_map, output_grad,
None, row_id_map,
num_tokens, None,
num_experts, pad_offsets,
num_out_tokens, num_tokens,
hidden_size, num_experts,
) num_out_tokens,
hidden_size,
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad)
else:
inp_grad, _ = permute_with_mask_map(
output_grad,
row_id_map,
None,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
merging_probs_grad = None merging_probs_grad = None
return inp_grad, merging_probs_grad # Return gradients for: inp, row_id_map, merging_probs, pad_offsets
# row_id_map and pad_offsets are integer arrays, so their gradients are None
return inp_grad, None, merging_probs_grad, None
_token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule) _token_combine.defvjp(_token_combine_fwd_rule, _token_combine_bwd_rule)
......
...@@ -142,17 +142,31 @@ def compile_triton( ...@@ -142,17 +142,31 @@ def compile_triton(
) )
# Create kernel object for JAX # Create kernel object for JAX
kernel = gpu_triton.TritonKernel( # From jax/jaxlib/gpu/triton_kernels.cc:
compiled.name, from packaging import version
num_warps,
compiled.metadata.shared, if version.parse(jax.__version__) >= version.parse("0.8.2"):
compiled.asm["ptx"], kernel = gpu_triton.TritonKernel(
"", # ttir compiled.name, # arg0: kernel_name (str)
compute_capability, num_warps, # arg1: num_warps (int)
1, num_ctas, # arg2: num_ctas (int)
1, compiled.metadata.shared, # arg3: shared_mem_bytes (int)
1, # cluster_dims compiled.asm["ptx"], # arg4: ptx (str)
) "", # arg5: ttir (str) - empty
compute_capability, # arg6: compute_capability (int)
)
else:
kernel = gpu_triton.TritonKernel(
compiled.name,
num_warps,
compiled.metadata.shared,
compiled.asm["ptx"],
"", # ttir
compute_capability,
1,
1,
1,
)
_TRITON_KERNEL_CACHE[cache_key] = kernel _TRITON_KERNEL_CACHE[cache_key] = kernel
return kernel return kernel
......
...@@ -34,6 +34,7 @@ from transformer_engine.pytorch.transformer import TransformerLayer ...@@ -34,6 +34,7 @@ from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import ( from transformer_engine.pytorch.permutation import (
moe_permute, moe_permute,
moe_permute_with_probs, moe_permute_with_probs,
moe_permute_and_pad_with_probs,
moe_unpermute, moe_unpermute,
moe_sort_chunks_by_index, moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs, moe_sort_chunks_by_index_with_probs,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""MoE Permutaion API""" """MoE Permutation API"""
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
...@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -191,6 +191,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
routing_map: torch.Tensor, routing_map: torch.Tensor,
num_out_tokens: int, num_out_tokens: int,
probs: torch.Tensor, probs: torch.Tensor,
pad_offsets: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not inp.numel(): if not inp.numel():
...@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -201,6 +202,8 @@ class _moe_permute_mask_map(torch.autograd.Function):
assert routing_map.is_cuda, "TransformerEngine needs CUDA." assert routing_map.is_cuda, "TransformerEngine needs CUDA."
if probs is not None: if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA." assert probs.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible" assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size() num_tokens, hidden_size = inp.size()
...@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -250,6 +253,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map, row_id_map,
probs, probs,
fp8_scale, fp8_scale,
pad_offsets,
num_tokens, num_tokens,
num_experts, num_experts,
num_out_tokens, num_out_tokens,
...@@ -292,7 +296,7 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -292,7 +296,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
requires_grad=output.requires_grad, requires_grad=output.requires_grad,
) )
ctx.save_for_backward(row_id_map) ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts ctx.num_experts = num_experts
ctx.num_tokens = num_tokens ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size ctx.hidden_size = hidden_size
...@@ -307,12 +311,12 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -307,12 +311,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not permuted_act_grad.numel(): if not permuted_act_grad.numel():
return permuted_act_grad, None, None, ctx.probs return permuted_act_grad, None, None, ctx.probs, None
act_grad = None act_grad = None
probs_grad = None probs_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors row_id_map, pad_offsets = ctx.saved_tensors
assert not isinstance( assert not isinstance(
permuted_act_grad, QuantizedTensor permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8." ), "The backward of moe_permute does not support FP8."
...@@ -321,13 +325,14 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -321,13 +325,14 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map, row_id_map,
None, None,
permuted_probs_grad, permuted_probs_grad,
pad_offsets,
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.hidden_size, ctx.hidden_size,
) )
if not ctx.needs_input_grad[3]: if not ctx.needs_input_grad[3]:
probs_grad = None probs_grad = None
return act_grad, None, None, probs_grad return act_grad, None, None, probs_grad, None
class _moe_unpermute_mask_map(torch.autograd.Function): class _moe_unpermute_mask_map(torch.autograd.Function):
...@@ -340,6 +345,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -340,6 +345,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: Optional[torch.Tensor], merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size], restore_shape: Optional[torch.Size],
pad_offsets: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not inp.numel(): if not inp.numel():
...@@ -358,6 +364,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -358,6 +364,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
# Device check # Device check
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
if pad_offsets is not None:
assert pad_offsets.is_cuda, "TransformerEngine needs CUDA."
assert not isinstance( assert not isinstance(
inp, QuantizedTensor inp, QuantizedTensor
...@@ -367,15 +375,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -367,15 +375,16 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map, row_id_map,
merging_probs, merging_probs,
None, None,
pad_offsets,
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
) )
if with_probs: if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs) ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets)
else: else:
ctx.save_for_backward(row_id_map) ctx.save_for_backward(row_id_map, pad_offsets)
ctx.num_experts = num_experts ctx.num_experts = num_experts
ctx.num_tokens = num_tokens ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0) ctx.num_permuted_tokens = inp.size(0)
...@@ -387,15 +396,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -387,15 +396,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def backward(ctx, unpermuted_act_grad): def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel(): if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.merging_probs, None return unpermuted_act_grad, None, ctx.merging_probs, None, None
act_grad = None act_grad = None
probs_grad = None probs_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.with_probs: if ctx.with_probs:
fwd_input, row_id_map, merging_probs = ctx.saved_tensors fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors
else: else:
(row_id_map,) = ctx.saved_tensors row_id_map, pad_offsets = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
...@@ -441,6 +450,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -441,6 +450,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map, row_id_map,
fwd_input, fwd_input,
merging_probs, merging_probs,
pad_offsets,
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
...@@ -453,6 +463,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -453,6 +463,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
row_id_map, row_id_map,
None, None,
fp8_scale, fp8_scale,
pad_offsets,
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
...@@ -497,7 +508,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -497,7 +508,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
probs_grad = None probs_grad = None
return act_grad, None, probs_grad, None return act_grad, None, probs_grad, None, None
def moe_permute( def moe_permute(
...@@ -537,7 +548,9 @@ def moe_permute( ...@@ -537,7 +548,9 @@ def moe_permute(
if map_type == "index": if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask": if map_type == "mask":
output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) output, row_id_map, _ = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, None, None
)
return output, row_id_map return output, row_id_map
raise ValueError("map_type should be one of 'mask' or 'index'") raise ValueError("map_type should be one of 'mask' or 'index'")
...@@ -570,11 +583,67 @@ def moe_permute_with_probs( ...@@ -570,11 +583,67 @@ def moe_permute_with_probs(
By default, set to '-1', meaning no tokens are dropped. By default, set to '-1', meaning no tokens are dropped.
""" """
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, probs inp, routing_map, num_out_tokens, probs, None
) )
return output, permuted_probs, row_id_map return output, permuted_probs, row_id_map
def moe_permute_and_pad_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
tokens_per_expert: torch.Tensor,
align_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
tokens_per_expert : torch.Tensor
Tensor of shape `[num_experts]` containing actual token counts per expert.
align_size : int
the alignment size for the input tensor.
"""
assert (
tokens_per_expert is not None
), "tokens_per_expert must be provided to the fused permute padding function."
assert align_size > 0, f"align_size must be positive, got {align_size}"
# Ensure tokens_per_expert is on the same device as input to avoid device transfers
if tokens_per_expert.device != inp.device:
tokens_per_expert = tokens_per_expert.to(inp.device)
# Calculate aligned token counts per expert
target_tokens_per_expert = (torch.ceil(tokens_per_expert / align_size) * align_size).long()
if torch.equal(tokens_per_expert, target_tokens_per_expert):
pad_offsets = None
else:
pad_lengths = target_tokens_per_expert - tokens_per_expert
cum_pad = torch.cumsum(pad_lengths, dim=0)
pad_offsets = torch.cat(
[torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]]
)
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
def moe_unpermute( def moe_unpermute(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
...@@ -582,6 +651,7 @@ def moe_unpermute( ...@@ -582,6 +651,7 @@ def moe_unpermute(
restore_shape: Optional[torch.Size] = None, restore_shape: Optional[torch.Size] = None,
map_type: str = "mask", map_type: str = "mask",
probs: Optional[torch.Tensor] = None, probs: Optional[torch.Tensor] = None,
pad_offsets: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...@@ -605,6 +675,10 @@ def moe_unpermute( ...@@ -605,6 +675,10 @@ def moe_unpermute(
Options are: 'mask', 'index'. Options are: 'mask', 'index'.
probs : torch.Tensor, default = None probs : torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility. Renamed to merging_probs. Keep for backward compatibility.
pad_offsets : torch.Tensor, default = None
Tensor of per-expert cumulative padding offsets used to remove padding added
during permutation. This is the fourth output of `moe_permute_and_pad_with_probs`
and is required when unpermuting padded outputs.
""" """
if probs is not None: if probs is not None:
if merging_probs is not None: if merging_probs is not None:
...@@ -616,7 +690,9 @@ def moe_unpermute( ...@@ -616,7 +690,9 @@ def moe_unpermute(
if map_type == "index": if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
if map_type == "mask": if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) return _moe_unpermute_mask_map.apply(
inp, row_id_map, merging_probs, restore_shape, pad_offsets
)
raise ValueError("map_type should be one of 'mask' or 'index'") raise ValueError("map_type should be one of 'mask' or 'index'")
......
...@@ -123,6 +123,7 @@ def permute_with_mask_map( ...@@ -123,6 +123,7 @@ def permute_with_mask_map(
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor, probs: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
pad_offsets: torch.Tensor,
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
...@@ -142,6 +143,9 @@ def permute_with_mask_map( ...@@ -142,6 +143,9 @@ def permute_with_mask_map(
The probabilities of the input tensor. If it is not None, it will be permuted. The probabilities of the input tensor. If it is not None, it will be permuted.
scale : torch.Tensor scale : torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted. The scale of the input tensor. If it is not None, it will be permuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int num_tokens : int
Number of tokens in the input tensor. Number of tokens in the input tensor.
num_experts : int num_experts : int
...@@ -153,18 +157,18 @@ def permute_with_mask_map( ...@@ -153,18 +157,18 @@ def permute_with_mask_map(
scale_hidden_dim : int scale_hidden_dim : int
Hidden size of the scale tensor. Hidden size of the scale tensor.
""" """
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
if probs is not None: # since the kernel doesn't write to padding positions.
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") alloc = torch.zeros if pad_offsets is not None else torch.empty
else: output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = None permuted_probs = (
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
if scale is not None: )
permuted_scale = torch.empty( permuted_scale = (
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda" torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
) if scale is not None
else: else None
permuted_scale = None )
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid]( _permute_kernel[grid](
...@@ -173,6 +177,7 @@ def permute_with_mask_map( ...@@ -173,6 +177,7 @@ def permute_with_mask_map(
probs, probs,
scale, scale,
permuted_scale, permuted_scale,
pad_offsets,
scale_hidden_dim, scale_hidden_dim,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
...@@ -193,6 +198,7 @@ def permute_with_mask_map( ...@@ -193,6 +198,7 @@ def permute_with_mask_map(
hidden_size, hidden_size,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None, PERMUTE_SCALE=scale is not None,
FUSION_PAD=pad_offsets is not None,
) )
return output, permuted_scale, permuted_probs return output, permuted_scale, permuted_probs
...@@ -202,6 +208,7 @@ def unpermute_with_mask_map( ...@@ -202,6 +208,7 @@ def unpermute_with_mask_map(
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: Union[torch.Tensor, None], merging_probs: Union[torch.Tensor, None],
permuted_probs: Union[torch.Tensor, None], permuted_probs: Union[torch.Tensor, None],
pad_offsets: Union[torch.Tensor, None],
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
...@@ -220,6 +227,9 @@ def unpermute_with_mask_map( ...@@ -220,6 +227,9 @@ def unpermute_with_mask_map(
to reduce the unpermuted tokens. to reduce the unpermuted tokens.
permuted_probs : torch.Tensor permuted_probs : torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused unpadding.
If it is not None, it will remove the previously fused padding.
num_tokens : int num_tokens : int
Number of tokens in the permuted tensor. Number of tokens in the permuted tensor.
num_experts : int num_experts : int
...@@ -241,6 +251,7 @@ def unpermute_with_mask_map( ...@@ -241,6 +251,7 @@ def unpermute_with_mask_map(
row_id_map, row_id_map,
merging_probs, merging_probs,
permuted_probs, permuted_probs,
pad_offsets,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
inp.stride(0), inp.stride(0),
...@@ -259,6 +270,7 @@ def unpermute_with_mask_map( ...@@ -259,6 +270,7 @@ def unpermute_with_mask_map(
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None, WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None, PERMUTE_PROBS=permuted_probs is not None,
FUSION_UNPAD=pad_offsets is not None,
) )
return output, unpermuted_probs return output, unpermuted_probs
...@@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -268,6 +280,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
fwd_input: torch.Tensor, fwd_input: torch.Tensor,
merging_probs: torch.Tensor, merging_probs: torch.Tensor,
pad_offsets: Union[torch.Tensor, None],
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
...@@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -286,6 +299,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs : torch.Tensor merging_probs : torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
pad_offsets : torch.Tensor
Per-expert padding offsets of shape `[num_experts]` for FP8 fused padding.
If it is not None, it will be allocated output buffers with aligned sizes.
num_tokens : int num_tokens : int
Number of tokens in the permuted tensor. Number of tokens in the permuted tensor.
num_experts : int num_experts : int
...@@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -295,9 +311,11 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
hidden_size : int hidden_size : int
Hidden size of the output tensor. Hidden size of the output tensor.
""" """
act_grad = torch.empty( # Use zeros when pad_offsets is used because padding slots won't be written to
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" # by the kernel. This matches the behavior of Fp8Unpadding.backward which zeros
) # out the padding slots.
alloc = torch.zeros if pad_offsets is not None else torch.empty
act_grad = alloc((num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda")
merging_probs_grad = torch.empty( merging_probs_grad = torch.empty(
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
) )
...@@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -307,6 +325,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_input, fwd_input,
merging_probs, merging_probs,
row_id_map, row_id_map,
pad_offsets,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
fwd_output_grad.stride(0), fwd_output_grad.stride(0),
...@@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -324,6 +343,7 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts, num_experts,
hidden_size, hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
FUSION_UNPAD=pad_offsets is not None,
) )
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment