Unverified Commit 5d112e3c authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] TE Permutation integration to Maxtext (#2672)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adding more stuff missing from cherry picky jeremy PR for inspecting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix some tracing issues when intergating to maxtext
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Have sort_chunks_by_index handle situations where input buffer is larger than num tokens
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* remove unnecessary assert and comments
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove Jeremy's PR for inspect ffi
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* untouch the amax file, also change comment on te
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Signed-off-by: default avatarJAX Toolbox <jax@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarJAX Toolbox <jax@nvidia.com>
parent f8449052
...@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel( ...@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32) ).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes) input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens = tl.sum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask) input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
...@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel( ...@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).to(tl.int32) ).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row = tl.where(pid < total_valid_tokens, dst_row, pid)
tl.store(dst_rows_ptr + pid, dst_row) tl.store(dst_rows_ptr + pid, dst_row)
......
...@@ -581,7 +581,7 @@ def sort_chunks_by_index( ...@@ -581,7 +581,7 @@ def sort_chunks_by_index(
return _sort_chunks_by_index(inp, split_sizes, sorted_indices) return _sort_chunks_by_index(inp, split_sizes, sorted_indices)
@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) @jax.custom_vjp
def _sort_chunks_by_index( def _sort_chunks_by_index(
inp: jnp.ndarray, inp: jnp.ndarray,
split_sizes: jnp.ndarray, split_sizes: jnp.ndarray,
...@@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule( ...@@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray, inp: jnp.ndarray,
split_sizes: jnp.ndarray, split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray, sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]: ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index.""" """Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions # Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D" assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
...@@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule( ...@@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule(
) )
# Return (primals, residuals) # Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size) # Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums
residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size)
return (output, row_id_map), residuals return (output, row_id_map), residuals
def _sort_chunks_by_index_bwd_rule( def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray, residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int],
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray], g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]: ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index.""" """Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals
output_grad, _ = g output_grad, _ = g
# Backward: reverse the sort # Backward: reverse the sort
...@@ -642,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule( ...@@ -642,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule(
is_forward=False, is_forward=False,
) )
return (inp_grad,) # Return gradients for all inputs: (inp, split_sizes, sorted_indices)
# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype)
sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype)
return (inp_grad, split_sizes_grad, sorted_indices_grad)
_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment