Unverified Commit caef45b5 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Enhance the robustness and generality of MLA examples (#709)

* Enhance format script to automatically install tools in need

* Add judgement for small `h_q` in MLA decode examples to improve robustness

* Allow scale as a param in MLA decode examples for better generality

* Fix typo
parent 6664d170
...@@ -8,8 +8,9 @@ import argparse ...@@ -8,8 +8,9 @@ import argparse
@tilelang.jit(out_idx=[6]) @tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
...@@ -288,10 +289,12 @@ def main( ...@@ -288,10 +289,12 @@ def main(
pv_flops = 2 * batch * heads * kv_ctx * dim pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = 64 BLOCK_H = min(64, heads // kv_heads)
num_split = 1 num_split = 1
softmax_scale = (dim + pe_dim)**-0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split) kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split,
softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500) latency = profiler.do_bench(warmup=500)
......
...@@ -9,8 +9,8 @@ import math ...@@ -9,8 +9,8 @@ import math
@tilelang.jit(out_idx=[8]) @tilelang.jit(out_idx=[8])
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size): block_size, softmax_scale):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = h_q // h_kv kv_group_num = h_q // h_kv
...@@ -318,12 +318,13 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -318,12 +318,13 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
dpe = d - dv dpe = d - dv
num_kv_splits = 1 num_kv_splits = 1
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = 64 BLOCK_H = min(64, h_q // h_kv)
softmax_scale = (d + dv)**-0.5
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size) num_kv_splits, block_size, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang(): def flash_mla_tilelang():
......
...@@ -44,7 +44,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -44,7 +44,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
......
...@@ -18,6 +18,11 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")" ...@@ -18,6 +18,11 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)" ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1 builtin cd "$ROOT" || exit 1
# If yapf/ruff/codespell is not installed, install according to the requirements
if ! (yapf --version &>/dev/null && ruff --version &>/dev/null && codespell --version &>/dev/null); then
pip install -r requirements-lint.txt
fi
YAPF_VERSION=$(yapf --version | awk '{print $2}') YAPF_VERSION=$(yapf --version | awk '{print $2}')
RUFF_VERSION=$(ruff --version | awk '{print $2}') RUFF_VERSION=$(ruff --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version) CODESPELL_VERSION=$(codespell --version)
...@@ -26,7 +31,7 @@ CODESPELL_VERSION=$(codespell --version) ...@@ -26,7 +31,7 @@ CODESPELL_VERSION=$(codespell --version)
tool_version_check() { tool_version_check() {
if [[ $2 != $3 ]]; then if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2." echo "Wrong $1 version installed: $3 is required, not $2."
exit 1 pip install -r requirements-lint.txt
fi fi
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment