Unverified Commit 0056b981 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[PyTorch] Change arguments order in triton kernels to make jax-triton work (#2416)



* Change order of arguments to make jax works
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* make num_experts a tl.constepxr again
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent f8cb598c
...@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr): ...@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr):
@triton.jit @triton.jit
def _row_id_map_pass_1_kernel( def _row_id_map_pass_1_kernel(
# pointers # input pointers
routing_map_ptr, routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes # sizes
num_tokens, num_tokens,
# strides # strides
...@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel( ...@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel(
stride_routing_map_expert, stride_routing_map_expert,
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
# output pointers
row_id_map_ptr,
workspace_ptr,
# metas # metas
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
...@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel( ...@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def _row_id_map_pass_3_kernel( def _row_id_map_pass_3_kernel(
# pointers # pointers
row_id_map_ptr, row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
# metas # metas
num_experts: tl.constexpr,
LOAD_SIZE: tl.constexpr, LOAD_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -194,17 +194,13 @@ def _row_id_map_pass_3_kernel( ...@@ -194,17 +194,13 @@ def _row_id_map_pass_3_kernel(
@triton.jit @triton.jit
def _permute_kernel( def _permute_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
scale_ptr, scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr, permuted_scale_ptr,
# sizes # sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim, scale_hidden_dim,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
...@@ -220,7 +216,12 @@ def _permute_kernel( ...@@ -220,7 +216,12 @@ def _permute_kernel(
stride_permuted_probs_token, stride_permuted_probs_token,
stride_permuted_scale_token, stride_permuted_scale_token,
stride_permuted_scale_hidden, stride_permuted_scale_hidden,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr, PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
...@@ -291,16 +292,11 @@ except RuntimeError: ...@@ -291,16 +292,11 @@ except RuntimeError:
@triton.jit @triton.jit
def _unpermute_kernel( def _unpermute_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
merging_probs_ptr, merging_probs_ptr,
permuted_probs_ptr, permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -313,7 +309,12 @@ def _unpermute_kernel( ...@@ -313,7 +309,12 @@ def _unpermute_kernel(
stride_permuted_probs_token, stride_permuted_probs_token,
stride_unpermuted_probs_token, stride_unpermuted_probs_token,
stride_unpermuted_probs_expert, stride_unpermuted_probs_expert,
# output pointers
output_ptr,
unpermuted_probs_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
...@@ -546,14 +547,10 @@ def _make_chunk_sort_map_kernel( ...@@ -546,14 +547,10 @@ def _make_chunk_sort_map_kernel(
@triton.jit @triton.jit
def _sort_chunks_by_map_kernel( def _sort_chunks_by_map_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides # strides
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
...@@ -561,7 +558,11 @@ def _sort_chunks_by_map_kernel( ...@@ -561,7 +558,11 @@ def _sort_chunks_by_map_kernel(
stride_output_hidden, stride_output_hidden,
stride_probs_token, stride_probs_token,
stride_permuted_probs_token, stride_permuted_probs_token,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas # metas
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr, FORWARD: tl.constexpr,
......
...@@ -72,13 +72,13 @@ def make_row_id_map( ...@@ -72,13 +72,13 @@ def make_row_id_map(
# [0, 0, 0, r, r, r, r]] # [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel[grid]( _row_id_map_pass_1_kernel[grid](
routing_map, routing_map,
row_id_map,
workspace_tensor,
num_tokens, num_tokens,
routing_map.stride(0), routing_map.stride(0),
routing_map.stride(1), routing_map.stride(1),
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
row_id_map,
workspace_tensor,
block_size, block_size,
) )
...@@ -110,9 +110,9 @@ def make_row_id_map( ...@@ -110,9 +110,9 @@ def make_row_id_map(
grid = (num_tokens,) grid = (num_tokens,)
_row_id_map_pass_3_kernel[grid]( _row_id_map_pass_3_kernel[grid](
row_id_map, row_id_map,
num_experts,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
num_experts,
triton.next_power_of_2(num_experts), triton.next_power_of_2(num_experts),
) )
return row_id_map return row_id_map
...@@ -169,14 +169,10 @@ def permute_with_mask_map( ...@@ -169,14 +169,10 @@ def permute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid]( _permute_kernel[grid](
inp, inp,
output,
row_id_map, row_id_map,
probs, probs,
scale, scale,
permuted_probs,
permuted_scale, permuted_scale,
num_experts,
hidden_size,
scale_hidden_dim, scale_hidden_dim,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
...@@ -191,6 +187,10 @@ def permute_with_mask_map( ...@@ -191,6 +187,10 @@ def permute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None, permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None, permuted_scale.stride(1) if permuted_scale is not None else None,
output,
permuted_probs,
num_experts,
hidden_size,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None, PERMUTE_SCALE=scale is not None,
) )
...@@ -238,13 +238,9 @@ def unpermute_with_mask_map( ...@@ -238,13 +238,9 @@ def unpermute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_unpermute_kernel[grid]( _unpermute_kernel[grid](
inp, inp,
output,
row_id_map, row_id_map,
merging_probs, merging_probs,
permuted_probs, permuted_probs,
unpermuted_probs,
num_experts,
hidden_size,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
inp.stride(0), inp.stride(0),
...@@ -256,6 +252,10 @@ def unpermute_with_mask_map( ...@@ -256,6 +252,10 @@ def unpermute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None, unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
output,
unpermuted_probs,
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None, WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None, PERMUTE_PROBS=permuted_probs is not None,
...@@ -395,17 +395,17 @@ def sort_chunks_by_map( ...@@ -395,17 +395,17 @@ def sort_chunks_by_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"])) grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_sort_chunks_by_map_kernel[grid]( _sort_chunks_by_map_kernel[grid](
inp, inp,
output,
row_id_map, row_id_map,
probs, probs,
permuted_probs,
hidden_size,
inp.stride(0), inp.stride(0),
inp.stride(1), inp.stride(1),
output.stride(0), output.stride(0),
output.stride(1), output.stride(1),
probs.stride(0) if probs is not None else None, probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
output,
permuted_probs,
hidden_size,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
FORWARD=is_forward, FORWARD=is_forward,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment