Unverified Commit a652730f authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Custom partitioning for Permutation primitives (#2591)



* initial impl, not tested
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* consolidate different unpermute primitives with with_pad and with_merging_probs booleans. Implement partitioning for all permutation primitives
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Add distributed test for non-padding permutation
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix issues in distributed test for padding permutation. Make common kernel zero intiialize output permuted scales, permuted probs and output tokens
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert zeroing in triton common kernel as it is a race condition. Instead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* fix utils to handle input output aliasing in autotuned kernels
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Clean up comments, and add more comments explaining input output alias in utils
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and greptile comment
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix issues that lint fixing introduced
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6a34b657
This diff is collapsed.
...@@ -201,8 +201,15 @@ def _permute_kernel( ...@@ -201,8 +201,15 @@ def _permute_kernel(
scale_ptr, scale_ptr,
permuted_scale_ptr, permuted_scale_ptr,
pad_offsets_ptr, pad_offsets_ptr,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr, # pylint: disable=unused-argument
permuted_probs_buf_ptr, # pylint: disable=unused-argument
# sizes # sizes
scale_hidden_dim, scale_hidden_dim,
num_tokens, # pylint: disable=unused-argument
num_out_tokens, # pylint: disable=unused-argument
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -228,12 +235,17 @@ def _permute_kernel( ...@@ -228,12 +235,17 @@ def _permute_kernel(
FUSION_PAD: tl.constexpr, FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx = 0 expert_idx = 0
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64) src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
...@@ -306,6 +318,10 @@ def _unpermute_kernel( ...@@ -306,6 +318,10 @@ def _unpermute_kernel(
merging_probs_ptr, merging_probs_ptr,
permuted_probs_ptr, permuted_probs_ptr,
pad_offsets_ptr, pad_offsets_ptr,
# Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
# These are unused in the unpermute kernel but maintain consistency with the permute kernel.
output_buf_ptr, # pylint: disable=unused-argument
unpermuted_probs_buf_ptr, # pylint: disable=unused-argument
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
......
...@@ -137,7 +137,7 @@ def token_dispatch( ...@@ -137,7 +137,7 @@ def token_dispatch(
) )
@partial(jax.custom_vjp, nondiff_argnums=(1, 3, 4, 5, 6)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _token_dispatch( def _token_dispatch(
inp: jnp.ndarray, inp: jnp.ndarray,
routing_map: jnp.ndarray, routing_map: jnp.ndarray,
...@@ -240,6 +240,7 @@ def _token_dispatch_fwd_rule( ...@@ -240,6 +240,7 @@ def _token_dispatch_fwd_rule(
num_experts, num_experts,
worst_case_out_tokens, worst_case_out_tokens,
hidden_size, hidden_size,
align_size=align_size,
) )
else: else:
# No padding # No padding
...@@ -268,7 +269,6 @@ def _token_dispatch_fwd_rule( ...@@ -268,7 +269,6 @@ def _token_dispatch_fwd_rule(
def _token_dispatch_bwd_rule( def _token_dispatch_bwd_rule(
_routing_map: jnp.ndarray,
_num_out_tokens: int, _num_out_tokens: int,
_worst_case_out_tokens: int, _worst_case_out_tokens: int,
_align_size: Optional[int], _align_size: Optional[int],
...@@ -281,8 +281,12 @@ def _token_dispatch_bwd_rule( ...@@ -281,8 +281,12 @@ def _token_dispatch_bwd_rule(
Optional[jnp.ndarray], Optional[jnp.ndarray],
Optional[jnp.ndarray], Optional[jnp.ndarray],
], ],
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: ) -> Tuple[jnp.ndarray, None, Optional[jnp.ndarray]]:
"""Backward pass rule for token_dispatch.""" """Backward pass rule for token_dispatch.
Returns gradients for (inp, routing_map, probs).
routing_map gradient is None since it's a discrete routing decision.
"""
row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs = residuals
output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads output_grad, permuted_probs_grad, _, _, _ = g # Ignore row_id_map, pad_offsets, target grads
...@@ -309,7 +313,9 @@ def _token_dispatch_bwd_rule( ...@@ -309,7 +313,9 @@ def _token_dispatch_bwd_rule(
hidden_size, hidden_size,
) )
return inp_grad, probs_grad if with_probs else None # Return gradients for (inp, routing_map, probs)
# routing_map is non-differentiable (discrete routing), so return None
return inp_grad, None, probs_grad if with_probs else None
_token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule) _token_dispatch.defvjp(_token_dispatch_fwd_rule, _token_dispatch_bwd_rule)
...@@ -497,6 +503,8 @@ def _token_combine_bwd_rule( ...@@ -497,6 +503,8 @@ def _token_combine_bwd_rule(
else: else:
# Simple case: just permute gradients back # Simple case: just permute gradients back
if pad_offsets is not None: if pad_offsets is not None:
# Note: align_size uses default (128) since buffer sizes are already
# determined from forward pass (stored in residuals as num_out_tokens)
inp_grad, _ = permute_with_mask_map_and_pad( inp_grad, _ = permute_with_mask_map_and_pad(
output_grad, output_grad,
row_id_map, row_id_map,
...@@ -506,6 +514,7 @@ def _token_combine_bwd_rule( ...@@ -506,6 +514,7 @@ def _token_combine_bwd_rule(
num_experts, num_experts,
num_out_tokens, num_out_tokens,
hidden_size, hidden_size,
align_size=128, # Default, sizes already computed in forward
) )
# The permute kernel only writes to positions that tokens map to. # The permute kernel only writes to positions that tokens map to.
# Padded positions may contain uninitialized (NaN) values - replace with zeros. # Padded positions may contain uninitialized (NaN) values - replace with zeros.
......
...@@ -409,7 +409,8 @@ def triton_call_lowering( ...@@ -409,7 +409,8 @@ def triton_call_lowering(
kernel_constexprs = constexprs if constexprs is not None else {} kernel_constexprs = constexprs if constexprs is not None else {}
# Handle autotuned kernels - compile all configs # Handle autotuned kernels - compile all configs
if isinstance(kernel_fn, autotuner.Autotuner): is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
if is_autotuned:
# Compile all configs for runtime selection # Compile all configs for runtime selection
kernel_calls = [] kernel_calls = []
actual_kernel_fn = kernel_fn.fn actual_kernel_fn = kernel_fn.fn
...@@ -450,24 +451,23 @@ def triton_call_lowering( ...@@ -450,24 +451,23 @@ def triton_call_lowering(
kernel_calls.append((config_call, str(config))) kernel_calls.append((config_call, str(config)))
# Create autotuned kernel call # IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
# Convert input_output_aliases to format with sizes #
if input_output_aliases is None: # Background:
input_output_aliases = {} # 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
# output can reuse an input's buffer. XLA may or may not honor this.
input_output_aliases_with_sizes = tuple( # 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
( # save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
input_idx, #
output_idx, # The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize, # but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
) # to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
for input_idx, output_idx in input_output_aliases.items() #
) # WAR: Don't pass aliases to TritonAutotunedKernelCall.
kernel_call = gpu_triton.TritonAutotunedKernelCall( kernel_call = gpu_triton.TritonAutotunedKernelCall(
f"{actual_kernel_fn.__name__}_autotuned", f"{actual_kernel_fn.__name__}_autotuned",
kernel_calls, kernel_calls,
input_output_aliases_with_sizes, (), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
) )
else: else:
...@@ -498,15 +498,17 @@ def triton_call_lowering( ...@@ -498,15 +498,17 @@ def triton_call_lowering(
serialized_metadata = b"" serialized_metadata = b""
call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata) call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)
if input_output_aliases is None: if input_output_aliases:
input_output_aliases = {} ffi_operand_output_aliases = input_output_aliases
else:
ffi_operand_output_aliases = None
# Use JAX FFI lowering with compressed protobuf # Use JAX FFI lowering with compressed protobuf
rule = jax.ffi.ffi_lowering( rule = jax.ffi.ffi_lowering(
"triton_kernel_call", # Custom call target registered in gpu_triton.py "triton_kernel_call", # Custom call target registered in gpu_triton.py
api_version=2, api_version=2,
backend_config=zlib.compress(call_proto), backend_config=zlib.compress(call_proto),
operand_output_aliases=input_output_aliases, operand_output_aliases=ffi_operand_output_aliases,
) )
return rule(ctx, *array_args) return rule(ctx, *array_args)
...@@ -157,8 +157,8 @@ def permute_with_mask_map( ...@@ -157,8 +157,8 @@ def permute_with_mask_map(
scale_hidden_dim : int scale_hidden_dim : int
Hidden size of the scale tensor. Hidden size of the scale tensor.
""" """
# Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed, # Use torch.zeros when pad_offsets is provided to ensure padding regions are zeroed.
# since the kernel doesn't write to padding positions. # The kernel writes only to valid positions, leaving padding positions at zero.
alloc = torch.zeros if pad_offsets is not None else torch.empty alloc = torch.zeros if pad_offsets is not None else torch.empty
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
permuted_probs = ( permuted_probs = (
...@@ -178,7 +178,13 @@ def permute_with_mask_map( ...@@ -178,7 +178,13 @@ def permute_with_mask_map(
scale, scale,
permuted_scale, permuted_scale,
pad_offsets, pad_offsets,
# Pass output buffers as input parameters (for JAX input_output_aliases compatibility).
# In PyTorch, these point to the same memory as the output pointers below.
output,
permuted_probs,
scale_hidden_dim, scale_hidden_dim,
num_tokens,
num_out_tokens,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
inp.stride(0), inp.stride(0),
...@@ -252,6 +258,10 @@ def unpermute_with_mask_map( ...@@ -252,6 +258,10 @@ def unpermute_with_mask_map(
merging_probs, merging_probs,
permuted_probs, permuted_probs,
pad_offsets, pad_offsets,
# Dummy buffer parameters for kernel signature consistency with _permute_kernel.
# These are unused in unpermute but maintain consistent interface.
output, # output_buf_ptr (unused, passed for signature consistency)
unpermuted_probs, # unpermuted_probs_buf_ptr (unused, passed for signature consistency)
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
inp.stride(0), inp.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment