Commit eee6148a authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla to obtain the optimal configuration from config

parent abac3adc
...@@ -40,6 +40,18 @@ from vllm.logger import init_logger ...@@ -40,6 +40,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str: def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default": if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json" return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
...@@ -737,6 +749,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -737,6 +749,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for " "are not implemented for "
"TritonMLAImpl") "TritonMLAImpl")
self.attn_configs = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill( def _forward_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO # TODO
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16") for bs in self.attn_configs.keys():
for mean_seq_len in self.attn_configs[bs].keys():
best_config = get_config(bs, mean_seq_len, self.attn_configs)
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits, decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, # config, attn_metadata.num_kv_splits, self.scale, best_config,
PAGE_SIZE) PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)
...@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped( ...@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped(
# opt # opt
@triton.autotune( # @triton.autotune(
configs=[ # configs=[
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
], # ],
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"] # key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
) # )
@triton.jit @triton.jit
def _decode_v1_kernel_stage1_use_tc( def _decode_v1_kernel_stage1_use_tc(
Q, Q,
...@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc( ...@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc(
mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
) )
@triton.autotune( # @triton.autotune(
configs=[ # configs=[
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1), # triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
], # ],
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"] # key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
) # )
@triton.jit @triton.jit
def _decode_v1_kernel_stage2_use_tc( def _decode_v1_kernel_stage2_use_tc(
logits, logits,
...@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc( ...@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc(
page_size, page_size,
num_kv_splits, num_kv_splits,
logit_cap, logit_cap,
best_config,
): ):
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc( ...@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1] # batch, head_num = B_req_idx.shape[0], q.shape[1]
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
SPLIT_K = num_kv_splits
BLOCK_N = best_config['BLOCK_N']
SPLIT_K = num_kv_splits # best_config['SPLIT_K'] ?
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: ( grid = lambda META: (
batch, batch,
...@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc( ...@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc(
q_head_num=head_num, q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK_N,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
SPLIT_K=SPLIT_K, SPLIT_K=SPLIT_K,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps,
num_stages=num_stages,
Lk=Lk, Lk=Lk,
kpack=2, kpack=2,
) )
return _decode_v1_kernel_stage1_use_tc.best_config # return _decode_v1_kernel_stage1_use_tc.best_config
def _decode_v1_stage2_use_tc( def _decode_v1_stage2_use_tc(
...@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc( ...@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
page_size, page_size,
best_config,
): ):
batch, head_num = b_seq_len.shape[0], logits.shape[0] batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2] kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
...@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc( ...@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc(
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
q_head_num=head_num, q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
Lv=Lv, Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
) )
return _decode_v1_kernel_stage2_use_tc.best_config # return _decode_v1_kernel_stage2_use_tc.best_config
def decode_attention_v1( def decode_attention_v1(
...@@ -1003,11 +1019,36 @@ def decode_attention_v1( ...@@ -1003,11 +1019,36 @@ def decode_attention_v1(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config,
page_size, page_size,
logit_cap=0.0, logit_cap=0.0,
): ):
# GQA/MQA/MLA # GQA/MQA/MLA
_decode_v1_stage1_best_config = _decode_v1_stage1_use_tc( # _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
# q,
# k_buffer,
# attn_logits,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# sm_scale,
# page_size,
# num_kv_splits,
# logit_cap,
# )
# _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
# attn_logits,
# v_buffer,
# o,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# page_size,
# )
# return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
_decode_v1_stage1_use_tc(
q, q,
k_buffer, k_buffer,
attn_logits, attn_logits,
...@@ -1019,8 +1060,9 @@ def decode_attention_v1( ...@@ -1019,8 +1060,9 @@ def decode_attention_v1(
page_size, page_size,
num_kv_splits, num_kv_splits,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v1_stage2_best_config = _decode_v1_stage2_use_tc( _decode_v1_stage2_use_tc(
attn_logits, attn_logits,
v_buffer, v_buffer,
o, o,
...@@ -1029,31 +1071,31 @@ def decode_attention_v1( ...@@ -1029,31 +1071,31 @@ def decode_attention_v1(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
page_size, page_size,
best_config['stage2'],
) )
return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
# @triton.autotune(
@triton.autotune( # configs=[
configs=[ # triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1), # ],
], # key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"] # )
)
@triton.jit @triton.jit
def _decode_v2_kernel_stage1_use_tc( def _decode_v2_kernel_stage1_use_tc(
Q, Q,
...@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc( ...@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc(
B_Seqlen, B_Seqlen,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config,
page_size, page_size,
logit_cap, logit_cap,
): ):
BLOCK = best_config['BLOCK_N']
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc( ...@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc(
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H, BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS, NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps,
num_stages=num_stages,
Lk=Lk, Lk=Lk,
Lv=Lv, Lv=Lv,
kpack=2, kpack=2,
) )
return _decode_v2_kernel_stage1_use_tc.best_config # return _decode_v2_kernel_stage1_use_tc.best_config
@triton.autotune( # @triton.autotune(
configs=[ # configs=[
triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
triton.Config({}, num_warps=2, num_stages=1), # triton.Config({}, num_warps=2, num_stages=1),
triton.Config({}, num_warps=4, num_stages=1), # triton.Config({}, num_warps=4, num_stages=1),
triton.Config({}, num_warps=8, num_stages=1), # triton.Config({}, num_warps=8, num_stages=1),
], # ],
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"] # key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
) # )
@triton.jit @triton.jit
def _decode_v2_kernel_stage2( def _decode_v2_kernel_stage2(
Mid_O, Mid_O,
...@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc( ...@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc(
v_buffer, v_buffer,
b_seq_len, b_seq_len,
num_kv_splits, num_kv_splits,
best_config,
): ):
num_stages = best_config['num_stages']
num_warps = best_config['num_warps']
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc( ...@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS=NUM_KV_SPLITS, NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
Lv=Lv, Lv=Lv,
num_warps=num_warps,
num_stages=num_stages,
) )
return _decode_v2_kernel_stage2.best_config # return _decode_v2_kernel_stage2.best_config
def decode_attention_v2( def decode_attention_v2(
...@@ -1401,10 +1456,26 @@ def decode_attention_v2( ...@@ -1401,10 +1456,26 @@ def decode_attention_v2(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config,
page_size, page_size,
logit_cap=0.0, logit_cap=0.0,
): ):
_decode_v2_stage1_best_config = _decode_v2_stage1_use_tc( # _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
# q,
# k_buffer,
# v_buffer,
# attn_logits,
# req_to_token,
# # b_req_idx,
# b_seq_len,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
# _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
# return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
_decode_v2_stage1_use_tc(
q, q,
k_buffer, k_buffer,
v_buffer, v_buffer,
...@@ -1416,9 +1487,9 @@ def decode_attention_v2( ...@@ -1416,9 +1487,9 @@ def decode_attention_v2(
sm_scale, sm_scale,
page_size, page_size,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config['stage2'])
return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
def decode_attention_fwd( def decode_attention_fwd(
...@@ -1431,7 +1502,7 @@ def decode_attention_fwd( ...@@ -1431,7 +1502,7 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
# config, best_config,
page_size=1, page_size=1,
logit_cap=0.0, logit_cap=0.0,
): ):
...@@ -1456,6 +1527,7 @@ def decode_attention_fwd( ...@@ -1456,6 +1527,7 @@ def decode_attention_fwd(
else: else:
# GQA/MQA/MLA # GQA/MQA/MLA
if envs.VLLM_USE_TRITON_OPT_MLA: if envs.VLLM_USE_TRITON_OPT_MLA:
'''
decode_attention_v2( decode_attention_v2(
q, q,
k_buffer, k_buffer,
...@@ -1469,63 +1541,62 @@ def decode_attention_fwd( ...@@ -1469,63 +1541,62 @@ def decode_attention_fwd(
page_size, page_size,
logit_cap, logit_cap,
) )
# attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size), (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16, dtype=torch.float16,
# device="cuda") device="cuda")
# decode_attention_v1( decode_attention_v1(
# q, q,
# k_buffer, k_buffer,
# v_buffer, v_buffer,
# o, o,
# req_to_token, req_to_token,
# b_start_loc, b_start_loc,
# b_seq_len, b_seq_len,
# attn_logits_v1, attn_logits_v1,
# num_kv_splits, # sub num_kv_splits, # sub
# sm_scale, sm_scale,
# page_size, page_size,
# logit_cap, logit_cap,
# ) )'''
# TODO if best_config['kernel_kind'] == 'v1_2stages_tc':
# if best_config['kernel_kind'] == 'v1_2stages_tc': attn_logits_v1 = torch.empty(
# attn_logits_v1 = torch.empty( (q.shape[1],k_buffer.shape[0]*page_size),
# (q.shape[1],k_buffer.shape[0]*page_size), dtype=torch.float16,
# dtype=torch.float16, device="cuda")
# device="cuda") decode_attention_v1(
# decode_attention_v1( q,
# q, k_buffer,
# k_buffer, v_buffer,
# v_buffer, o,
# o, req_to_token,
# req_to_token, b_start_loc,
# b_start_loc, b_seq_len,
# b_seq_len, attn_logits_v1,
# attn_logits_v1, num_kv_splits,
# num_kv_splits, sm_scale,
# sm_scale, best_config=best_config['best_config'],
# config, page_size=page_size,
# page_size, logit_cap=logit_cap,
# logit_cap, )
# ) elif best_config['kernel_kind'] == 'v2_tc':
# elif best_config['kernel_kind'] == 'v2_tc': decode_attention_v2(
# decode_attention_v2( q,
# q, k_buffer,
# k_buffer, v_buffer,
# v_buffer, o,
# o, req_to_token,
# req_to_token, b_seq_len,
# b_seq_len, attn_logits,
# attn_logits, num_kv_splits,
# num_kv_splits, sm_scale,
# sm_scale, best_config=best_config['best_config'],
# config, page_size=page_size,
# page_size, logit_cap=logit_cap,
# logit_cap, )
# ) else:
# else: print("Unknown mla kernel kind: ", best_config['kernel_kind'])
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
else: else:
decode_attention_fwd_grouped( decode_attention_fwd_grouped(
q, q,
......
...@@ -89,7 +89,7 @@ def get_model_architecture( ...@@ -89,7 +89,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' ) # TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment