Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
df191254
Unverified
Commit
df191254
authored
Aug 19, 2024
by
Ke Bao
Committed by
GitHub
Aug 19, 2024
Browse files
Optimize MLA/GQA/MQA Triton decoding (#1138)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
b997a18d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
337 additions
and
49 deletions
+337
-49
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+337
-49
No files found.
python/sglang/srt/layers/decode_attention.py
View file @
df191254
...
...
@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
...
...
@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_stard_index
=
start_n
*
BLOCK_N
...
...
@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
).
to
(
REDUCE_TRITON_TYPE
)
offs_buf_kpe
=
(
k_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[
None
,
:]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
+=
tl
.
sum
(
qpe
[
None
,
:]
*
kpe
,
1
)
att_value
*=
sm_scale
if
logit_cap
>
0
:
...
...
@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lk
BLOCK_DPE
=
0
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
...
...
@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
)
@
triton
.
jit
def
_fwd_grouped_kernel_stage1
(
Q
,
K_Buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
start_n
=
tl
.
program_id
(
2
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_stard_index
=
start_n
*
BLOCK_N
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
offs_q
+
start_mark
,
mask
=
mask_h
[:,
None
]).
to
(
REDUCE_TRITON_TYPE
)
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n_new
,
mask
=
offs_n_new
<
cur_batch_end_index
,
other
=
0
,
)
offs_buf_k
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
offs_n_new
[
None
,
:]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
,
mask
=
mask_h
[:,
None
]).
to
(
REDUCE_TRITON_TYPE
)
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n_new
[
None
,
:]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
)
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n_new
[
None
,
:]
<
cur_batch_end_index
),
)
@
triton
.
jit
def
_fwd_grouped_kernel_stage2
(
Logics
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
)
qk
=
tl
.
load
(
Logics
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
mask_h
[:,
None
])
def
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
att_out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
):
BLOCK
=
32
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lk
BLOCK_DPE
=
0
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
max_len_in_batch
,
BLOCK
),
)
num_warps
=
4
_fwd_grouped_kernel_stage1
[
grid
](
q
,
k_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
def
_decode_grouped_softmax_reducev_fwd
(
logics
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
):
BLOCK
=
128
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logics
.
shape
[
0
]
kv_group_num
=
logics
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
num_warps
=
8
_fwd_grouped_kernel_stage2
[
grid
](
logics
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logics
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
],
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
def
decode_attention_fwd
(
q
,
k_buffer
,
...
...
@@ -316,24 +577,51 @@ def decode_attention_fwd(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
_decode_att_m_fwd
(
q
,
k_buffer
,
att_m
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
att_m
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
# MHA
_decode_att_m_fwd
(
q
,
k_buffer
,
att_m
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
att_m
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
else
:
# GQA/MQA/MLA
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
att_m
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
_decode_grouped_softmax_reducev_fwd
(
att_m
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment