"testing/vscode:/vscode.git/clone" did not exist on "a58bf9b6c63e945928153d151f2ae927cbc20dc4"
Unverified Commit caef45b5 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Enhance the robustness and generality of MLA examples (#709)

* Enhance format script to automatically install tools in need

* Add judgement for small `h_q` in MLA decode examples to improve robustness

* Allow scale as a param in MLA decode examples for better generality

* Fix typo
parent 6664d170
......@@ -8,8 +8,9 @@ import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
......@@ -288,10 +289,12 @@ def main(
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
BLOCK_H = min(64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim)**-0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split,
softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
......
......@@ -9,8 +9,8 @@ import math
@tilelang.jit(out_idx=[8])
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
block_size, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = h_q // h_kv
......@@ -318,12 +318,13 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = 64
BLOCK_H = min(64, h_q // h_kv)
softmax_scale = (d + dv)**-0.5
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
num_kv_splits, block_size, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
......
......@@ -44,7 +44,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro
def MMA1(
V: T.Tensor(shape_kv, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
......
......@@ -18,6 +18,11 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
# If yapf/ruff/codespell is not installed, install according to the requirements
if ! (yapf --version &>/dev/null && ruff --version &>/dev/null && codespell --version &>/dev/null); then
pip install -r requirements-lint.txt
fi
YAPF_VERSION=$(yapf --version | awk '{print $2}')
RUFF_VERSION=$(ruff --version | awk '{print $2}')
CODESPELL_VERSION=$(codespell --version)
......@@ -26,7 +31,7 @@ CODESPELL_VERSION=$(codespell --version)
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
pip install -r requirements-lint.txt
fi
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment