Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
40083064
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "a8f12a63fde4765dffe53f7bf1482d52ac80af33"
Commit
40083064
authored
Feb 28, 2025
by
zhuwenwen
Browse files
add VLLM_USE_TRITON_OPT_MLA to use optimized MLA attention
parent
0a130908
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1064 additions
and
18 deletions
+1064
-18
vllm/attention/ops/triton_decode_attention.py
vllm/attention/ops/triton_decode_attention.py
+1059
-18
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
vllm/attention/ops/triton_decode_attention.py
View file @
40083064
...
@@ -30,11 +30,12 @@ It supports page size >= 1.
...
@@ -30,11 +30,12 @@ It supports page size >= 1.
import
os
import
os
import
logging
import
logging
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm
import
envs
is_hip_
=
current_platform
.
is_rocm
()
is_hip_
=
current_platform
.
is_rocm
()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
...
@@ -221,7 +222,6 @@ def _decode_att_m_fwd(
...
@@ -221,7 +222,6 @@ def _decode_att_m_fwd(
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
Lv
=
Lv
,
)
)
...
@@ -458,7 +458,6 @@ def _decode_grouped_att_m_fwd(
...
@@ -458,7 +458,6 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
Lv
=
Lv
,
**
extra_kargs
,
**
extra_kargs
,
...
@@ -560,7 +559,6 @@ def _decode_softmax_reducev_fwd(
...
@@ -560,7 +559,6 @@ def _decode_softmax_reducev_fwd(
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
)
)
...
@@ -623,6 +621,1017 @@ def decode_attention_fwd_grouped(
...
@@ -623,6 +621,1017 @@ def decode_attention_fwd_grouped(
num_kv_splits
)
num_kv_splits
)
# opt
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"SPLIT_K"
:
256
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
k_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_logic_h"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage2_use_tc
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
#tl.load(B_req_idx + cur_batch)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_page_number
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
)
//
PAGE_SIZE
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
v_loc
=
v_page_number
*
PAGE_SIZE
+
(
start_n
+
offs_n
)
%
PAGE_SIZE
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
)
qk
=
tl
.
load
(
logits
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
)
#[head, block_n]
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_loc
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
#[block_n,head_dim]
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
att_out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
sm_scale
,
page_size
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
META
[
'SPLIT_K'
],
)
_decode_v1_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
kpack
=
2
,
)
return
_decode_v1_kernel_stage1_use_tc
.
best_config
def
_decode_v1_stage2_use_tc
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
):
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_decode_v1_kernel_stage2_use_tc
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
Lv
=
Lv
,
)
return
_decode_v1_kernel_stage2_use_tc
.
best_config
def
decode_attention_v1
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
attn_logits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config
=
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
attn_logits
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_v1_stage2_best_config
=
_decode_v1_stage2_use_tc
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
)
return
_decode_v1_stage1_best_config
,
_decode_v1_stage2_best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
1
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
2
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
4
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
8
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
16
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
32
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
64
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"NUM_KV_SPLITS"
:
128
,
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
sm_scale
,
page_size
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
META
[
'NUM_KV_SPLITS'
],
)
_decode_v2_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
Lv
=
Lv
,
kpack
=
2
,
)
return
_decode_v2_kernel_stage1_use_tc
.
best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_mid_ob"
,
"stride_mid_oh"
,
"stride_mid_os"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage2
(
Mid_O
,
O
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_v2_stage2_use_tc
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
head_num
)
_decode_v2_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
)
return
_decode_v2_kernel_stage2
.
best_config
def
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
# b_req_idx,
b_seq_len
,
attn_logits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_v2_stage1_best_config
=
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
# b_req_idx,
b_seq_len
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
_decode_v2_stage1_best_config
.
kwargs
[
"NUM_KV_SPLITS"
])
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
def
decode_attention_fwd
(
def
decode_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -638,7 +1647,7 @@ def decode_attention_fwd(
...
@@ -638,7 +1647,7 @@ def decode_attention_fwd(
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
b_start_loc
=
torch
.
arange
(
0
,
k_buffer
.
shape
[
0
]
*
page_size
,
k_buffer
.
shape
[
0
]
*
page_size
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
# MHA
# MHA
decode_attention_fwd_normal
(
decode_attention_fwd_normal
(
...
@@ -656,16 +1665,48 @@ def decode_attention_fwd(
...
@@ -656,16 +1665,48 @@ def decode_attention_fwd(
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
decode_attention_fwd_grouped
(
if
not
envs
.
VLLM_USE_TRITON_OPT_MLA
:
q
,
decode_attention_fwd_grouped
(
k_buffer
,
q
,
v_buffer
,
k_buffer
,
o
,
v_buffer
,
req_to_token
,
o
,
b_seq_len
,
req_to_token
,
attn_logits
,
b_seq_len
,
num_kv_splits
,
attn_logits
,
sm_scale
,
num_kv_splits
,
page_size
,
sm_scale
,
logit_cap
,
page_size
,
)
logit_cap
,
)
else
:
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
sm_scale
,
page_size
,
logit_cap
,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# #num_kv_splits, # sub
# sm_scale,
# page_size,
# logit_cap,
# )
\ No newline at end of file
vllm/envs.py
View file @
40083064
...
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
...
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
VLLM_NCCL_SO_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
LD_LIBRARY_PATH
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
@@ -564,6 +565,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -564,6 +565,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the MLA attention optimizations.
# If set, vLLM will disable the MLA attention optimizations.
"VLLM_MLA_DISABLE"
:
"VLLM_MLA_DISABLE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MLA_DISABLE"
,
"0"
))),
# If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_OPT_MLA"
,
"0"
))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment