Unverified Commit 08dc786c authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Fix 50% comparison mismatch in sort_chunks_by_index (Cont.) (#2575)



* force initialization to int32
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* address greptile comment
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* del useless comments, add more restriction to int32
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent de51c96b
......@@ -19,11 +19,6 @@ from transformer_engine.jax.permutation import (
from utils import assert_allclose, pytest_parametrize_wrapper
# =============================================================================
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
# =============================================================================
# All dispatch/combine test cases
ALL_DISPATCH_COMBINE_CASES = [
(128, 5, 128, 3),
(1024, 8, 128, 8),
......@@ -35,7 +30,6 @@ DISPATCH_COMBINE_CASES = {
"L2": ALL_DISPATCH_COMBINE_CASES,
}
# All sort chunks test cases
ALL_SORT_CHUNKS_CASES = [
(8, 4096, 1280),
(64, 4096, 4096),
......@@ -46,7 +40,6 @@ SORT_CHUNKS_CASES = {
"L2": ALL_SORT_CHUNKS_CASES,
}
# All dispatch/combine with padding test cases
ALL_DISPATCH_COMBINE_PADDING_CASES = [
(128, 5, 128, 3, 8),
(1024, 8, 128, 8, 16),
......@@ -58,14 +51,12 @@ DISPATCH_COMBINE_PADDING_CASES = {
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
}
# Dtypes for testing
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
DTYPES = {
"L0": ALL_DTYPES,
"L2": ALL_DTYPES,
}
# With probs options
ALL_WITH_PROBS = [True, False]
WITH_PROBS = {
"L0": [True],
......@@ -389,7 +380,7 @@ def reference_make_chunk_sort_map(
# For each source chunk, compute its destination offset
# inverse_indices[i] = position of chunk i in sorted order
inverse_indices = jnp.argsort(sorted_indices)
inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32)
dest_offsets = dest_cumsum[inverse_indices]
# Create row_id_map: for each token position, compute its destination
......@@ -397,7 +388,7 @@ def reference_make_chunk_sort_map(
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
# chunk_ids[i] = which chunk position i belongs to
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32)
# within_chunk_offset[i] = position i's offset within its chunk
within_chunk_offset = position_indices - src_cumsum[chunk_ids]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment