Unverified Commit 3d46bf61 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Permutation to always return group_size/tokens_per_expert (#2613)



return tokens_per_experts always
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 8bf37f0e
...@@ -52,7 +52,7 @@ def token_dispatch( ...@@ -52,7 +52,7 @@ def token_dispatch(
Optional[jnp.ndarray], Optional[jnp.ndarray],
jnp.ndarray, jnp.ndarray,
Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray], jnp.ndarray,
]: ]:
""" """
Dispatch tokens to experts based on routing map. Dispatch tokens to experts based on routing map.
...@@ -101,9 +101,11 @@ def token_dispatch( ...@@ -101,9 +101,11 @@ def token_dispatch(
pad_offsets : Optional[jnp.ndarray] pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding, Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed. None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray] tokens_per_expert : jnp.ndarray
Aligned token counts per expert of shape [num_experts] when using padding, Token counts per expert of shape [num_experts]:
None otherwise. - Without padding: actual token counts (sum of routing_map columns)
- With padding: aligned token counts (ceil(actual / align_size) * align_size)
This gives the effective number of tokens per expert in the output buffer.
Note Note
---- ----
...@@ -151,10 +153,10 @@ def _token_dispatch( ...@@ -151,10 +153,10 @@ def _token_dispatch(
Optional[jnp.ndarray], Optional[jnp.ndarray],
jnp.ndarray, jnp.ndarray,
Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray], jnp.ndarray,
]: ]:
"""Internal token_dispatch with custom VJP.""" """Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = ( (output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = (
_token_dispatch_fwd_rule( _token_dispatch_fwd_rule(
inp, inp,
routing_map, routing_map,
...@@ -165,7 +167,7 @@ def _token_dispatch( ...@@ -165,7 +167,7 @@ def _token_dispatch(
use_padding, use_padding,
) )
) )
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert
def _token_dispatch_fwd_rule( def _token_dispatch_fwd_rule(
...@@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule( ...@@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule(
Optional[jnp.ndarray], Optional[jnp.ndarray],
jnp.ndarray, jnp.ndarray,
Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray], jnp.ndarray,
], ],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]: ]:
...@@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule( ...@@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None with_probs = probs is not None
if use_padding: # Compute tokens_per_expert from routing_map (actual counts)
# Compute tokens_per_expert internally from routing_map # This is well-optimized by XLA as a simple column-wise reduction
# This can be a traced value since output shape uses worst_case_out_tokens tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
if use_padding:
# Calculate aligned token counts per expert # Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype( target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32 jnp.int32
...@@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule( ...@@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule(
hidden_size, hidden_size,
align_size=align_size, align_size=align_size,
) )
# Return aligned counts when using padding
out_tokens_per_expert = target_tokens_per_expert
else: else:
# No padding # No padding
pad_offsets = None pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map( output, permuted_probs = permute_with_mask_map(
inp, inp,
...@@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule( ...@@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule(
hidden_size, hidden_size,
) )
# Return actual counts when not using padding
out_tokens_per_expert = tokens_per_expert
# Return (primals, residuals) # Return (primals, residuals)
# out_tokens_per_expert is:
# - target_tokens_per_expert (aligned) when using padding
# - tokens_per_expert (actual) when not using padding
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs) residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return ( return (
output, output,
permuted_probs, permuted_probs,
row_id_map, row_id_map,
pad_offsets, pad_offsets,
target_tokens_per_expert, out_tokens_per_expert,
), residuals ), residuals
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment