Unverified Commit a652730f authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Custom partitioning for Permutation primitives (#2591)



* initial impl, not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* consolidate different unpermute primitives with with_pad and with_merging_probs booleans. Implement partitioning for all permutation primitives
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add distributed test for non-padding permutation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix issues in distributed test for padding permutation. Make common kernel zero intiialize output permuted scales, permuted probs and output tokens
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert zeroing in triton common kernel as it is a race condition. Instead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix utils to handle input output aliasing in autotuned kernels
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Clean up comments, and add more comments explaining input output alias in utils
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and greptile comment
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix issues that lint fixing introduced
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6a34b657
This diff is collapsed.
......@@ -201,8 +201,15 @@ def _permute_kernel(
scale_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr, # pylint: disable=unused-argument
permuted_probs_buf_ptr, # pylint: disable=unused-argument
# sizes
scale_hidden_dim,
num_tokens, # pylint: disable=unused-argument
num_out_tokens, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -228,12 +235,17 @@ def _permute_kernel(
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
......@@ -306,6 +318,10 @@ def _unpermute_kernel(
merging_probs_ptr,
permuted_probs_ptr,
pad_offsets_ptr,
# Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
# These are unused in the unpermute kernel but maintain consistency with the permute kernel.
output_buf_ptr, # pylint: disable=unused-argument
unpermuted_probs_buf_ptr, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......
......@@ -137,7 +137,7 @@ def token_dispatch(
)
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6))
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
......@@ -240,6 +240,7 @@ def _token_dispatch_fwd_rule(
num_experts,
worst_case_out_tokens,
hidden_size,
align_size=align_size,
)
else:
# No padding
......@@ -268,7 +269,6 @@ def _token_dispatch_fwd_rule(
def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int,
_worst_case_out_tokens: int,
_align_size: Optional[int],
......@@ -281,8 +281,12 @@ def _token_dispatch_bwd_rule(
Optional[jnp.ndarray],
Optional[jnp.ndarray],
],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch."""
) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch.
Returns gradients for (inp, routing_map, probs).
routing_map gradient is None since it's a discrete routing decision.
"""
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
......@@ -309,7 +313,9 @@ def _token_dispatch_bwd_rule(
hidden_size,
)
return inp_grad, probs_grad if with_probs else None
# Return gradients for (inp, routing_map, probs)
# routing_map is non-differentiable (discrete routing), so return None
return inp_grad, None, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
......@@ -497,6 +503,8 @@ def _token_combine_bwd_rule(
else:
# Simple case: just permute gradients back
if pad_offsets is not None:
# Note: align_size uses default (128) since buffer sizes are already
# determined from forward pass (stored in residuals as num_out_tokens)
inp_grad, _ = permute_with_mask_map_and_pad(
output_grad,
row_id_map,
......@@ -506,6 +514,7 @@ def _token_combine_bwd_rule(
num_experts,
num_out_tokens,
hidden_size,
align_size=128, # Default, sizes already computed in forward
)
# The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros.
......
......@@ -409,7 +409,8 @@ def triton_call_lowering(
kernel_constexprs = constexprs if constexprs is not None else {}
# Handle autotuned kernels - compile all configs
if isinstance(kernel_fn, autotuner.Autotuner):
is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
if is_autotuned:
# Compile all configs for runtime selection
kernel_calls = []
actual_kernel_fn = kernel_fn.fn
......@@ -450,24 +451,23 @@ def triton_call_lowering(
kernel_calls.append((config_call, str(config)))
# Create autotuned kernel call
# Convert input_output_aliases to format with sizes
if input_output_aliases is None:
input_output_aliases = {}
input_output_aliases_with_sizes = tuple(
(
input_idx,
output_idx,
ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize,
)
for input_idx, output_idx in input_output_aliases.items()
)
# IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
#
# Background:
# 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
# output can reuse an input's buffer. XLA may or may not honor this.
# 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
# save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
#
# The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
# but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
# to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
#
# WAR: Don't pass aliases to TritonAutotunedKernelCall.
kernel_call = gpu_triton.TritonAutotunedKernelCall(
f"{actual_kernel_fn.__name__}_autotuned",
kernel_calls,
input_output_aliases_with_sizes,
(), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
)
else:
......@@ -498,15 +498,17 @@ def triton_call_lowering(
serialized_metadata = b""
call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)
if input_output_aliases is None:
input_output_aliases = {}
if input_output_aliases:
ffi_operand_output_aliases = input_output_aliases
else:
ffi_operand_output_aliases = None
# Use JAX FFI lowering with compressed protobuf
rule = jax.ffi.ffi_lowering(
"triton_kernel_call", # Custom call target registered in gpu_triton.py
api_version=2,
backend_config=zlib.compress(call_proto),
operand_output_aliases=input_output_aliases,
operand_output_aliases=ffi_operand_output_aliases,
)
return rule(ctx, *array_args)
......@@ -157,8 +157,8 @@ def permute_with_mask_map(
scale_hidden_dim : int
Hidden size of the scale tensor.
"""
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed,
# since the kernel doesn't write to padding positions.
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed.
# The kernel writes only to valid positions, leaving padding positions at zero.
alloc = torch.zeros if pad_offsets is not None else torch.empty
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = (
......@@ -178,7 +178,13 @@ def permute_with_mask_map(
scale,
permuted_scale,
pad_offsets,
# Pass output buffers as input parameters (for JAX input_output_aliases compatibility).
# In PyTorch, these point to the same memory as the output pointers below.
output,
permuted_probs,
scale_hidden_dim,
num_tokens,
num_out_tokens,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......@@ -252,6 +258,10 @@ def unpermute_with_mask_map(
merging_probs,
permuted_probs,
pad_offsets,
# Dummy buffer parameters for kernel signature consistency with _permute_kernel.
# These are unused in unpermute but maintain consistent interface.
output, # output_buf_ptr (unused, passed for signature consistency)
unpermuted_probs, # unpermuted_probs_buf_ptr (unused, passed for signature consistency)
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment