Commit 146eb9d3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla to obtain the optimal configuration from config

parent c1370857
......@@ -39,6 +39,17 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
......@@ -736,6 +747,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for "
"TritonMLAImpl")
self.attn_configs = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill(
self,
q: torch.Tensor,
......@@ -789,13 +802,16 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# TODO
for bs in self.attn_configs.keys():
for mean_seq_len in self.attn_configs[bs].keys():
best_config = get_config(bs, mean_seq_len, self.attn_configs)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, # config,
attn_metadata.num_kv_splits, self.scale, best_config,
PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
......@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped(
# opt
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
)
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# ],
# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
# )
@triton.jit
def _decode_v1_kernel_stage1_use_tc(
Q,
......@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc(
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
)
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
)
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
# triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
# ],
# key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
# )
@triton.jit
def _decode_v1_kernel_stage2_use_tc(
logits,
......@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc(
page_size,
num_kv_splits,
logit_cap,
best_config,
):
Lk = k_buffer.shape[-1]
......@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_N = best_config['BLOCK_N']
SPLIT_K = num_kv_splits # best_config['SPLIT_K'] ?
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: (
batch,
......@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc(
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK_N,
BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=num_stages,
Lk=Lk,
kpack=2,
)
return _decode_v1_kernel_stage1_use_tc.best_config
# return _decode_v1_kernel_stage1_use_tc.best_config
def _decode_v1_stage2_use_tc(
......@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc(
b_start_loc,
b_seq_len,
page_size,
best_config,
):
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
......@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc(
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size,
Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
)
return _decode_v1_kernel_stage2_use_tc.best_config
# return _decode_v1_kernel_stage2_use_tc.best_config
def decode_attention_v1(
......@@ -1003,11 +1019,36 @@ def decode_attention_v1(
attn_logits,
num_kv_splits,
sm_scale,
best_config,
page_size,
logit_cap=0.0,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
# _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
# q,
# k_buffer,
# attn_logits,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# sm_scale,
# page_size,
# num_kv_splits,
# logit_cap,
# )
# _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
# attn_logits,
# v_buffer,
# o,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# page_size,
# )
# return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
_decode_v1_stage1_use_tc(
q,
k_buffer,
attn_logits,
......@@ -1019,8 +1060,9 @@ def decode_attention_v1(
page_size,
num_kv_splits,
logit_cap,
best_config['stage1'],
)
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
_decode_v1_stage2_use_tc(
attn_logits,
v_buffer,
o,
......@@ -1029,31 +1071,31 @@ def decode_attention_v1(
b_start_loc,
b_seq_len,
page_size,
best_config['stage2'],
)
return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
)
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
# ],
# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
# )
@triton.jit
def _decode_v2_kernel_stage1_use_tc(
Q,
......@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc(
B_Seqlen,
num_kv_splits,
sm_scale,
best_config,
page_size,
logit_cap,
):
BLOCK = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
......@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc(
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
kpack=2,
)
return _decode_v2_kernel_stage1_use_tc.best_config
@triton.autotune(
configs=[
triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1),
triton.Config({}, num_warps=8, num_stages=1),
],
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
)
# return _decode_v2_kernel_stage1_use_tc.best_config
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=1, num_stages=1),
# triton.Config({}, num_warps=2, num_stages=1),
# triton.Config({}, num_warps=4, num_stages=1),
# triton.Config({}, num_warps=8, num_stages=1),
# ],
# key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
# )
@triton.jit
def _decode_v2_kernel_stage2(
Mid_O,
......@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc(
v_buffer,
b_seq_len,
num_kv_splits,
best_config,
):
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
......@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
)
return _decode_v2_kernel_stage2.best_config
# return _decode_v2_kernel_stage2.best_config
def decode_attention_v2(
......@@ -1401,10 +1456,26 @@ def decode_attention_v2(
attn_logits,
num_kv_splits,
sm_scale,
best_config,
page_size,
logit_cap=0.0,
):
_decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
# _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
# q,
# k_buffer,
# v_buffer,
# attn_logits,
# req_to_token,
# # b_req_idx,
# b_seq_len,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
# _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
# return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
_decode_v2_stage1_use_tc(
q,
k_buffer,
v_buffer,
......@@ -1416,9 +1487,9 @@ def decode_attention_v2(
sm_scale,
page_size,
logit_cap,
best_config['stage1'],
)
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
_decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config['stage2'])
def decode_attention_fwd(
......@@ -1431,7 +1502,7 @@ def decode_attention_fwd(
attn_logits,
num_kv_splits,
sm_scale,
# config,
best_config,
page_size=1,
logit_cap=0.0,
):
......@@ -1456,6 +1527,7 @@ def decode_attention_fwd(
else:
# GQA/MQA/MLA
if envs.VLLM_USE_TRITON_OPT_MLA:
'''
decode_attention_v2(
q,
k_buffer,
......@@ -1469,62 +1541,62 @@ def decode_attention_fwd(
page_size,
logit_cap,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
# if best_config['kernel_kind'] == 'v1_2stages_tc':
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# elif best_config['kernel_kind'] == 'v2_tc':
# decode_attention_v2(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_seq_len,
# attn_logits,
# num_kv_splits,
# sm_scale,
# config,
# page_size,
# logit_cap,
# )
# else:
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16,
device="cuda")
decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits, # sub
sm_scale,
page_size,
logit_cap,
)'''
if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16,
device="cuda")
decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits,
sm_scale,
best_config=best_config['best_config'],
page_size=page_size,
logit_cap=logit_cap,
)
elif best_config['kernel_kind'] == 'v2_tc':
decode_attention_v2(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
best_config=best_config['best_config'],
page_size=page_size,
logit_cap=logit_cap,
)
else:
print("Unknown mla kernel kind: ", best_config['kernel_kind'])
else:
decode_attention_fwd_grouped(
q,
......
......@@ -80,7 +80,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration',
support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration',
'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM',
'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment