Unverified Commit 0056b981 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[PyTorch] Change arguments order in triton kernels to make jax-triton work (#2416)



* Change order of arguments to make jax works
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* make num_experts a tl.constepxr again
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent f8cb598c
......@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr):
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
# input pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
......@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel(
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# output pointers
row_id_map_ptr,
workspace_ptr,
# metas
BLOCK_SIZE: tl.constexpr,
):
......@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
num_experts: tl.constexpr,
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -194,17 +194,13 @@ def _row_id_map_pass_3_kernel(
@triton.jit
def _permute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
......@@ -220,7 +216,12 @@ def _permute_kernel(
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
......@@ -291,16 +292,11 @@ except RuntimeError:
@triton.jit
def _unpermute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -313,7 +309,12 @@ def _unpermute_kernel(
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# output pointers
output_ptr,
unpermuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
......@@ -546,14 +547,10 @@ def _make_chunk_sort_map_kernel(
@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
......@@ -561,7 +558,11 @@ def _sort_chunks_by_map_kernel(
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
......
......@@ -72,13 +72,13 @@ def make_row_id_map(
# [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel[grid](
routing_map,
row_id_map,
workspace_tensor,
num_tokens,
routing_map.stride(0),
routing_map.stride(1),
row_id_map.stride(0),
row_id_map.stride(1),
row_id_map,
workspace_tensor,
block_size,
)
......@@ -110,9 +110,9 @@ def make_row_id_map(
grid = (num_tokens,)
_row_id_map_pass_3_kernel[grid](
row_id_map,
num_experts,
row_id_map.stride(0),
row_id_map.stride(1),
num_experts,
triton.next_power_of_2(num_experts),
)
return row_id_map
......@@ -169,14 +169,10 @@ def permute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid](
inp,
output,
row_id_map,
probs,
scale,
permuted_probs,
permuted_scale,
num_experts,
hidden_size,
scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
......@@ -191,6 +187,10 @@ def permute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
output,
permuted_probs,
num_experts,
hidden_size,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
)
......@@ -238,13 +238,9 @@ def unpermute_with_mask_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_unpermute_kernel[grid](
inp,
output,
row_id_map,
merging_probs,
permuted_probs,
unpermuted_probs,
num_experts,
hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
......@@ -256,6 +252,10 @@ def unpermute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
output,
unpermuted_probs,
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
......@@ -395,17 +395,17 @@ def sort_chunks_by_map(
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_sort_chunks_by_map_kernel[grid](
inp,
output,
row_id_map,
probs,
permuted_probs,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
output,
permuted_probs,
hidden_size,
PERMUTE_PROBS=probs is not None,
FORWARD=is_forward,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment