Unverified Commit e05f87e1 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[PyTorch] Change order of args in another permutation triton kernel (#2488)



change order
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 8ef3a33d
...@@ -402,16 +402,11 @@ except RuntimeError: ...@@ -402,16 +402,11 @@ except RuntimeError:
@triton.jit @triton.jit
def _unpermute_bwd_with_merging_probs_kernel( def _unpermute_bwd_with_merging_probs_kernel(
# pointers # input pointers
fwd_output_grad_ptr, fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr, fwd_input_ptr,
merging_probs_ptr, merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr, row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -425,7 +420,12 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -425,7 +420,12 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_expert, stride_merging_probs_expert,
stride_merging_probs_grad_token, stride_merging_probs_grad_token,
stride_merging_probs_grad_expert, stride_merging_probs_grad_expert,
# output pointers
fwd_input_grad_ptr,
merging_probs_grad_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
......
...@@ -304,13 +304,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -304,13 +304,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
grid = (num_tokens,) grid = (num_tokens,)
_unpermute_bwd_with_merging_probs_kernel[grid]( _unpermute_bwd_with_merging_probs_kernel[grid](
fwd_output_grad, fwd_output_grad,
act_grad,
fwd_input, fwd_input,
merging_probs, merging_probs,
merging_probs_grad,
row_id_map, row_id_map,
num_experts,
hidden_size,
row_id_map.stride(0), row_id_map.stride(0),
row_id_map.stride(1), row_id_map.stride(1),
fwd_output_grad.stride(0), fwd_output_grad.stride(0),
...@@ -323,6 +319,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -323,6 +319,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1), merging_probs.stride(1),
merging_probs_grad.stride(0), merging_probs_grad.stride(0),
merging_probs_grad.stride(1), merging_probs_grad.stride(1),
act_grad,
merging_probs_grad,
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
) )
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment