Unverified Commit 32de54ed authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung Committed by GitHub
Browse files

[ROCm] Fix fp8 unrolledx4 matmul kernel. (#3325)


Co-authored-by: default avatarHAI <hixiao@gmail.com>
parent 2d9c3195
......@@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# manually unroll to 4 iterations
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
UNROLL_FACTOR = 4
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)):
# 1st iteration
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(
a_ptrs,
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
other=0.0,
)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
other=0.0,
)
k_start = k * BLOCK_SIZE_K
k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
......@@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs += BLOCK_SIZE_K * stride_bk
# 2nd iteration
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(
a_ptrs,
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
other=0.0,
)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
other=0.0,
)
k_start = k_start + BLOCK_SIZE_K
offs_ks = k_start // group_k
......@@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs += BLOCK_SIZE_K * stride_bk
# 3rd iteration
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(
a_ptrs,
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
other=0.0,
)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
other=0.0,
)
k_start = k_start + BLOCK_SIZE_K
offs_ks = k_start // group_k
......@@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs += BLOCK_SIZE_K * stride_bk
# 4th iteration
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(
a_ptrs,
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
other=0.0,
)
b = tl.load(
b_ptrs,
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
other=0.0,
)
k_start = k_start + BLOCK_SIZE_K
offs_ks = k_start // group_k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment