Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2f9bd0fa
Unverified
Commit
2f9bd0fa
authored
Dec 14, 2024
by
Ke Bao
Committed by
GitHub
Dec 14, 2024
Browse files
Fix correctness issue for triton decoding kernel (#2479)
parent
5282a473
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
18 deletions
+30
-18
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+24
-14
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+6
-4
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
2f9bd0fa
...
@@ -32,7 +32,7 @@ is_hip_ = is_hip()
...
@@ -32,7 +32,7 @@ is_hip_ = is_hip()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
logger
.
warn
(
logger
.
warn
ing
(
"The following error message 'operation scheduled before its operands' can be ignored."
"The following error message 'operation scheduled before its operands' can be ignored."
)
)
...
@@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -474,6 +474,7 @@ def _decode_grouped_att_m_fwd(
def
_fwd_kernel_stage2
(
def
_fwd_kernel_stage2
(
Mid_O
,
Mid_O
,
O
,
O
,
B_Seqlen
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
...
@@ -486,6 +487,8 @@ def _fwd_kernel_stage2(
...
@@ -486,6 +487,8 @@ def _fwd_kernel_stage2(
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
mask_d
=
offs_d
<
Lv
...
@@ -497,19 +500,24 @@ def _fwd_kernel_stage2(
...
@@ -497,19 +500,24 @@ def _fwd_kernel_stage2(
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
tv
=
tl
.
load
(
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
split_kv_start
=
kv_len_per_split
*
split_kv_id
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
if
split_kv_end
>
split_kv_start
:
acc
*=
old_scale
tv
=
tl
.
load
(
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
acc
+=
exp_logic
*
tv
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
e_sum
=
e_sum
*
old_scale
+
exp_logic
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
e_max
=
n_e_max
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
n_e_max
)
acc
+=
exp_logic
*
tv
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
tl
.
store
(
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
...
@@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -523,6 +531,7 @@ def _decode_softmax_reducev_fwd(
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
b_seq_len
,
num_kv_splits
,
num_kv_splits
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -541,6 +550,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
o
,
o
,
b_seq_len
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
...
@@ -580,7 +590,7 @@ def decode_attention_fwd_normal(
...
@@ -580,7 +590,7 @@ def decode_attention_fwd_normal(
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -608,7 +618,7 @@ def decode_attention_fwd_grouped(
...
@@ -608,7 +618,7 @@ def decode_attention_fwd_grouped(
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
...
test/srt/test_triton_attention_kernels.py
View file @
2f9bd0fa
...
@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -232,9 +232,9 @@ class TestTritonAttention(unittest.TestCase):
for
B
,
H_Q
,
H_KV
,
D
in
configs
:
for
B
,
H_Q
,
H_KV
,
D
in
configs
:
self
.
_test_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
)
self
.
_test_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
)
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
def
_test_grouped_decode_attention_once
(
self
,
B
,
S
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
seq_len
=
128
# This represents the number of tokens already in the sequence
seq_len
=
S
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
num_kv_splits
=
8
...
@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -300,6 +300,7 @@ class TestTritonAttention(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
def
test_grouped_decode_attention
(
self
):
seq_lens
=
[
5
,
100
,
128
,
500
]
configs
=
[
configs
=
[
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
...
@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -309,8 +310,9 @@ class TestTritonAttention(unittest.TestCase):
(
2
,
128
,
1
,
576
,
512
),
(
2
,
128
,
1
,
576
,
512
),
]
]
for
B
,
H_Q
,
H_KV
,
D
,
D_V
in
configs
:
for
S
in
seq_lens
:
self
.
_test_grouped_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
,
D_V
)
for
B
,
H_Q
,
H_KV
,
D
,
D_V
in
configs
:
self
.
_test_grouped_decode_attention_once
(
B
,
S
,
H_Q
,
H_KV
,
D
,
D_V
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment