Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7dc66fcb
Unverified
Commit
7dc66fcb
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Optimize Triton decoding kernel for long context (#2394)
parent
1f09e84b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
328 additions
and
360 deletions
+328
-360
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+13
-8
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+287
-342
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+21
-10
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7dc66fcb
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
forward_batch
.
seq_lens_sum
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
dtype
=
self
.
reduce_dtype
,
forward_batch
.
batch_size
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
self
.
num_head
,
dtype
=
torch
.
float32
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
max_seq_len
,
max_seq_len
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
7dc66fcb
...
@@ -17,8 +17,8 @@ It supports page size = 1.
...
@@ -17,8 +17,8 @@ It supports page size = 1.
"""
"""
# Adapted from
# Adapted from
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_nopad_att
1.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage
1.py
# https://github.com/ModelTC/lightllm/blob/
f2a54f0912293f683bf1d1695fd12c4098a5bf82
/lightllm/models/
llama
/triton_kernel/
token_attention_softmax_and_reducev
.py
# https://github.com/ModelTC/lightllm/blob/
96353e868a840db4d103138caf15ed9dbea8c186
/lightllm/models/
deepseek2
/triton_kernel/
gqa_flash_decoding_stage2
.py
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -37,10 +37,10 @@ def tanh(x):
...
@@ -37,10 +37,10 @@ def tanh(x):
def
_fwd_kernel_stage1
(
def
_fwd_kernel_stage1
(
Q
,
Q
,
K_Buffer
,
K_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_req_to_tokens_b
,
...
@@ -48,152 +48,137 @@ def _fwd_kernel_stage1(
...
@@ -48,152 +48,137 @@ def _fwd_kernel_stage1(
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
split_k_id
=
tl
.
program_id
(
2
)
split_k
v
_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_kv_head
=
cur_head
//
kv_group_num
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
).
to
(
reduce_dtype
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
offs_buf_k
=
(
k_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_k_end
)
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
<
split_k_end
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
e_max
=
-
float
(
"inf"
)
e_sum
=
0.0
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
offs_buf_k
=
(
kv_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_d
[
None
,
:]),
other
=
0.0
,
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
@
triton
.
jit
if
logit_cap
>
0
:
def
_fwd_kernel_stage2
(
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
logits
,
V_Buffer
,
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head
//
kv_group_num
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_buf_v
=
(
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
kv_loc
[:,
None
]
*
stride_buf_vbs
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
acc
*=
re_scale
acc
+=
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
0
)
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
n_e_max
e_max
=
float
(
"-inf"
)
offs_mid_o
=
(
e_sum
=
0.0
cur_batch
*
stride_mid_ob
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
+
offs_dv
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
)
qk
=
tl
.
load
(
tl
.
store
(
logits
Att_Out
+
offs_mid_o
,
+
cur_head
*
stride_logic_h
acc
/
e_sum
,
+
(
cur_batch_start_loc
+
start_n
+
offs_n
),
mask
=
(
mask_dv
),
mask
=
start_n
+
offs_n
<
cur_batch_seq_len
,
other
=
float
(
"-inf"
),
)
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
0
),
e_max
)
offs_mid_o_1
=
(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
cur_batch
*
stride_mid_ob
p
=
tl
.
exp
(
qk
-
n_e_max
)
+
cur_head
*
stride_mid_oh
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
0
)
+
split_kv_id
*
stride_mid_os
v
=
tl
.
load
(
+
Lv
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
)
acc
=
acc
*
old_scale
+
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
tl
.
store
(
off_o
=
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
Att_Out
+
offs_mid_o_1
,
out_ptrs
=
Out
+
off_o
e_max
+
tl
.
log
(
e_sum
),
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_d
<
Lv
)
)
)
def
_decode_att_m_fwd
(
def
_decode_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
32
BLOCK
=
64
SPLIT
_K
=
8
NUM_KV_
SPLIT
S
=
num_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
SPLIT
_K
)
grid
=
(
batch
,
head_num
,
NUM_KV_
SPLIT
S
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -202,14 +187,15 @@ def _decode_att_m_fwd(
...
@@ -202,14 +187,15 @@ def _decode_att_m_fwd(
num_warps
=
2
num_warps
=
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage1
[
grid
](
_fwd_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
Req_to_tokens
.
stride
(
0
),
...
@@ -217,56 +203,20 @@ def _decode_att_m_fwd(
...
@@ -217,56 +203,20 @@ def _decode_att_m_fwd(
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
):
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
grid
=
(
batch
,
head
,
1
)
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
num_warps
=
1
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage2
[
grid
](
logits
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
3
,
Lv
=
Lv
,
Lv
=
Lv
,
)
)
...
@@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd(
...
@@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd(
def
_fwd_grouped_kernel_stage1
(
def
_fwd_grouped_kernel_stage1
(
Q
,
Q
,
K_Buffer
,
K_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_req_to_tokens_b
,
...
@@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1(
...
@@ -286,23 +236,27 @@ def _fwd_grouped_kernel_stage1(
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_kh
,
att_stride_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT
_K
:
tl
.
constexpr
,
NUM_KV_
SPLIT
S
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
split_k_id
=
tl
.
program_id
(
2
)
split_kv_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
if
BLOCK_H
<
kv_group_num
:
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
...
@@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1(
...
@@ -313,171 +267,136 @@ def _fwd_grouped_kernel_stage1(
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
mask_dpe
=
offs_dpe
<
Lk
off_qpe
=
(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
offs_buf_k
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
split_kv_start
=
kv_len_per_split
*
split_kv_id
other
=
0.0
,
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
if
BLOCK_DPE
>
0
:
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
offs_buf_kpe
=
(
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
k_loc
[
None
,
:]
*
stride_buf_kbs
if
split_kv_end
>
split_kv_start
:
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
offs_buf_k
=
(
kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_d
pe
[:,
None
]
+
offs_d
[:,
None
]
)
)
k
pe
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
pe
,
K_Buffer
+
offs_buf_k
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
mask
=
(
offs_n
[
None
,
:]
<
split_k
v
_end
)
&
(
mask_d
[:,
None
])
,
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
=
tl
.
dot
(
q
,
k
.
to
(
q
.
dtype
))
qk
*=
sm_scale
if
BLOCK_DPE
>
0
:
offs_buf_kpe
=
(
if
logit_cap
>
0
:
kv_loc
[
None
,
:]
*
stride_buf_kbs
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
offs_o
=
cur_head
[:,
None
]
*
att_stride_h
+
(
)
cur_batch_in_all_start_index
+
offs_n
[
None
,
:]
kpe
=
tl
.
load
(
)
K_Buffer
+
offs_buf_kpe
,
mask
=
(
offs_n
[
None
,
:]
<
split_kv_end
)
&
(
mask_dpe
[:,
None
]),
tl
.
store
(
other
=
0.0
,
Att_Out
+
offs_o
,
)
qk
,
qk
+=
tl
.
dot
(
qpe
,
kpe
.
to
(
qpe
.
dtype
))
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
qk
*=
sm_scale
)
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
@
triton
.
jit
def
_fwd_grouped_kernel_stage2
(
qk
=
tl
.
where
(
logits
,
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
V_Buffer
,
)
Out
,
Req_to_tokens
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
stride_logic_h
,
stride_buf_vbs
,
stride_buf_vh
,
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
kv_group_num
:
tl
.
constexpr
,
q_head_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_buf_v
=
(
cur_batch_start_loc
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
kv_loc
[:,
None
]
*
stride_buf_vbs
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
(
offs_n
[:,
None
]
<
split_kv_end
)
&
(
mask_dv
[
None
,
:]),
other
=
0.0
,
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
acc
*=
re_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
.
to
(
v
.
dtype
),
v
)
offs_buf_v
=
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
e_sum
=
e_sum
*
re_scale
+
tl
.
sum
(
p
,
1
)
v_ptrs
=
V_Buffer
+
offs_buf_v
e_max
=
n_e_max
e_max
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
offs_mid_o
=
(
e_sum
=
tl
.
zeros
([
BLOCK_H
],
dtype
=
tl
.
float32
)
cur_batch
*
stride_mid_ob
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
+
cur_head
[:,
None
]
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
for
start_n
in
range
(
0
,
cur_batch_seq_len
,
BLOCK_N
):
+
offs_dv
[
None
,
:]
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
v_index
=
tl
.
load
(
Req_to_tokens
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
0
,
)
)
offs_qk
=
cur_head
[:,
None
]
*
stride_logic_h
+
(
tl
.
store
(
cur_batch_start_loc
+
start_n
+
offs_n
[
None
,
:]
Att_Out
+
offs_mid_o
,
acc
/
e_sum
[:,
None
],
mask
=
(
mask_h
[:,
None
])
&
(
mask_dv
[
None
,
:]),
)
)
qk
=
tl
.
load
(
offs_mid_o_1
=
(
logits
+
offs_qk
,
cur_batch
*
stride_mid_ob
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
+
cur_head
*
stride_mid_oh
other
=
float
(
"-inf"
),
+
split_kv_id
*
stride_mid_os
+
Lv
)
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
tl
.
store
(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
Att_Out
+
offs_mid_o_1
,
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_max
+
tl
.
log
(
e_sum
),
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
mask
=
mask_h
,
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_grouped_att_m_fwd
(
def
_decode_grouped_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
BLOCK
=
64
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
...
@@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd(
...
@@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd(
else
:
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
BLOCK_H
=
16
SPLIT
_K
=
8
NUM_KV_
SPLIT
S
=
num_kv_splits
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
SPLIT
_K
,
NUM_KV_
SPLIT
S
,
)
)
num_warps
=
4
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip_
:
if
is_hip_
:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
...
@@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd(
...
@@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd(
_fwd_grouped_kernel_stage1
[
grid
](
_fwd_grouped_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
Req_to_tokens
,
B_req_idx
,
B_req_idx
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
Req_to_tokens
.
stride
(
0
),
...
@@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd(
...
@@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd(
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
att_out
.
stride
(
2
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT
_K
,
NUM_KV_SPLITS
=
NUM_KV_
SPLIT
S
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
4
,
num_stages
=
1
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
Lv
=
Lv
,
**
extra_kargs
,
**
extra_kargs
,
)
)
def
_decode_grouped_softmax_reducev_fwd
(
@
triton
.
jit
logits
,
def
_fwd_kernel_stage2
(
v_buffer
,
Mid_O
,
o
,
O
,
req_to_tokens
,
stride_mid_ob
,
b_req_idx
,
stride_mid_oh
,
b_start_loc
,
stride_mid_os
,
b_seq_len
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
BLOCK
=
128
cur_batch
=
tl
.
program_id
(
0
)
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
cur_head
=
tl
.
program_id
(
1
)
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
mask_d
=
offs_d
<
Lv
e_sum
=
0.0
e_max
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
num_warps
=
8
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
mask
=
mask_d
,
)
def
_decode_softmax_reducev_fwd
(
logits
,
q
,
o
,
v_buffer
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM_KV_SPLITS
=
num_kv_splits
extra_kargs
=
{}
extra_kargs
=
{}
if
is_hip_
:
if
is_hip_
:
...
@@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd(
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_grouped_kernel_stage2
[
grid
](
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
v_buffer
,
o
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
v_buffer
.
stride
(
0
),
logits
.
stride
(
1
),
v_buffer
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
kv_group_num
=
kv_group_num
,
BLOCK_DV
=
BLOCK_DV
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_warps
=
4
,
num_stages
=
1
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
)
)
...
@@ -597,34 +554,27 @@ def decode_attention_fwd_normal(
...
@@ -597,34 +554,27 @@ def decode_attention_fwd_normal(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -634,34 +584,27 @@ def decode_attention_fwd_grouped(
...
@@ -634,34 +584,27 @@ def decode_attention_fwd_grouped(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_grouped_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -675,9 +618,11 @@ def decode_attention_fwd(
...
@@ -675,9 +618,11 @@ def decode_attention_fwd(
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
num_kv_splits
==
attn_logits
.
shape
[
2
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -689,10 +634,10 @@ def decode_attention_fwd(
...
@@ -689,10 +634,10 @@ def decode_attention_fwd(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
...
@@ -705,10 +650,10 @@ def decode_attention_fwd(
...
@@ -705,10 +650,10 @@ def decode_attention_fwd(
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
python/sglang/srt/server_args.py
View file @
7dc66fcb
...
@@ -141,6 +141,7 @@ class ServerArgs:
...
@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection
:
bool
=
False
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_num_kv_splits
:
int
=
8
num_continuous_decode_steps
:
int
=
1
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
delete_ckpt_after_loading
:
bool
=
False
...
@@ -753,6 +754,12 @@ class ServerArgs:
...
@@ -753,6 +754,12 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--triton-attention-num-kv-splits"
,
type
=
int
,
default
=
ServerArgs
.
triton_attention_num_kv_splits
,
help
=
"The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
"--num-continuous-decode-steps"
,
type
=
int
,
type
=
int
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
7dc66fcb
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
seq_len
=
1
0
# This represents the number of tokens already in the sequence
seq_len
=
1
28
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
decode_attention_fwd_grouped
(
decode_attention_fwd_grouped
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
1
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
)
)
print
(
cos_sim
.
item
())
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
def
test_grouped_decode_attention
(
self
):
configs
=
[
configs
=
[
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
128
,
1
,
80
,
80
),
(
2
,
128
,
1
,
80
,
80
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment