Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c5f86501
Unverified
Commit
c5f86501
authored
Nov 23, 2024
by
Ke Bao
Committed by
GitHub
Nov 23, 2024
Browse files
Fix grid size in Triton decoding kernel (#2134)
parent
d98fa1e9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
38 deletions
+34
-38
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+34
-38
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c5f86501
...
@@ -50,12 +50,13 @@ def _fwd_kernel_stage1(
...
@@ -50,12 +50,13 @@ def _fwd_kernel_stage1(
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
s
tart_n
=
tl
.
program_id
(
2
)
s
plit_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_kv_head
=
cur_head
//
kv_group_num
cur_kv_head
=
cur_head
//
kv_group_num
...
@@ -65,22 +66,18 @@ def _fwd_kernel_stage1(
...
@@ -65,22 +66,18 @@ def _fwd_kernel_stage1(
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
).
to
(
reduce_dtype
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
block_stard_index
=
start_n
*
BLOCK_N
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
q
=
tl
.
load
(
Q
+
off_q
+
start_mark
).
to
(
reduce_dtype
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
_new
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
_new
<
cur_batch_end_index
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
other
=
0
,
)
)
offs_buf_k
=
(
offs_buf_k
=
(
...
@@ -90,7 +87,7 @@ def _fwd_kernel_stage1(
...
@@ -90,7 +87,7 @@ def _fwd_kernel_stage1(
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
_new
[:,
None
]
<
cur_batch_end_index
)
&
(
offs_d
[
None
,
:]
<
Lk
),
mask
=
(
offs_n
[:,
None
]
<
split_k_end
)
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
).
to
(
reduce_dtype
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
...
@@ -100,7 +97,7 @@ def _fwd_kernel_stage1(
...
@@ -100,7 +97,7 @@ def _fwd_kernel_stage1(
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
_new
<
cur_batch_end_index
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
<
split_k_end
)
@
triton
.
jit
@
triton
.
jit
...
@@ -189,11 +186,12 @@ def _decode_att_m_fwd(
...
@@ -189,11 +186,12 @@ def _decode_att_m_fwd(
logit_cap
,
logit_cap
,
):
):
BLOCK
=
32
BLOCK
=
32
SPLIT_K
=
8
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
triton
.
cdiv
(
max_len_in_batch
,
BLOC
K
)
)
grid
=
(
batch
,
head_num
,
SPLIT_
K
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -221,6 +219,7 @@ def _decode_att_m_fwd(
...
@@ -221,6 +219,7 @@ def _decode_att_m_fwd(
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SPLIT_K
=
SPLIT_K
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
...
@@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1(
...
@@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
s
tart_n
=
tl
.
program_id
(
2
)
s
plit_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
...
@@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1(
...
@@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1(
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
block_stard_index
=
start_n
*
BLOCK_N
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
offs_q
+
start_mark
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
)
).
to
(
reduce_dtype
)
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
_new
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
_new
<
cur_batch_end_index
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
other
=
0
,
)
)
offs_buf_k
=
(
offs_buf_k
=
(
...
@@ -348,14 +345,11 @@ def _fwd_grouped_kernel_stage1(
...
@@ -348,14 +345,11 @@ def _fwd_grouped_kernel_stage1(
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
)
&
(
offs_d
[:,
None
]
<
Lk
),
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
,
mask
=
mask_h
[:,
None
]).
to
(
reduce_dtype
)
offs_buf_kpe
=
(
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
...
@@ -363,7 +357,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -363,7 +357,7 @@ def _fwd_grouped_kernel_stage1(
)
)
kpe
=
tl
.
load
(
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
other
=
0.0
,
other
=
0.0
,
).
to
(
reduce_dtype
)
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
...
@@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1(
tl
.
store
(
tl
.
store
(
Att_Out
+
offs_o
,
Att_Out
+
offs_o
,
qk
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
),
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
)
...
@@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd(
...
@@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
SPLIT_K
=
8
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
max_len_in_batch
,
BLOCK
)
,
SPLIT_K
,
)
)
num_warps
=
4
num_warps
=
4
...
@@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment