Unverified Commit d7982daf authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Bugfix] Fix fused MoE IMA (sans chunking) by using int64 for strides (#34279)


Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 9b17c574
...@@ -95,19 +95,19 @@ def fused_moe_kernel_gptq_awq( ...@@ -95,19 +95,19 @@ def fused_moe_kernel_gptq_awq(
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am: tl.int64,
stride_ak, stride_ak: tl.int64,
stride_be, stride_be: tl.int64,
stride_bk, stride_bk: tl.int64,
stride_bn, stride_bn: tl.int64,
stride_cm, stride_cm: tl.int64,
stride_cn, stride_cn: tl.int64,
stride_bse, stride_bse: tl.int64,
stride_bsk, stride_bsk: tl.int64,
stride_bsn, stride_bsn: tl.int64,
stride_bze, stride_bze: tl.int64,
stride_bzk, stride_bzk: tl.int64,
stride_bzn, stride_bzn: tl.int64,
block_k_diviable: tl.constexpr, block_k_diviable: tl.constexpr,
group_size: tl.constexpr, group_size: tl.constexpr,
# Meta-parameters # Meta-parameters
...@@ -329,20 +329,20 @@ def fused_moe_kernel( ...@@ -329,20 +329,20 @@ def fused_moe_kernel(
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am: tl.int64,
stride_ak, stride_ak: tl.int64,
stride_be, stride_be: tl.int64,
stride_bk, stride_bk: tl.int64,
stride_bn, stride_bn: tl.int64,
stride_cm, stride_cm: tl.int64,
stride_cn, stride_cn: tl.int64,
stride_asm, stride_asm: tl.int64,
stride_ask, stride_ask: tl.int64,
stride_bse, stride_bse: tl.int64,
stride_bsk, stride_bsk: tl.int64,
stride_bsn, stride_bsn: tl.int64,
stride_bbe, # bias expert stride stride_bbe: tl.int64, # bias expert stride
stride_bbn, # bias N stride stride_bbn: tl.int64, # bias N stride
# Block size for block-wise quantization # Block size for block-wise quantization
group_n: tl.constexpr, group_n: tl.constexpr,
group_k: tl.constexpr, group_k: tl.constexpr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment