Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c5f86501
"megatron/legacy/data/biencoder_dataset_utils.py" did not exist on "8eff2a996736d1632595b4420cac008c85e39c78"
Unverified
Commit
c5f86501
authored
Nov 23, 2024
by
Ke Bao
Committed by
GitHub
Nov 23, 2024
Browse files
Fix grid size in Triton decoding kernel (#2134)
parent
d98fa1e9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
38 deletions
+34
-38
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+34
-38
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c5f86501
...
...
@@ -50,12 +50,13 @@ def _fwd_kernel_stage1(
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
s
tart_n
=
tl
.
program_id
(
2
)
s
plit_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_kv_head
=
cur_head
//
kv_group_num
...
...
@@ -65,22 +66,18 @@ def _fwd_kernel_stage1(
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
).
to
(
reduce_dtype
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_stard_index
=
start_n
*
BLOCK_N
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
off_q
+
start_mark
).
to
(
reduce_dtype
)
offs_n_new
=
cur_batch_start_index
+
offs_n
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
_new
,
mask
=
offs_n
_new
<
cur_batch_end_index
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
offs_buf_k
=
(
...
...
@@ -90,7 +87,7 @@ def _fwd_kernel_stage1(
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
_new
[:,
None
]
<
cur_batch_end_index
)
&
(
offs_d
[
None
,
:]
<
Lk
),
mask
=
(
offs_n
[:,
None
]
<
split_k_end
)
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
...
...
@@ -100,7 +97,7 @@ def _fwd_kernel_stage1(
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
_new
<
cur_batch_end_index
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n
<
split_k_end
)
@
triton
.
jit
...
...
@@ -189,11 +186,12 @@ def _decode_att_m_fwd(
logit_cap
,
):
BLOCK
=
32
SPLIT_K
=
8
Lk
=
k_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
triton
.
cdiv
(
max_len_in_batch
,
BLOC
K
)
)
grid
=
(
batch
,
head_num
,
SPLIT_
K
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
...
...
@@ -221,6 +219,7 @@ def _decode_att_m_fwd(
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
SPLIT_K
=
SPLIT_K
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
...
...
@@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
s
tart_n
=
tl
.
program_id
(
2
)
s
plit_k_id
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
...
...
@@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1(
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
),
other
=
0.0
).
to
(
reduce_dtype
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
mask_h
[:,
None
],
other
=
0.0
).
to
(
reduce_dtype
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
SPLIT_K
)
split_k_start
=
kv_len_per_split
*
split_k_id
split_k_end
=
tl
.
minimum
(
split_k_start
+
kv_len_per_split
,
cur_batch_seq_len
)
block_stard_index
=
start_n
*
BLOCK_N
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
offs_q
+
start_mark
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
)
).
to
(
reduce_dtype
)
offs_n_new
=
cur_batch_start_index
+
offs_n
for
start_n
in
range
(
split_k_start
,
split_k_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
_new
,
mask
=
offs_n
_new
<
cur_batch_end_index
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n
,
mask
=
offs_n
<
split_k_end
,
other
=
0
,
)
offs_buf_k
=
(
...
...
@@ -348,14 +345,11 @@ def _fwd_grouped_kernel_stage1(
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
)
&
(
offs_d
[:,
None
]
<
Lk
),
mask
=
(
offs_n
[
None
,
:]
<
split_k_end
)
&
(
offs_d
[:,
None
]
<
Lk
),
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
,
mask
=
mask_h
[:,
None
]).
to
(
reduce_dtype
)
offs_buf_kpe
=
(
k_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
...
...
@@ -363,7 +357,7 @@ def _fwd_grouped_kernel_stage1(
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
,
mask
=
offs_n
[
None
,
:]
<
split_k_end
,
other
=
0.0
,
).
to
(
reduce_dtype
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
...
...
@@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1(
tl
.
store
(
Att_Out
+
offs_o
,
qk
,
mask
=
mask_h
[:,
None
]
&
(
offs_n
_new
[
None
,
:]
<
cur_batch_end_index
),
mask
=
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_k_end
),
)
...
...
@@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
)))
SPLIT_K
=
8
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
max_len_in_batch
,
BLOCK
)
,
SPLIT_K
,
)
num_warps
=
4
...
...
@@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
SPLIT_K
=
SPLIT_K
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment