Unverified Commit 087ab832 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Triton] Optimize over mask compute to tl.load in fused_moe_kernel (#1980)

parent 8169c6f4
...@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd( ...@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
num_warps = 4 num_warps = 4
extra_kargs = {}
if is_hip():
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_grouped_kernel_stage1[grid]( _fwd_grouped_kernel_stage1[grid](
q, q,
k_buffer, k_buffer,
...@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd( ...@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk, Lk=Lk,
**extra_kargs,
) )
......
...@@ -54,6 +54,7 @@ def fused_moe_kernel( ...@@ -54,6 +54,7 @@ def fused_moe_kernel(
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8: tl.constexpr, use_fp8: tl.constexpr,
even_Ks: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -130,16 +131,24 @@ def fused_moe_kernel( ...@@ -130,16 +131,24 @@ def fused_moe_kernel(
# of fp32 values for higher accuracy. # of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop. # `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
if even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)
else:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0, other=0.0,
) )
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel( ...@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
) )
K = B.shape[2] - padding_size
if K % config["BLOCK_SIZE_K"] == 0:
even_ks = True
else:
even_ks = False
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
B, B,
...@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel( ...@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8, use_fp8=use_fp8,
even_Ks=even_ks,
**config, **config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment