Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7dc66fcb
Unverified
Commit
7dc66fcb
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Optimize Triton decoding kernel for long context (#2394)
parent
1f09e84b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
328 additions
and
360 deletions
+328
-360
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+13
-8
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+287
-342
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+21
-10
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7dc66fcb
...
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
...
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
forward_batch
.
seq_lens_sum
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
(
forward_batch
.
batch_size
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
...
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
logit_cap
,
)
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
7dc66fcb
...
...
@@ -17,8 +17,8 @@ It supports page size = 1.
"""
# Adapted from
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_nopad_att
1.py
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_softmax_and_reducev
.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage
1.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage2
.py
import
triton
import
triton.language
as
tl
...
...
@@ -37,10 +37,10 @@ def tanh(x):
def
_fwd_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
...
...
@@ -48,152 +48,137 @@ def _fwd_kernel_stage1(
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
split_k_id
=
tl
.
program_id
(
2
)
split_k
v
_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
).
to
(
reduce_dtype
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT
_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_
SPLIT
S
)
split_k
v
_start
=
kv_len_per_split
*
split_k
v
_id
split_k
v
_end
=
tl
.
minimum
(
split_k
v
_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
e_max
=
-
float
(
"inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k
_loc
=
tl
.
load
(
kv
_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
mask
=
offs_n
<
split_k
v
_end
,
other
=
0
,
)
offs_buf_k
=
(
k
_loc
[:,
None
]
*
stride_buf_kbs
kv
_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_k_end
)
&
(
offs
_d
[
None
,
:]
<
Lk
),
mask
=
(
offs_n
[:,
None
]
<
split_k
v
_end
)
&
(
mask
_d
[
None
,
:]),
other
=
0.0
,
).
to
(
reduce_dtype
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
<
split_k_end
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
@
triton
.
jit
def
_fwd_kernel_stage2
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
acc
*=
re_scale
acc
+=
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
0
)
e_max
=
n_e_max
e_max
=
float
(
"-inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
)
qk
=
tl
.
load
(
logits
+
cur_head
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
),
mask
=
start_n
+
offs_n
<
cur_batch_seq_len
,
other
=
float
(
"-inf"
),
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
,
mask
=
(
mask_dv
),
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
0
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
acc
=
acc
*
old_scale
+
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
off_o
=
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_d
<
Lv
)
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
)
def
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
):
BLOCK
=
32
SPLIT
_K
=
8
BLOCK
=
64
NUM_KV_
SPLIT
S
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
SPLIT
_K
)
grid
=
(
batch
,
head_num
,
NUM_KV_
SPLIT
S
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
...
...
@@ -202,14 +187,15 @@ def _decode_att_m_fwd(
num_warps
=
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
...
...
@@ -217,56 +203,20 @@ def _decode_att_m_fwd(
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
2
,
Lk
=
Lk
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
):
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
grid
=
(
batch
,
head
,
1
)
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
num_warps
=
1
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage2
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
3
,
Lv
=
Lv
,
)
...
...
@@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd(
def
_fwd_grouped_kernel_stage1
(
Q
,
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
Att_Out
,
stride_req_to_tokens_b
,
...
...
@@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1(
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
split_kv_id
=
tl
.
program_id
(
2
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
...
...
@@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1(
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT
_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_
SPLIT
S
)
split_k
v
_start
=
kv_len_per_split
*
split_k
v
_id
split_k
v
_end
=
tl
.
minimum
(
split_k
v
_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k
_loc
=
tl
.
load
(
kv
_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
mask
=
offs_n
<
split_k
v
_end
,
other
=
0
,
)
offs_buf_k
=
(
k
_loc
[
None
,
:]
*
stride_buf_kbs
kv
_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs
_d
[:,
None
]
<
Lk
),
mask
=
(
offs_n
[
None
,
:]
<
split_k
v
_end
)
&
(
mask
_d
[:,
None
]),
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
)
)
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
k
_loc
[
None
,
:]
*
stride_buf_kbs
kv
_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
mask
=
(
offs_n
[
None
,
:]
<
split_k
v
_end
)
&
(
mask_dpe
[:,
None
])
,
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
)
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
@
triton
.
jit
def
_fwd_grouped_kernel_stage2
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
e_max
=
n_e_max
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
offs_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
offs_dv
[
None
,
:]
)
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
tl
.
store
(
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
qk
=
tl
.
load
(
logits
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
offs_mid_o_1
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
tl
.
store
(
Att_Out
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
):
BLOCK
=
64
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
BLOCK_DMODEL
=
512
...
...
@@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd(
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
SPLIT
_K
=
8
BLOCK_H
=
16
NUM_KV_
SPLIT
S
=
num_kv_splits
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
SPLIT
_K
,
NUM_KV_
SPLIT
S
,
)
num_warps
=
4
extra_kargs
=
{}
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
...
...
@@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd(
_fwd_grouped_kernel_stage1
[
grid
](
q
,
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
...
...
@@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd(
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_warps
=
4
,
num_stages
=
2
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
)
def
_decode_grouped_softmax_reducev_fwd
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
O
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
BLOCK
=
128
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
num_warps
=
8
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
if
is_hip_
:
...
...
@@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd(
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_grouped_kernel_stage2
[
grid
](
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_warps
=
4
,
num_stages
=
2
,
**
extra_kargs
,
)
...
...
@@ -597,34 +554,27 @@ def decode_attention_fwd_normal(
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
_decode_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
def
decode_attention_fwd_grouped
(
...
...
@@ -634,34 +584,27 @@ def decode_attention_fwd_grouped(
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_grouped_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
def
decode_attention_fwd
(
...
...
@@ -675,9 +618,11 @@ def decode_attention_fwd(
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
...
...
@@ -689,10 +634,10 @@ def decode_attention_fwd(
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
...
...
@@ -705,10 +650,10 @@ def decode_attention_fwd(
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
python/sglang/srt/server_args.py
View file @
7dc66fcb
...
...
@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_num_kv_splits
:
int
=
8
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
...
...
@@ -753,6 +754,12 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
)
parser
.
add_argument
(
"--triton-attention-num-kv-splits"
,
type
=
int
,
default
=
ServerArgs
.
triton_attention_num_kv_splits
,
help
=
"The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8."
,
)
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
type
=
int
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
7dc66fcb
...
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
dtype
=
dtype
,
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
,
attn_logits
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
...
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
seq_len
=
1
0
# This represents the number of tokens already in the sequence
seq_len
=
1
28
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
...
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
dtype
=
dtype
,
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
decode_attention_fwd_grouped
(
q
,
k_buffer
,
...
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
attn_logits
1
,
seq_len
,
num_kv_splits
,
sm_scale
,
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
)
print
(
cos_sim
.
item
())
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
configs
=
[
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
128
,
1
,
80
,
80
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment