Unverified Commit e05f87e1 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[PyTorch] Change order of args in another permutation triton kernel (#2488)



change order
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 8ef3a33d
......@@ -402,16 +402,11 @@ except RuntimeError:
@triton.jit
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
# input pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -425,7 +420,12 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# output pointers
fwd_input_grad_ptr,
merging_probs_grad_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
......
......@@ -304,13 +304,9 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
grid = (num_tokens,)
_unpermute_bwd_with_merging_probs_kernel[grid](
fwd_output_grad,
act_grad,
fwd_input,
merging_probs,
merging_probs_grad,
row_id_map,
num_experts,
hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
fwd_output_grad.stride(0),
......@@ -323,6 +319,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1),
merging_probs_grad.stride(0),
merging_probs_grad.stride(1),
act_grad,
merging_probs_grad,
num_experts,
hidden_size,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
)
return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment