Unverified Commit 702fc5ee authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Fix 50% comparison mismatch in sort_chunks_by_index (#2566)



* force initialization to int32
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* address greptile comment
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 404a3ee0
......@@ -97,7 +97,9 @@ def reference_make_row_id_map(
# Compute total tokens per expert and expert offsets
tokens_per_expert = jnp.sum(routing_map, axis=0)
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
expert_offsets = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
)
# Compute destination rows for all (token, expert) pairs
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
......@@ -115,7 +117,9 @@ def reference_make_row_id_map(
# Gather the sorted destination rows and expert indices using advanced indexing
# Create indices for gathering
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
token_idx = jnp.broadcast_to(
jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
)
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]
# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
......@@ -373,11 +377,15 @@ def reference_make_chunk_sort_map(
Row ID map for chunk sorting of shape [num_tokens,].
"""
# Compute source chunk boundaries (cumulative sum of original split_sizes)
src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
src_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
)
# Compute destination chunk boundaries based on sorted order
sorted_sizes = split_sizes[sorted_indices]
dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)])
dest_cumsum = jnp.concatenate(
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
)
# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
......@@ -386,7 +394,7 @@ def reference_make_chunk_sort_map(
# Create row_id_map: for each token position, compute its destination
# First, figure out which chunk each position belongs to
position_indices = jnp.arange(num_tokens)
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment