Unverified Commit 3d46bf61 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Permutation to always return group_size/tokens_per_expert (#2613)



return tokens_per_experts always
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 8bf37f0e
......@@ -52,7 +52,7 @@ def token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""
Dispatch tokens to experts based on routing map.
......@@ -101,9 +101,11 @@ def token_dispatch(
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
tokens_per_expert : jnp.ndarray
Token counts per expert of shape [num_experts]:
- Without padding: actual token counts (sum of routing_map columns)
- With padding: aligned token counts (ceil(actual / align_size) * align_size)
This gives the effective number of tokens per expert in the output buffer.
Note
----
......@@ -151,10 +153,10 @@ def _token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
(output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
......@@ -165,7 +167,7 @@ def _token_dispatch(
use_padding,
)
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert
def _token_dispatch_fwd_rule(
......@@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
......@@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None
if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
# Compute tokens_per_expert from routing_map (actual counts)
# This is well-optimized by XLA as a simple column-wise reduction
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
if use_padding:
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
......@@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule(
hidden_size,
align_size=align_size,
)
# Return aligned counts when using padding
out_tokens_per_expert = target_tokens_per_expert
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
......@@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule(
hidden_size,
)
# Return actual counts when not using padding
out_tokens_per_expert = tokens_per_expert
# Return (primals, residuals)
# out_tokens_per_expert is:
# - target_tokens_per_expert (aligned) when using padding
# - tokens_per_expert (actual) when not using padding
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
out_tokens_per_expert,
), residuals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment