Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MLAttention
Commits
ebcba9f0
You need to sign in or sign up before continuing.
Commit
ebcba9f0
authored
Feb 26, 2025
by
zhangqha
Browse files
support MLAttention
parent
cb13d4d8
Pipeline
#2431
canceled with stages
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1381 additions
and
0 deletions
+1381
-0
README.md
README.md
+54
-0
setup.py
setup.py
+12
-0
tests/test_triton_decode_attention.py
tests/test_triton_decode_attention.py
+90
-0
triton_mla_op/__init__.py
triton_mla_op/__init__.py
+3
-0
triton_mla_op/triton_decode_attention.py
triton_mla_op/triton_decode_attention.py
+1222
-0
No files found.
README.md
View file @
ebcba9f0
# MLAttention
# MLAttention
## 简介
```
MLAttention is an efficient MLA decoding kernel , optimized for variable-length sequences serving.
目前支持的精度:
- BF16, FP16
目前支持的实现方式:
- OpenAI Triton
```
## 安装
### 源码方式安装
```
bash
python3
-m
pip
install
.
```
### 单测验证
```
bash
pytest
-s
tests/test_triton_decode_attention.py
```
## 使用方式
```
python
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
...
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
...
```
## MLAttention开发进度
```
目前,基于 Cutlass的MLAttention版本正在积极开发中。
我们会及时在项目仓库中更新开发进度。欢迎关注我们的开发者社区以获取最新信息。
```
setup.py
0 → 100644
View file @
ebcba9f0
from
setuptools
import
setup
,
find_packages
setup
(
name
=
'MLAttention'
,
packages
=
find_packages
(
"triton_mla_op"
),
package_dir
=
{
""
:
"triton_mla_op"
},
include_package_data
=
True
,
install_requires
=
[
'triton'
,
'torch'
]
)
tests/test_triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
triton
from
triton_mla_op.triton_decode_attention
import
decode_attention_fwd
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1027
,
1025
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D_QK"
,
[
128
,
192
,
576
])
@
pytest
.
mark
.
parametrize
(
"D_V"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"CACHE_SIZE"
,
[
16384
])
@
pytest
.
mark
.
parametrize
(
"PAGE_SIZE"
,
[
1
,
16
])
def
test_decode_attention
(
B
,
L
,
H_Q
,
H_KV
,
D_QK
,
D_V
,
CACHE_SIZE
,
PAGE_SIZE
):
assert
CACHE_SIZE
%
PAGE_SIZE
==
0
dtype
=
torch
.
bfloat16
seq_len
=
L
# This represents the number of tokens already in the sequence
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
#这里为向上取整,65,(1027+16-1)//16
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
#shape为(B, num_pages_per_batch, 1)的tensor,大小取值为0 至cache_size//page_size
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
#这里是维度扩展,从torch.Size([3, 65, 1])扩展至torch.Size([3, 65, 16])
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
# Call the original implementation.
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
# Page size can be larger than 1.
k_buffer
=
k_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_buffer
=
v_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
o1
=
torch
.
zeros_like
(
o
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
)
triton_mla_op/__init__.py
0 → 100644
View file @
ebcba9f0
from
.triton_decode_attention
import
decode_attention_fwd
__all__
=
[
'decode_attention_fwd'
]
triton_mla_op/triton_decode_attention.py
0 → 100644
View file @
ebcba9f0
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import
os
import
logging
import
torch
import
triton
import
triton.language
as
tl
from
vllm.platforms
import
current_platform
is_hip_
=
current_platform
.
is_rocm
()
os
.
environ
[
"TRITON_HIP_USE_NEW_STREAM_PIPELINE"
]
=
f
"0"
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
# and accuracy.
logger
.
warning
(
"The following error message 'operation scheduled before its operands' "
"can be ignored."
)
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
split_kv_id
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
-
float
(
"inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:])
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:])
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
acc
*=
re_scale
acc
+=
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
0
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
)
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
,
mask
=
(
mask_dv
),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
)
def
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
BLOCK
=
64
NUM_KV_SPLITS
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
num_warps
=
4
if
kv_group_num
==
1
else
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
3
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer
.
stride
(
-
2
),
# Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
Lk
=
Lk
,
Lv
=
Lv
,
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
o
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
o
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
**
extra_kargs
,
)
def
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
# opt
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
k_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
8
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
1
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
1
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
2
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
8
,
num_ldmatrixes
=
0
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_logic_h"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v1_kernel_stage2_use_tc
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
cur_batch
#tl.load(B_req_idx + cur_batch)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_page_number
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
)
//
PAGE_SIZE
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
v_loc
=
v_page_number
*
PAGE_SIZE
+
(
start_n
+
offs_n
)
%
PAGE_SIZE
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
)
qk
=
tl
.
load
(
logits
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
)
#[head, block_n]
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_loc
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
#[block_n,head_dim]
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
att_out
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
sm_scale
,
page_size
,
num_kv_splits
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
SPLIT_K
=
num_kv_splits
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
SPLIT_K
,
)
_decode_v1_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
sm_scale
,
Req_to_tokens
,
#B_req_idx,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
kpack
=
2
,
)
return
_decode_v1_kernel_stage1_use_tc
.
best_config
def
_decode_v1_stage2_use_tc
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
):
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
-
2
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_decode_v1_kernel_stage2_use_tc
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_H
=
BLOCK_H
,
PAGE_SIZE
=
page_size
,
Lv
=
Lv
,
)
return
_decode_v1_kernel_stage2_use_tc
.
best_config
def
decode_attention_v1
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
# GQA/MQA/MLA
_decode_v1_stage1_best_config
=
_decode_v1_stage1_use_tc
(
q
,
k_buffer
,
attn_logits
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
sm_scale
,
page_size
,
num_kv_splits
,
logit_cap
,
)
_decode_v1_stage2_best_config
=
_decode_v1_stage2_use_tc
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
#b_req_idx,
b_start_loc
,
b_seq_len
,
page_size
,
)
return
_decode_v1_stage1_best_config
,
_decode_v1_stage2_best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
16
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
32
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
8
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_qbs"
,
"stride_buf_kbs"
,
"stride_buf_kh"
,
"stride_buf_vbs"
,
"stride_buf_vh"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage1_use_tc
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_kv_id
=
tl
.
program_id
(
2
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
# cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_req_idx
=
cur_batch
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_page_number
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
//
PAGE_SIZE
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
kv_loc
=
kv_page_number
*
PAGE_SIZE
+
offs_n
%
PAGE_SIZE
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
def
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
):
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lk
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
# batch, head_num = B_req_idx.shape[0], q.shape[1]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
-
2
]
BLOCK_H
=
16
NUM_KV_SPLITS
=
num_kv_splits
grid
=
lambda
META
:
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM_KV_SPLITS
,
)
_decode_v2_kernel_stage1_use_tc
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
# B_req_idx,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
-
3
),
k_buffer
.
stride
(
-
2
),
v_buffer
.
stride
(
-
3
),
v_buffer
.
stride
(
-
2
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
PAGE_SIZE
=
page_size
,
logit_cap
=
logit_cap
,
Lk
=
Lk
,
Lv
=
Lv
,
kpack
=
2
,
)
return
_decode_v2_kernel_stage1_use_tc
.
best_config
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
1
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
2
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
1
),
triton
.
Config
({},
num_warps
=
8
,
num_stages
=
1
),
],
key
=
[
"B_Seqlen"
,
"stride_mid_ob"
,
"stride_mid_oh"
,
"stride_mid_os"
]
)
@
triton
.
jit
def
_decode_v2_kernel_stage2
(
Mid_O
,
O
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
if
split_kv_end
>
split_kv_start
:
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_v2_stage2_use_tc
(
logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
grid
=
(
batch
,
head_num
)
_decode_v2_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
)
return
_decode_v2_kernel_stage2
.
best_config
def
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
# b_req_idx,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
=
0.0
,
):
_decode_v2_stage1_best_config
=
_decode_v2_stage1_use_tc
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
# b_req_idx,
b_seq_len
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
_decode_v2_stage2_best_config
=
_decode_v2_stage2_use_tc
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
return
_decode_v2_stage1_best_config
,
_decode_v2_stage2_best_config
def
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
=
1
,
logit_cap
=
0.0
,
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
-
2
]
b_start_loc
=
torch
.
arange
(
0
,
k_buffer
.
shape
[
0
]
*
page_size
,
k_buffer
.
shape
[
0
]
*
page_size
//
q
.
shape
[
0
],
device
=
"cuda"
).
to
(
torch
.
int32
)
if
kv_group_num
==
1
:
# MHA
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
else
:
# GQA/MQA/MLA
decode_attention_v2
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
page_size
,
logit_cap
,
)
# attn_logits_v1 = torch.empty(
# (q.shape[1],k_buffer.shape[0]*page_size),
# dtype=torch.float16,
# device="cuda")
# decode_attention_v1(
# q,
# k_buffer,
# v_buffer,
# o,
# req_to_token,
# b_start_loc,
# b_seq_len,
# attn_logits_v1,
# num_kv_splits,
# sm_scale,
# page_size,
# logit_cap,
# )
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment