Unverified Commit f414f906 authored by JartX's avatar JartX Committed by GitHub
Browse files

[Bugfix][Kernel][ROCm] Fix triton_w4a16 scales mismatch when BLOCK_K > group_size (#39705)


Signed-off-by: default avatarJartX <sagformas@epdcenter.es>
parent 8625ec26
...@@ -235,6 +235,14 @@ def triton_w4a16_gemm( ...@@ -235,6 +235,14 @@ def triton_w4a16_gemm(
else: else:
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32 BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
# The kernel loads scales/zeros for a single group per BLOCK_K tile
# (one g_idx per iteration). If BLOCK_K > group_size, rows at the tail
# of the tile dequantize with the wrong group's scales, silently
# corrupting the output. Clamp BLOCK_K to group_size to keep one
# scale group per tile.
if group_size < BLOCK_K:
BLOCK_K = group_size
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
triton_w4a16_gemm_kernel[grid]( triton_w4a16_gemm_kernel[grid](
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment