Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MLAttention
Commits
ebcba9f0
Commit
ebcba9f0
authored
Feb 26, 2025
by
zhangqha
Browse files
support MLAttention
parent
cb13d4d8
Pipeline
#2431
canceled with stages
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1381 additions
and
0 deletions
+1381
-0
README.md
README.md
+54
-0
setup.py
setup.py
+12
-0
tests/test_triton_decode_attention.py
tests/test_triton_decode_attention.py
+90
-0
triton_mla_op/__init__.py
triton_mla_op/__init__.py
+3
-0
triton_mla_op/triton_decode_attention.py
triton_mla_op/triton_decode_attention.py
+1222
-0
No files found.
README.md
View file @
ebcba9f0
# MLAttention
## 简介
```
MLAttention is an efficient MLA decoding kernel , optimized for variable-length sequences serving.
目前支持的精度:
- BF16, FP16
目前支持的实现方式:
- OpenAI Triton
```
## 安装
### 源码方式安装
```
bash
python3
-m
pip
install
.
```
### 单测验证
```
bash
pytest
-s
tests/test_triton_decode_attention.py
```
## 使用方式
```
python
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
...
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
...
```
## MLAttention开发进度
```
目前,基于 Cutlass的MLAttention版本正在积极开发中。
我们会及时在项目仓库中更新开发进度。欢迎关注我们的开发者社区以获取最新信息。
```
setup.py
0 → 100644
View file @
ebcba9f0
from
setuptools
import
setup
,
find_packages
setup
(
name
=
'MLAttention'
,
packages
=
find_packages
(
"triton_mla_op"
),
package_dir
=
{
""
:
"triton_mla_op"
},
include_package_data
=
True
,
install_requires
=
[
'triton'
,
'torch'
]
)
tests/test_triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1027
,
1025
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D_QK"
,
[
128
,
192
,
576
])
@
pytest
.
mark
.
parametrize
(
"D_V"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"CACHE_SIZE"
,
[
16384
])
@
pytest
.
mark
.
parametrize
(
"PAGE_SIZE"
,
[
1
,
16
])
def
test_decode_attention
(
B
,
L
,
H_Q
,
H_KV
,
D_QK
,
D_V
,
CACHE_SIZE
,
PAGE_SIZE
):
assert
CACHE_SIZE
%
PAGE_SIZE
==
0
dtype
=
torch
.
bfloat16
seq_len
=
L
# This represents the number of tokens already in the sequence
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
#这里为向上取整,65,(1027+16-1)//16
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
#shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
#这里是维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
# Call the original implementation.
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
# Page size can be larger than 1.
k_buffer
=
k_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_buffer
=
v_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
o1
=
torch
.
zeros_like
(
o
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
)
triton_mla_op/__init__.py
0 → 100644
View file @
ebcba9f0
from
.triton_decode_attention
import
decode_attention_fwd
__all__
=
[
'decode_attention_fwd'
]
triton_mla_op/triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import
os
import
logging
import
torch
import
triton
import
triton.language
as
tl
from
vllm.platforms
import
current_platform
is_hip_
=
current_platform
.
is_rocm
()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
# and accuracy.
logger
.
warning
(
"The following error message 'operation scheduled before its operands' "
"can be ignored."
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
split_kv_id
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
-
float
(
"inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:])
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:])
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
acc
*=
re_scale
acc
+=
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
0
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
)
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
,
mask
=
(
mask_dv
),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
)
def
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
BLOCK
=
64
NUM_KV_SPLITS
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
num_warps
=
4
if
kv_group_num
==
1
else
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
Lk
=
Lk
,
Lv
=
Lv
,
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
o
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
**
extra_kargs
,
)
def
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
# opt
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
k_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_logic_h"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage2_use_tc
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
#tl.load(B_req_idx + cur_batch)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_page_number
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
)
//
PAGE_SIZE
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
v_loc
=
v_page_number
*
PAGE_SIZE
+
(
start_n
+
offs_n
)
%
PAGE_SIZE
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
)
qk
=
tl
.
load
(
logits
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
)
#[head, block_n]
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_loc
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
#[block_n,head_dim]
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
att_out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
sm_scale
,
page_size
,
num_kv_splits
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
SPLIT_K
=
num_kv_splits
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
SPLIT_K
,
)
_decode_v1_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
kpack
=
2
,
)
return
_decode_v1_kernel_stage1_use_tc
.
best_config
def
_decode_v1_stage2_use_tc
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
):
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_decode_v1_kernel_stage2_use_tc
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
Lv
=
Lv
,
)
return
_decode_v1_kernel_stage2_use_tc
.
best_config
def
decode_attention_v1
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config
=
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
attn_logits
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
sm_scale
,
page_size
,
num_kv_splits
,
logit_cap
,
)
_decode_v1_stage2_best_config
=
_decode_v1_stage2_use_tc
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
)
return
_decode_v1_stage1_best_config
,
_decode_v1_stage2_best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
NUM_KV_SPLITS
=
num_kv_splits
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM_KV_SPLITS
,
)
_decode_v2_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
Lv
=
Lv
,
kpack
=
2
,
)
return
_decode_v2_kernel_stage1_use_tc
.
best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_mid_ob"
,
"stride_mid_oh"
,
"stride_mid_os"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage2
(
Mid_O
,
O
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_v2_stage2_use_tc
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
head_num
)
_decode_v2_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
)
return
_decode_v2_kernel_stage2
.
best_config
def
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
# b_req_idx,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_v2_stage1_best_config
=
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
# b_req_idx,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
def
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
=
1
,
logit_cap
=
0.0
,
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
b_start_loc
=
torch
.
arange
(
0
,
k_buffer
.
shape
[
0
]
*
page_size
,
k_buffer
.
shape
[
0
]
*
page_size
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
if
kv_group_num
==
1
:
# MHA
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
else
:
# GQA/MQA/MLA
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment