Commit eee6148a authored by zhuwenwen's avatar zhuwenwen
Browse files

update mla to obtain the optimal configuration from config

parent abac3adc
...@@ -40,6 +40,18 @@ from vllm.logger import init_logger ...@@ -40,6 +40,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str: def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default": if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json" return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
...@@ -736,6 +748,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -736,6 +748,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"TritonMLAImpl") "TritonMLAImpl")
self.attn_configs = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill( def _forward_prefill(
self, self,
...@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]): ...@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE = kv_c_and_k_pe_cache.size(1) PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# TODO # TODO
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16") for bs in self.attn_configs.keys():
for mean_seq_len in self.attn_configs[bs].keys():
best_config = get_config(bs, mean_seq_len, self.attn_configs)
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables, decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits, decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale, # config, attn_metadata.num_kv_splits, self.scale, best_config,
PAGE_SIZE) PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)
...@@ -89,7 +89,7 @@ def get_model_architecture( ...@@ -89,7 +89,7 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' ) # TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment