Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eee6148a
Commit
eee6148a
authored
Mar 13, 2025
by
zhuwenwen
Browse files
update mla to obtain the optimal configuration from config
parent
abac3adc
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
200 deletions
+287
-200
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+19
-3
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+267
-196
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
No files found.
vllm/attention/backends/triton_mla.py
View file @
eee6148a
...
@@ -40,6 +40,18 @@ from vllm.logger import init_logger
...
@@ -40,6 +40,18 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
get_config
(
bs_key
,
mean_kv_seqlen_key
,
config
):
# 转换参数为字符串以匹配字典的键
bs_key_str
=
str
(
bs_key
)
mean_kv_seqlen_key_str
=
str
(
mean_kv_seqlen_key
)
# 检查字典中是否存在对应的配置
if
bs_key_str
in
config
and
mean_kv_seqlen_key_str
in
config
[
bs_key_str
]:
return
config
[
bs_key_str
][
mean_kv_seqlen_key_str
]
else
:
raise
ValueError
(
f
"No matching configuration found for bs key:
{
bs_key
}
and mean kv seq key:
{
mean_kv_seqlen_key
}
when init decode attention db"
)
def
get_mla_config_file_name
(
QH
:
int
,
KVH
:
int
,
QKD
:
int
,
VD
:
int
,
cache_dtype
:
Optional
[
str
])
->
str
:
def
get_mla_config_file_name
(
QH
:
int
,
KVH
:
int
,
QKD
:
int
,
VD
:
int
,
cache_dtype
:
Optional
[
str
])
->
str
:
if
cache_dtype
==
"default"
:
if
cache_dtype
==
"default"
:
return
f
"QH=
{
QH
}
_KVH=
{
KVH
}
_QKD=
{
QKD
}
_VD=
{
VD
}
_default.json"
return
f
"QH=
{
QH
}
_KVH=
{
KVH
}
_QKD=
{
QKD
}
_VD=
{
VD
}
_default.json"
...
@@ -736,6 +748,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -736,6 +748,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
"TritonMLAImpl"
)
"TritonMLAImpl"
)
self
.
attn_configs
=
get_attention_mla_configs
(
self
.
num_heads
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
,
"fp16"
)
def
_forward_prefill
(
def
_forward_prefill
(
self
,
self
,
...
@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# TODO
# TODO
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
for
bs
in
self
.
attn_configs
.
keys
():
for
mean_seq_len
in
self
.
attn_configs
[
bs
].
keys
():
best_config
=
get_config
(
bs
,
mean_seq_len
,
self
.
attn_configs
)
# Run MQA
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
#
config,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
best_
config
,
PAGE_SIZE
)
PAGE_SIZE
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/ops/triton_decode_attention.py
View file @
eee6148a
This diff is collapsed.
Click to expand it.
vllm/model_executor/model_loader/utils.py
View file @
eee6148a
...
@@ -89,7 +89,7 @@ def get_model_architecture(
...
@@ -89,7 +89,7 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'Qwen2ForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment