Commit 18c2244b authored by zhuwenwen's avatar zhuwenwen
Browse files

add mla tuning script and configs

parent ac811e51
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
parser.add_argument('--QH', type=int, required=True, help='Value for QH')
parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
parser.add_argument('--VD', type=int, required=True, help='Value for VD')
parser.add_argument('--TP', type=int, required=True, help='Value for TP')
return parser.parse_args()
def generate_B():
# 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
B_values = [1, 4] + [i * 8 for i in range(1, 17)]
# B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
return B_values
def generate_SEQ_bs1_32(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_bs40_96(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 200))
SEQ_values_part2 = list(range(5000, seq_end, 300))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_other(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
B_values = generate_B()
SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
SEQ_values_other = generate_SEQ_other(seq_start, seq_end)
pa_tables = []
for B in B_values:
if B <= 32:
for SEQ in SEQ_values_b1_32:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
elif B <= 96 and B > 32:
for SEQ in SEQ_values_b40_96:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
else:
for SEQ in SEQ_values_other:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
return pa_tables
if __name__ == "__main__":
args = parse_args()
assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)
for t in tuple_sizes:
print(f"{t},")
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
import os
import json
import pytest
import torch
import triton
from triton_decode_attention import decode_attentionv1_fwd, decode_attentionv2_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [1])
# @pytest.mark.parametrize("L", [100])
# @pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000])
@pytest.mark.parametrize("L", [1,100,400,700,1000,1300,1600,1900,2200,2500,2800,3100,3400,3700,4000,4300,4600,4900,5000,5500,6000,6500,7000,7500,8000,8500,9000,9500,10000,10500,11000,11500,12000,12500,13000,13500,14000,14500,15000,15500,16000,16500,17000,17500,18000,18500,19000,19500,20000,20500,21000,21500,22000,22500,23000,23500,24000,24500,25000,25500,26000,26500,27000,27500,28000,28500,29000,29500,30000,30500,31000,31500,32000,32500])
@pytest.mark.parametrize("H_Q", [4, 8, 16])
@pytest.mark.parametrize("H_KV", [1])
@pytest.mark.parametrize("D_QK", [576])
@pytest.mark.parametrize("D_V", [512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 4
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) #这里为向上取整,65,(1027+16-1)//16
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1), #shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
b_req_idx = torch.arange(B, device="cuda").to(torch.int32)
# Call the original implementation.
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
configs = {
"v2_tc": {"stage1": {}, "stage2": {}},
"v1_2stages_tc": {"stage1": {}, "stage2": {}},
}
ms = {
"v1_2stages_tc": 10000.0,
"v2_tc": 10000.0,
}
final_best_config = {
"kernel_kind": "",
"best_config": {},
"best_us": 0.0,
}
v2_tc_stage1_best_config, v2_tc_stage2_best_config = decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2)
quantiles = [0.5, 0.2, 0.8]
v2_tc_ms, v2_tc_min_ms, v2_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv2_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v2_tc_stage1_best_config.kwargs.items():
configs["v2_tc"]["stage1"][key] = value
configs["v2_tc"]["stage1"]["num_stages"] = v2_tc_stage1_best_config.num_stages
configs["v2_tc"]["stage1"]["num_warps"] = v2_tc_stage1_best_config.num_warps
for key, value in v2_tc_stage2_best_config.kwargs.items():
configs["v2_tc"]["stage2"][key] = value
configs["v2_tc"]["stage2"]["num_stages"] = v2_tc_stage2_best_config.num_stages
configs["v2_tc"]["stage2"]["num_warps"] = v2_tc_stage2_best_config.num_warps
ms["v2_tc"] = v2_tc_ms
print(f"v2_tc best configs is {configs['v2_tc']}")
print("print mla decode attention v2 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v2_tc_ms)
o2 = torch.zeros_like(o)
v1_tc_stage1_best_config, v1_tc_stage2_best_config = decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o2,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
v1_tc_ms, v1_tc_min_ms, v1_tc_max_ms = triton.testing.do_bench(lambda:
decode_attentionv1_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
), quantiles=quantiles)
for key, value in v1_tc_stage1_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage1"][key] = value
configs["v1_2stages_tc"]["stage1"]["num_stages"] = v1_tc_stage1_best_config.num_stages
configs["v1_2stages_tc"]["stage1"]["num_warps"] = v1_tc_stage1_best_config.num_warps
configs["v1_2stages_tc"]["stage1"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
for key, value in v1_tc_stage2_best_config.kwargs.items():
configs["v1_2stages_tc"]["stage2"][key] = value
configs["v1_2stages_tc"]["stage2"]["num_stages"] = v1_tc_stage2_best_config.num_stages
configs["v1_2stages_tc"]["stage2"]["num_warps"] = v1_tc_stage2_best_config.num_warps
configs["v1_2stages_tc"]["stage2"]["num_ldmatrixes"] = v1_tc_stage1_best_config.num_ldmatrixes
ms["v1_2stages_tc"] = v1_tc_ms
min_key, min_ms = min(ms.items(), key=lambda x: x[1])
final_best_config["kernel_kind"] = min_key
final_best_config["best_config"] = configs[min_key]
final_best_config["best_us"] = min_ms * 1000
print(f"v1_2stages_tc best configs is {configs['v1_2stages_tc']}")
print("print mla decode attention v1 kernel [B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE] min cost :",[B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE], v1_tc_ms)
print(f"Tuned_decode_attention choose {min_key} kernel, min cost {min_ms} ms, best config of {min_key} kernel is {configs[min_key]}")
assert torch.allclose(o, o2, atol=1e-2, rtol=1e-2)
#**************save config**************#
batch = b_req_idx.shape[0]
mean_seq_len = int((b_seq_len.sum() / max(1, batch)).item())
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_K100AI.json"
elif "BW" in device_name:
# return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
file_name = f"QH={H_Q}_KVH={H_KV}_QKD={D_QK}_VD={D_V}_fp16_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
if os.path.exists(file_name):
with open(file_name, 'r') as file:
config_info = json.load(file)
else:
config_info = {}
# 如果 config_info 中没有当前的 batch,初始化它为一个空字典
# if f"{batch}" not in config_info:
# config_info[f"{batch}"] = {}
# 把新的 mean_seq_len 配置加入到当前 batch 中
# config_info[f"{batch}"][f"{mean_seq_len}"] = final_best_config
config_info[f"{mean_seq_len}"] = final_best_config
# 保存最佳配置
with open(file_name, 'w') as file:
json.dump(config_info, file, indent=1)
#**************save config**************#
This diff is collapsed.
......@@ -9,14 +9,11 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
class KERNLE_KINDS(Enum):
v1 = 0
v1_tc = 1
v1_tc_2 = 2
v1_2stages = 3
v1_2stages_tc = 4
v2 = 5
v2_tc = 6
TOTAL_KIND = 7
v1_2stages = 0
v1_2stages_tc = 1
v2 = 2
v2_tc = 3
TOTAL_KIND = 4
class BestConfig():
def __init__(self):
......@@ -24,8 +21,9 @@ class BestConfig():
self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0
self.BLOCK_SEQ = 0
self.SPLIT_K = 0
self.BLOCK_DIM = 0
# self.BLOCK_SEQ = 0
# self.SPLIT_K = 0
self.num_stages = 0
self.num_warps = 0
self.NUM_KV_SPLITS = 0
......@@ -49,19 +47,6 @@ def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_config_file_name(QH: int, KVH: int, D: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_D={D}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
......@@ -73,7 +58,7 @@ def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_d
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
......@@ -107,29 +92,11 @@ def get_config_map(attention_configs):
int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len]
configs = BestConfig()
configs.batch_size = int_bs
configs.seq_len = int_seq_len
# configs.batch_size = int_bs
# configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc_2':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc_2
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages':
if kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
......@@ -158,10 +125,10 @@ def get_config_map(attention_configs):
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
# if 'BLOCK_SEQ' in stage1:
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
# else:
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
......@@ -172,11 +139,12 @@ def get_config_map(attention_configs):
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
# if 'BLOCK_SEQ' in stage1:
# configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
# else:
# configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.BLOCK_DIM = stage1['BLOCK_DIM']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
......
# SPDX-License-Identifier: Apache-2.0
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -7,7 +8,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
......@@ -689,7 +690,7 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for "
"TritonMLAImpl")
self.attn_configs = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill(
self,
......@@ -745,9 +746,13 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO
for bs in self.attn_configs.keys():
for mean_seq_len in self.attn_configs[bs].keys():
best_config = get_config(bs, mean_seq_len, self.attn_configs)
max_seq_len = torch.max(decode_meta.seq_lens_tensor).item()
if os.environ.get('PA_MATCH_USE_MEAN_SEQ') == '1':
match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item())
else:
match_seq_len = max_seq_len
best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))]
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment