Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eee6148a
Commit
eee6148a
authored
Mar 13, 2025
by
zhuwenwen
Browse files
update mla to obtain the optimal configuration from config
parent
abac3adc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
200 deletions
+287
-200
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+19
-3
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+267
-196
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
No files found.
vllm/attention/backends/triton_mla.py
View file @
eee6148a
...
@@ -40,6 +40,18 @@ from vllm.logger import init_logger
...
@@ -40,6 +40,18 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
get_config
(
bs_key
,
mean_kv_seqlen_key
,
config
):
# 转换参数为字符串以匹配字典的键
bs_key_str
=
str
(
bs_key
)
mean_kv_seqlen_key_str
=
str
(
mean_kv_seqlen_key
)
# 检查字典中是否存在对应的配置
if
bs_key_str
in
config
and
mean_kv_seqlen_key_str
in
config
[
bs_key_str
]:
return
config
[
bs_key_str
][
mean_kv_seqlen_key_str
]
else
:
raise
ValueError
(
f
"No matching configuration found for bs key:
{
bs_key
}
and mean kv seq key:
{
mean_kv_seqlen_key
}
when init decode attention db"
)
def
get_mla_config_file_name
(
QH
:
int
,
KVH
:
int
,
QKD
:
int
,
VD
:
int
,
cache_dtype
:
Optional
[
str
])
->
str
:
def
get_mla_config_file_name
(
QH
:
int
,
KVH
:
int
,
QKD
:
int
,
VD
:
int
,
cache_dtype
:
Optional
[
str
])
->
str
:
if
cache_dtype
==
"default"
:
if
cache_dtype
==
"default"
:
return
f
"QH=
{
QH
}
_KVH=
{
KVH
}
_QKD=
{
QKD
}
_VD=
{
VD
}
_default.json"
return
f
"QH=
{
QH
}
_KVH=
{
KVH
}
_QKD=
{
QKD
}
_VD=
{
VD
}
_default.json"
...
@@ -737,6 +749,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -737,6 +749,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for "
"are not implemented for "
"TritonMLAImpl"
)
"TritonMLAImpl"
)
self
.
attn_configs
=
get_attention_mla_configs
(
self
.
num_heads
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
,
"fp16"
)
def
_forward_prefill
(
def
_forward_prefill
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
...
@@ -791,13 +805,15 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# TODO
# TODO
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
for
bs
in
self
.
attn_configs
.
keys
():
for
mean_seq_len
in
self
.
attn_configs
[
bs
].
keys
():
best_config
=
get_config
(
bs
,
mean_seq_len
,
self
.
attn_configs
)
# Run MQA
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_meta
.
block_tables
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
#
config,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
best_
config
,
PAGE_SIZE
)
PAGE_SIZE
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/ops/triton_decode_attention.py
View file @
eee6148a
...
@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped(
...
@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped(
# opt
# opt
@
triton
.
autotune
(
#
@triton.autotune(
configs
=
[
#
configs=[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
#
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
]
#
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
)
#
)
@
triton
.
jit
@
triton
.
jit
def
_decode_v1_kernel_stage1_use_tc
(
def
_decode_v1_kernel_stage1_use_tc
(
Q
,
Q
,
...
@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc(
...
@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc(
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
)
@
triton
.
autotune
(
#
@triton.autotune(
configs
=
[
#
configs=[
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
#
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
],
#
],
key
=
[
"B_Seqlen"
,
"stride_logic_h"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
#
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
)
#
)
@
triton
.
jit
@
triton
.
jit
def
_decode_v1_kernel_stage2_use_tc
(
def
_decode_v1_kernel_stage2_use_tc
(
logits
,
logits
,
...
@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc(
...
@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc(
page_size
,
page_size
,
num_kv_splits
,
num_kv_splits
,
logit_cap
,
logit_cap
,
best_config
,
):
):
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
...
@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc(
...
@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1]
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
SPLIT_K
=
num_kv_splits
BLOCK_N
=
best_config
[
'BLOCK_N'
]
SPLIT_K
=
num_kv_splits
# best_config['SPLIT_K'] ?
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
batch
,
batch
,
...
@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc(
...
@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc(
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK_N
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
SPLIT_K
=
SPLIT_K
,
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
kpack
=
2
,
kpack
=
2
,
)
)
return
_decode_v1_kernel_stage1_use_tc
.
best_config
#
return _decode_v1_kernel_stage1_use_tc.best_config
def
_decode_v1_stage2_use_tc
(
def
_decode_v1_stage2_use_tc
(
...
@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc(
...
@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc(
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
page_size
,
page_size
,
best_config
,
):
):
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
BLOCK_N
=
best_config
[
'BLOCK_N'
]
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
...
@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc(
...
@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc(
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK_N
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
)
return
_decode_v1_kernel_stage2_use_tc
.
best_config
#
return _decode_v1_kernel_stage2_use_tc.best_config
def
decode_attention_v1
(
def
decode_attention_v1
(
...
@@ -1003,11 +1019,36 @@ def decode_attention_v1(
...
@@ -1003,11 +1019,36 @@ def decode_attention_v1(
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
best_config
,
page_size
,
page_size
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
# GQA/MQA/MLA
# GQA/MQA/MLA
_decode_v1_stage1_best_config
=
_decode_v1_stage1_use_tc
(
# _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
# q,
# k_buffer,
# attn_logits,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# sm_scale,
# page_size,
# num_kv_splits,
# logit_cap,
# )
# _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
# attn_logits,
# v_buffer,
# o,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# page_size,
# )
# return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
_decode_v1_stage1_use_tc
(
q
,
q
,
k_buffer
,
k_buffer
,
attn_logits
,
attn_logits
,
...
@@ -1019,8 +1060,9 @@ def decode_attention_v1(
...
@@ -1019,8 +1060,9 @@ def decode_attention_v1(
page_size
,
page_size
,
num_kv_splits
,
num_kv_splits
,
logit_cap
,
logit_cap
,
best_config
[
'stage1'
],
)
)
_decode_v1_stage2_best_config
=
_decode_v1_stage2_use_tc
(
_decode_v1_stage2_use_tc
(
attn_logits
,
attn_logits
,
v_buffer
,
v_buffer
,
o
,
o
,
...
@@ -1029,31 +1071,31 @@ def decode_attention_v1(
...
@@ -1029,31 +1071,31 @@ def decode_attention_v1(
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
page_size
,
page_size
,
best_config
[
'stage2'
],
)
)
return
_decode_v1_stage1_best_config
,
_decode_v1_stage2_best_config
# @triton.autotune(
@
triton
.
autotune
(
# configs=[
configs
=
[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
# ],
],
# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
# )
)
@
triton
.
jit
@
triton
.
jit
def
_decode_v2_kernel_stage1_use_tc
(
def
_decode_v2_kernel_stage1_use_tc
(
Q
,
Q
,
...
@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc(
...
@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc(
B_Seqlen
,
B_Seqlen
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
best_config
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
best_config
[
'BLOCK_N'
]
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc(
...
@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
Lv
=
Lv
,
kpack
=
2
,
kpack
=
2
,
)
)
return
_decode_v2_kernel_stage1_use_tc
.
best_config
#
return _decode_v2_kernel_stage1_use_tc.best_config
@
triton
.
autotune
(
#
@triton.autotune(
configs
=
[
#
configs=[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
#
triton.Config({}, num_warps=1, num_stages=1),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
#
triton.Config({}, num_warps=2, num_stages=1),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
#
triton.Config({}, num_warps=4, num_stages=1),
triton
.
Config
({},
num_warps
=
8
,
num_stages
=
1
),
#
triton.Config({}, num_warps=8, num_stages=1),
],
#
],
key
=
[
"B_Seqlen"
,
"stride_mid_ob"
,
"stride_mid_oh"
,
"stride_mid_os"
]
#
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
)
#
)
@
triton
.
jit
@
triton
.
jit
def
_decode_v2_kernel_stage2
(
def
_decode_v2_kernel_stage2
(
Mid_O
,
Mid_O
,
...
@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc(
...
@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc(
v_buffer
,
v_buffer
,
b_seq_len
,
b_seq_len
,
num_kv_splits
,
num_kv_splits
,
best_config
,
):
):
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc(
...
@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
)
return
_decode_v2_kernel_stage2
.
best_config
#
return _decode_v2_kernel_stage2.best_config
def
decode_attention_v2
(
def
decode_attention_v2
(
...
@@ -1401,10 +1456,26 @@ def decode_attention_v2(
...
@@ -1401,10 +1456,26 @@ def decode_attention_v2(
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
best_config
,
page_size
,
page_size
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
_decode_v2_stage1_best_config
=
_decode_v2_stage1_use_tc
(
# _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
# q,
# k_buffer,
# v_buffer,
# attn_logits,
# req_to_token,
# # b_req_idx,
# b_seq_len,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
# _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
# return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
_decode_v2_stage1_use_tc
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
...
@@ -1416,9 +1487,9 @@ def decode_attention_v2(
...
@@ -1416,9 +1487,9 @@ def decode_attention_v2(
sm_scale
,
sm_scale
,
page_size
,
page_size
,
logit_cap
,
logit_cap
,
best_config
[
'stage1'
],
)
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
best_config
[
'stage2'
])
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -1431,7 +1502,7 @@ def decode_attention_fwd(
...
@@ -1431,7 +1502,7 @@ def decode_attention_fwd(
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
#
config,
best_
config
,
page_size
=
1
,
page_size
=
1
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
...
@@ -1456,6 +1527,7 @@ def decode_attention_fwd(
...
@@ -1456,6 +1527,7 @@ def decode_attention_fwd(
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
if
envs
.
VLLM_USE_TRITON_OPT_MLA
:
if
envs
.
VLLM_USE_TRITON_OPT_MLA
:
'''
decode_attention_v2(
decode_attention_v2(
q,
q,
k_buffer,
k_buffer,
...
@@ -1469,63 +1541,62 @@ def decode_attention_fwd(
...
@@ -1469,63 +1541,62 @@ def decode_attention_fwd(
page_size,
page_size,
logit_cap,
logit_cap,
)
)
# attn_logits_v1 = torch.empty(
attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
(q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
dtype=torch.float16,
# device="cuda")
device="cuda")
# decode_attention_v1(
decode_attention_v1(
# q,
q,
# k_buffer,
k_buffer,
# v_buffer,
v_buffer,
# o,
o,
# req_to_token,
req_to_token,
# b_start_loc,
b_start_loc,
# b_seq_len,
b_seq_len,
# attn_logits_v1,
attn_logits_v1,
# num_kv_splits, # sub
num_kv_splits, # sub
# sm_scale,
sm_scale,
# page_size,
page_size,
# logit_cap,
logit_cap,
# )
)'''
# TODO
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
# if best_config['kernel_kind'] == 'v1_2stages_tc':
attn_logits_v1
=
torch
.
empty
(
# attn_logits_v1 = torch.empty(
(
q
.
shape
[
1
],
k_buffer
.
shape
[
0
]
*
page_size
),
# (q.shape[1],k_buffer.shape[0]*page_size),
dtype
=
torch
.
float16
,
# dtype=torch.float16,
device
=
"cuda"
)
# device="cuda")
decode_attention_v1
(
# decode_attention_v1(
q
,
# q,
k_buffer
,
# k_buffer,
v_buffer
,
# v_buffer,
o
,
# o,
req_to_token
,
# req_to_token,
b_start_loc
,
# b_start_loc,
b_seq_len
,
# b_seq_len,
attn_logits_v1
,
# attn_logits_v1,
num_kv_splits
,
# num_kv_splits,
sm_scale
,
# sm_scale,
best_config
=
best_config
[
'best_config'
],
# config,
page_size
=
page_size
,
# page_size,
logit_cap
=
logit_cap
,
# logit_cap,
)
# )
elif
best_config
[
'kernel_kind'
]
==
'v2_tc'
:
# elif best_config['kernel_kind'] == 'v2_tc':
decode_attention_v2
(
# decode_attention_v2(
q
,
# q,
k_buffer
,
# k_buffer,
v_buffer
,
# v_buffer,
o
,
# o,
req_to_token
,
# req_to_token,
b_seq_len
,
# b_seq_len,
attn_logits
,
# attn_logits,
num_kv_splits
,
# num_kv_splits,
sm_scale
,
# sm_scale,
best_config
=
best_config
[
'best_config'
],
# config,
page_size
=
page_size
,
# page_size,
logit_cap
=
logit_cap
,
# logit_cap,
)
# )
else
:
# else:
print
(
"Unknown mla kernel kind: "
,
best_config
[
'kernel_kind'
])
# print("Unknown mla kernel kind: ", best_config['kernel_kind'])
else
:
else
:
decode_attention_fwd_grouped
(
decode_attention_fwd_grouped
(
q
,
q
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
eee6148a
...
@@ -89,7 +89,7 @@ def get_model_architecture(
...
@@ -89,7 +89,7 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'Qwen2ForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment