Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
146eb9d3
Commit
146eb9d3
authored
Mar 13, 2025
by
zhuwenwen
Browse files
update mla to obtain the optimal configuration from config
parent
c1370857
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
288 additions
and
200 deletions
+288
-200
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+20
-4
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+267
-195
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
No files found.
vllm/attention/backends/triton_mla.py
View file @
146eb9d3
...
...
@@ -39,6 +39,17 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata
)
def
get_config
(
bs_key
,
mean_kv_seqlen_key
,
config
):
# 转换参数为字符串以匹配字典的键
bs_key_str
=
str
(
bs_key
)
mean_kv_seqlen_key_str
=
str
(
mean_kv_seqlen_key
)
# 检查字典中是否存在对应的配置
if
bs_key_str
in
config
and
mean_kv_seqlen_key_str
in
config
[
bs_key_str
]:
return
config
[
bs_key_str
][
mean_kv_seqlen_key_str
]
else
:
raise
ValueError
(
f
"No matching configuration found for bs key:
{
bs_key
}
and mean kv seq key:
{
mean_kv_seqlen_key
}
when init decode attention db"
)
def
get_mla_config_file_name
(
QH
:
int
,
KVH
:
int
,
QKD
:
int
,
VD
:
int
,
cache_dtype
:
Optional
[
str
])
->
str
:
if
cache_dtype
==
"default"
:
return
f
"QH=
{
QH
}
_KVH=
{
KVH
}
_QKD=
{
QKD
}
_VD=
{
VD
}
_default.json"
...
...
@@ -736,6 +747,8 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for "
"TritonMLAImpl"
)
self
.
attn_configs
=
get_attention_mla_configs
(
self
.
num_heads
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
,
"fp16"
)
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
...
...
@@ -789,13 +802,16 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# config = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
# TODO
for
bs
in
self
.
attn_configs
.
keys
():
for
mean_seq_len
in
self
.
attn_configs
[
bs
].
keys
():
best_config
=
get_config
(
bs
,
mean_seq_len
,
self
.
attn_configs
)
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
#
config,
attn_metadata
.
num_kv_splits
,
self
.
scale
,
best_
config
,
PAGE_SIZE
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/ops/triton_decode_attention.py
View file @
146eb9d3
...
...
@@ -623,26 +623,26 @@ def decode_attention_fwd_grouped(
# opt
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
]
)
#
@triton.autotune(
#
configs=[
#
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
],
#
key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"]
#
)
@
triton
.
jit
def
_decode_v1_kernel_stage1_use_tc
(
Q
,
...
...
@@ -754,59 +754,59 @@ def _decode_v1_kernel_stage1_use_tc(
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_logic_h"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
#
@triton.autotune(
#
configs=[
#
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1),
#
triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1),
#
],
#
key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"]
#
)
@
triton
.
jit
def
_decode_v1_kernel_stage2_use_tc
(
logits
,
...
...
@@ -898,6 +898,7 @@ def _decode_v1_stage1_use_tc(
page_size
,
num_kv_splits
,
logit_cap
,
best_config
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
...
...
@@ -914,7 +915,11 @@ def _decode_v1_stage1_use_tc(
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
SPLIT_K
=
num_kv_splits
BLOCK_N
=
best_config
[
'BLOCK_N'
]
SPLIT_K
=
num_kv_splits
# best_config['SPLIT_K'] ?
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
lambda
META
:
(
batch
,
...
...
@@ -940,14 +945,17 @@ def _decode_v1_stage1_use_tc(
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK_N
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
Lk
=
Lk
,
kpack
=
2
,
)
return
_decode_v1_kernel_stage1_use_tc
.
best_config
#
return _decode_v1_kernel_stage1_use_tc.best_config
def
_decode_v1_stage2_use_tc
(
...
...
@@ -959,9 +967,14 @@ def _decode_v1_stage2_use_tc(
b_start_loc
,
b_seq_len
,
page_size
,
best_config
,
):
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
BLOCK_N
=
best_config
[
'BLOCK_N'
]
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
...
...
@@ -984,11 +997,14 @@ def _decode_v1_stage2_use_tc(
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK_N
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
_decode_v1_kernel_stage2_use_tc
.
best_config
#
return _decode_v1_kernel_stage2_use_tc.best_config
def
decode_attention_v1
(
...
...
@@ -1003,11 +1019,36 @@ def decode_attention_v1(
attn_logits
,
num_kv_splits
,
sm_scale
,
best_config
,
page_size
,
logit_cap
=
0.0
,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config
=
_decode_v1_stage1_use_tc
(
# _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc(
# q,
# k_buffer,
# attn_logits,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# sm_scale,
# page_size,
# num_kv_splits,
# logit_cap,
# )
# _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc(
# attn_logits,
# v_buffer,
# o,
# req_to_token,
# #b_req_idx,
# b_start_loc,
# b_seq_len,
# page_size,
# )
# return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
attn_logits
,
...
...
@@ -1019,8 +1060,9 @@ def decode_attention_v1(
page_size
,
num_kv_splits
,
logit_cap
,
best_config
[
'stage1'
],
)
_decode_v1_stage2_best_config
=
_decode_v1_stage2_use_tc
(
_decode_v1_stage2_use_tc
(
attn_logits
,
v_buffer
,
o
,
...
...
@@ -1029,31 +1071,31 @@ def decode_attention_v1(
b_start_loc
,
b_seq_len
,
page_size
,
best_config
[
'stage2'
],
)
return
_decode_v1_stage1_best_config
,
_decode_v1_stage2_best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1),
# ],
# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"]
# )
@
triton
.
jit
def
_decode_v2_kernel_stage1_use_tc
(
Q
,
...
...
@@ -1227,10 +1269,15 @@ def _decode_v2_stage1_use_tc(
B_Seqlen
,
num_kv_splits
,
sm_scale
,
best_config
,
page_size
,
logit_cap
,
):
BLOCK
=
best_config
[
'BLOCK_N'
]
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
...
@@ -1281,26 +1328,29 @@ def _decode_v2_stage1_use_tc(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lv
=
Lv
,
kpack
=
2
,
)
return
_decode_v2_kernel_stage1_use_tc
.
best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_mid_ob"
,
"stride_mid_oh"
,
"stride_mid_os"
]
)
#
return _decode_v2_kernel_stage1_use_tc.best_config
#
@triton.autotune(
#
configs=[
#
triton.Config({}, num_warps=1, num_stages=1),
#
triton.Config({}, num_warps=2, num_stages=1),
#
triton.Config({}, num_warps=4, num_stages=1),
#
triton.Config({}, num_warps=8, num_stages=1),
#
],
#
key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"]
#
)
@
triton
.
jit
def
_decode_v2_kernel_stage2
(
Mid_O
,
...
...
@@ -1364,7 +1414,10 @@ def _decode_v2_stage2_use_tc(
v_buffer
,
b_seq_len
,
num_kv_splits
,
best_config
,
):
num_stages
=
best_config
[
'num_stages'
]
num_warps
=
best_config
[
'num_warps'
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
...
...
@@ -1385,9 +1438,11 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
_decode_v2_kernel_stage2
.
best_config
#
return _decode_v2_kernel_stage2.best_config
def
decode_attention_v2
(
...
...
@@ -1401,10 +1456,26 @@ def decode_attention_v2(
attn_logits
,
num_kv_splits
,
sm_scale
,
best_config
,
page_size
,
logit_cap
=
0.0
,
):
_decode_v2_stage1_best_config
=
_decode_v2_stage1_use_tc
(
# _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc(
# q,
# k_buffer,
# v_buffer,
# attn_logits,
# req_to_token,
# # b_req_idx,
# b_seq_len,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
# _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
# return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
...
...
@@ -1416,9 +1487,9 @@ def decode_attention_v2(
sm_scale
,
page_size
,
logit_cap
,
best_config
[
'stage1'
],
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
best_config
[
'stage2'
])
def
decode_attention_fwd
(
...
...
@@ -1431,7 +1502,7 @@ def decode_attention_fwd(
attn_logits
,
num_kv_splits
,
sm_scale
,
#
config,
best_
config
,
page_size
=
1
,
logit_cap
=
0.0
,
):
...
...
@@ -1456,6 +1527,7 @@ def decode_attention_fwd(
else
:
# GQA/MQA/MLA
if
envs
.
VLLM_USE_TRITON_OPT_MLA
:
'''
decode_attention_v2(
q,
k_buffer,
...
...
@@ -1469,62 +1541,62 @@ def decode_attention_fwd(
page_size,
logit_cap,
)
#
attn_logits_v1 = torch.empty(
#
(q.shape[1],k_buffer.shape[0]*page_size),
#
dtype=torch.float16,
#
device="cuda")
#
decode_attention_v1(
#
q,
#
k_buffer,
#
v_buffer,
#
o,
#
req_to_token,
#
b_start_loc,
#
b_seq_len,
#
attn_logits_v1,
#
num_kv_splits, # sub
#
sm_scale,
#
page_size,
#
logit_cap,
#
)
#
if best_config['kernel_kind'] == 'v1_2stages_tc':
#
attn_logits_v1 = torch.empty(
#
(q.shape[1],k_buffer.shape[0]*page_size),
#
dtype=torch.float16,
#
device="cuda")
#
decode_attention_v1(
#
q,
#
k_buffer,
#
v_buffer,
#
o,
#
req_to_token,
#
b_start_loc,
#
b_seq_len,
#
attn_logits_v1,
#
num_kv_splits,
#
sm_scale,
#
config,
#
page_size,
#
logit_cap,
#
)
#
elif best_config['kernel_kind'] == 'v2_tc':
#
decode_attention_v2(
#
q,
#
k_buffer,
#
v_buffer,
#
o,
#
req_to_token,
#
b_seq_len,
#
attn_logits,
#
num_kv_splits,
#
sm_scale,
#
config,
#
page_size,
#
logit_cap,
#
)
#
else:
#
print("Unknown mla kernel kind: ", best_config['kernel_kind'])
attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16,
device="cuda")
decode_attention_v1(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_start_loc,
b_seq_len,
attn_logits_v1,
num_kv_splits, # sub
sm_scale,
page_size,
logit_cap,
)
'''
if
best_config
[
'kernel_kind'
]
==
'v1_2stages_tc'
:
attn_logits_v1
=
torch
.
empty
(
(
q
.
shape
[
1
],
k_buffer
.
shape
[
0
]
*
page_size
),
dtype
=
torch
.
float16
,
device
=
"cuda"
)
decode_attention_v1
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_start_loc
,
b_seq_len
,
attn_logits_v1
,
num_kv_splits
,
sm_scale
,
best_
config
=
best_config
[
'best_config'
]
,
page_size
=
page_size
,
logit_cap
=
logit_cap
,
)
elif
best_config
[
'kernel_kind'
]
==
'v2_tc'
:
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
best_
config
=
best_config
[
'best_config'
]
,
page_size
=
page_size
,
logit_cap
=
logit_cap
,
)
else
:
print
(
"Unknown mla kernel kind: "
,
best_config
[
'kernel_kind'
])
else
:
decode_attention_fwd_grouped
(
q
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
146eb9d3
...
...
@@ -80,7 +80,7 @@ def get_model_architecture(
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
visions
=
getattr
(
model_config
.
hf_config
,
"visual"
,
[])
or
getattr
(
model_config
.
hf_config
,
"vision_config"
,
[])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures
=
[
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
support_nn_architectures
=
[
'LlamaForCausalLM'
,
'Qwen2ForCausalLM'
,
'QWenLMHeadModel'
,
'Qwen2VLForConditionalGeneration'
,
'Qwen2_5_VLForConditionalGeneration'
,
'Qwen2MoeForCausalLM'
,
'ChatGLMModel'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'MedusaModel'
,
'MixtralForCausalLM'
,
'MLPSpeculatorPreTrainedModel'
,
'FalconForCausalLM'
,
'DeepseekV2ForCausalLM'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment