Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7dc66fcb
"docs/vscode:/vscode.git/clone" did not exist on "b59654544bbaf5c6040e670397351abe0e543a75"
Unverified
Commit
7dc66fcb
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Optimize Triton decoding kernel for long context (#2394)
parent
1f09e84b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
328 additions
and
360 deletions
+328
-360
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+13
-8
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+287
-342
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+21
-10
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7dc66fcb
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
forward_batch
.
seq_lens_sum
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
dtype
=
self
.
reduce_dtype
,
forward_batch
.
batch_size
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
self
.
num_head
,
dtype
=
torch
.
float32
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
max_seq_len
,
max_seq_len
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
7dc66fcb
...
@@ -17,8 +17,8 @@ It supports page size = 1.
...
@@ -17,8 +17,8 @@ It supports page size = 1.
"""
"""
# Adapted from
# Adapted from
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_nopad_att
1.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage
1.py
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_softmax_and_reducev
.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage2
.py
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -37,10 +37,10 @@ def tanh(x):
...
@@ -37,10 +37,10 @@ def tanh(x):
def
_fwd_kernel_stage1
(
def
_fwd_kernel_stage1
(
Q
,
Q
,
K_Buffer
,
K_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_req_to_tokens_b
,
...
@@ -48,152 +48,137 @@ def _fwd_kernel_stage1(
...
@@ -48,152 +48,137 @@ def _fwd_kernel_stage1(
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
split_k_id
=
tl
.
program_id
(
2
)
split_k
v
_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_kv_head
=
cur_head
//
kv_group_num
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
).
to
(
reduce_dtype
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT
_K
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_
SPLIT
S
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k
v
_start
=
kv_len_per_split
*
split_k
v
_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_k
v
_end
=
tl
.
minimum
(
split_k
v
_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
e_max
=
-
float
(
"inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k
_loc
=
tl
.
load
(
kv
_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
mask
=
offs_n
<
split_k
v
_end
,
other
=
0
,
other
=
0
,
)
)
offs_buf_k
=
(
offs_buf_k
=
(
k
_loc
[:,
None
]
*
stride_buf_kbs
kv
_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:]
+
offs_d
[
None
,
:]
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_k_end
)
&
(
offs
_d
[
None
,
:]
<
Lk
),
mask
=
(
offs_n
[:,
None
]
<
split_k
v
_end
)
&
(
mask
_d
[
None
,
:]),
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
<
split_k_end
)
@
triton
.
jit
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
def
_fwd_kernel_stage2
(
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head
//
kv_group_num
offs_buf_v
=
(
kv_loc
[:,
None
]
*
stride_buf_vbs
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
+
cur_kv_head
*
stride_buf_vh
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
+
offs_dv
[
None
,
:]
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
acc
*=
re_scale
acc
+=
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
0
)
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
n_e_max
e_max
=
float
(
"-inf"
)
offs_mid_o
=
(
e_sum
=
0.0
cur_batch
*
stride_mid_ob
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
+
offs_dv
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
)
qk
=
tl
.
load
(
tl
.
store
(
logits
Att_Out
+
offs_mid_o
,
+
cur_head
*
stride_logic_h
acc
/
e_sum
,
+
(
cur_batch_start_loc
+
start_n
+
offs_n
),
mask
=
(
mask_dv
),
mask
=
start_n
+
offs_n
<
cur_batch_seq_len
,
other
=
float
(
"-inf"
),
)
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
offs_mid_o_1
=
(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
cur_batch
*
stride_mid_ob
p
=
tl
.
exp
(
qk
-
n_e_max
)
+
cur_head
*
stride_mid_oh
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
0
)
+
split_kv_id
*
stride_mid_os
v
=
tl
.
load
(
+
Lv
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
)
acc
=
acc
*
old_scale
+
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
tl
.
store
(
off_o
=
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
Att_Out
+
offs_mid_o_1
,
out_ptrs
=
Out
+
off_o
e_max
+
tl
.
log
(
e_sum
),
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_d
<
Lv
)
)
)
def
_decode_att_m_fwd
(
def
_decode_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
32
BLOCK
=
64
SPLIT
_K
=
8
NUM_KV_
SPLIT
S
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
SPLIT
_K
)
grid
=
(
batch
,
head_num
,
NUM_KV_
SPLIT
S
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -202,14 +187,15 @@ def _decode_att_m_fwd(
...
@@ -202,14 +187,15 @@ def _decode_att_m_fwd(
num_warps
=
2
num_warps
=
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage1
[
grid
](
_fwd_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
Req_to_tokens
.
stride
(
0
),
...
@@ -217,56 +203,20 @@ def _decode_att_m_fwd(
...
@@ -217,56 +203,20 @@ def _decode_att_m_fwd(
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
):
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
grid
=
(
batch
,
head
,
1
)
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
num_warps
=
1
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage2
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
3
,
Lv
=
Lv
,
Lv
=
Lv
,
)
)
...
@@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd(
...
@@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd(
def
_fwd_grouped_kernel_stage1
(
def
_fwd_grouped_kernel_stage1
(
Q
,
Q
,
K_Buffer
,
K_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_req_to_tokens_b
,
...
@@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1(
...
@@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1(
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
split_kv_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
if
BLOCK_H
<
kv_group_num
:
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
...
@@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1(
...
@@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1(
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT
_K
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_
SPLIT
S
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k
v
_start
=
kv_len_per_split
*
split_k
v
_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_k
v
_end
=
tl
.
minimum
(
split_k
v
_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k
_loc
=
tl
.
load
(
kv
_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
mask
=
offs_n
<
split_k
v
_end
,
other
=
0
,
other
=
0
,
)
)
offs_buf_k
=
(
offs_buf_k
=
(
k
_loc
[
None
,
:]
*
stride_buf_kbs
kv
_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
+
offs_d
[:,
None
]
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs
_d
[:,
None
]
<
Lk
),
mask
=
(
offs_n
[
None
,
:]
<
split_k
v
_end
)
&
(
mask
_d
[:,
None
]),
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
)
qk
=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
)
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
offs_buf_kpe
=
(
k
_loc
[
None
,
:]
*
stride_buf_kbs
kv
_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
+
offs_dpe
[:,
None
]
)
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
mask
=
(
offs_n
[
None
,
:]
<
split_k
v
_end
)
&
(
mask_dpe
[:,
None
])
,
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
)
)
qk
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
qk
=
tl
.
where
(
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
)
tl
.
store
(
offs_buf_v
=
(
Att_Out
+
offs_o
,
kv_loc
[:,
None
]
*
stride_buf_vbs
qk
,
+
cur_kv_head
*
stride_buf_vh
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
@
triton
.
jit
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
def
_fwd_grouped_kernel_stage2
(
e_max
=
n_e_max
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
offs_mid_o
=
(
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
cur_batch
*
stride_mid_ob
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
+
offs_dv
[
None
,
:]
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
)
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
tl
.
store
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
)
qk
=
tl
.
load
(
offs_mid_o_1
=
(
logits
+
offs_qk
,
cur_batch
*
stride_mid_ob
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
+
cur_head
*
stride_mid_oh
other
=
float
(
"-inf"
),
+
split_kv_id
*
stride_mid_os
+
Lv
)
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
tl
.
store
(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
Att_Out
+
offs_mid_o_1
,
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_max
+
tl
.
log
(
e_sum
),
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
mask
=
mask_h
,
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_grouped_att_m_fwd
(
def
_decode_grouped_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
64
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
@@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd(
...
@@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd(
else
:
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
BLOCK_H
=
16
SPLIT
_K
=
8
NUM_KV_
SPLIT
S
=
num_kv_splits
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
SPLIT
_K
,
NUM_KV_
SPLIT
S
,
)
)
num_warps
=
4
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip_
:
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
...
@@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd(
...
@@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd(
_fwd_grouped_kernel_stage1
[
grid
](
_fwd_grouped_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
Req_to_tokens
.
stride
(
0
),
...
@@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd(
...
@@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd(
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
4
,
num_stages
=
1
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
**
extra_kargs
,
)
)
def
_decode_grouped_softmax_reducev_fwd
(
@
triton
.
jit
logits
,
def
_fwd_kernel_stage2
(
v_buffer
,
Mid_O
,
o
,
O
,
req_to_tokens
,
stride_mid_ob
,
b_req_idx
,
stride_mid_oh
,
b_start_loc
,
stride_mid_os
,
b_seq_len
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
BLOCK
=
128
cur_batch
=
tl
.
program_id
(
0
)
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
cur_head
=
tl
.
program_id
(
1
)
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
num_warps
=
8
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip_
:
if
is_hip_
:
...
@@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd(
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_grouped_kernel_stage2
[
grid
](
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
v_buffer
,
o
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
logits
.
stride
(
1
),
v_buffer
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
kv_group_num
=
kv_group_num
,
BLOCK_DV
=
BLOCK_DV
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_warps
=
4
,
num_stages
=
1
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
)
)
...
@@ -597,34 +554,27 @@ def decode_attention_fwd_normal(
...
@@ -597,34 +554,27 @@ def decode_attention_fwd_normal(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -634,34 +584,27 @@ def decode_attention_fwd_grouped(
...
@@ -634,34 +584,27 @@ def decode_attention_fwd_grouped(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_grouped_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -675,9 +618,11 @@ def decode_attention_fwd(
...
@@ -675,9 +618,11 @@ def decode_attention_fwd(
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -689,10 +634,10 @@ def decode_attention_fwd(
...
@@ -689,10 +634,10 @@ def decode_attention_fwd(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
...
@@ -705,10 +650,10 @@ def decode_attention_fwd(
...
@@ -705,10 +650,10 @@ def decode_attention_fwd(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
python/sglang/srt/server_args.py
View file @
7dc66fcb
...
@@ -141,6 +141,7 @@ class ServerArgs:
...
@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection
:
bool
=
False
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_num_kv_splits
:
int
=
8
num_continuous_decode_steps
:
int
=
1
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
delete_ckpt_after_loading
:
bool
=
False
...
@@ -753,6 +754,12 @@ class ServerArgs:
...
@@ -753,6 +754,12 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--triton-attention-num-kv-splits"
,
type
=
int
,
default
=
ServerArgs
.
triton_attention_num_kv_splits
,
help
=
"The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
"--num-continuous-decode-steps"
,
type
=
int
,
type
=
int
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
7dc66fcb
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
seq_len
=
1
0
# This represents the number of tokens already in the sequence
seq_len
=
1
28
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
decode_attention_fwd_grouped
(
decode_attention_fwd_grouped
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
1
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
)
)
print
(
cos_sim
.
item
())
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
def
test_grouped_decode_attention
(
self
):
configs
=
[
configs
=
[
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
128
,
1
,
80
,
80
),
(
2
,
128
,
1
,
80
,
80
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment