Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c77762d5
Unverified
Commit
c77762d5
authored
Oct 28, 2024
by
Ke Bao
Committed by
GitHub
Oct 27, 2024
Browse files
Fix Triton decode kernel & ut (#1819)
parent
51c81e33
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
216 additions
and
40 deletions
+216
-40
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+101
-30
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
...lang/srt/layers/attention/triton_ops/prefill_attention.py
+1
-1
test/srt/run_suite.py
test/srt/run_suite.py
+2
-1
test/srt/test_triton_attention_backend.py
test/srt/test_triton_attention_backend.py
+0
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+112
-8
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c77762d5
...
@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
...
@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
start_n
=
tl
.
program_id
(
2
)
start_n
=
tl
.
program_id
(
2
)
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
reduce_dtype
=
Att_Out
.
dtype
.
element_ty
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
if
BLOCK_H
<
kv_group_num
:
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
...
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
...
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_head_id
=
tl
.
program_id
(
1
)
cur_kv_head
=
cur_head_id
//
tl
.
cdiv
(
kv_group_num
,
BLOCK_H
)
cur_head
=
cur_kv_head
*
kv_group_num
+
tl
.
arange
(
0
,
BLOCK_H
)
if
BLOCK_H
<
kv_group_num
:
mask_h
=
cur_head
<
(
cur_kv_head
+
1
)
*
kv_group_num
VALID_BLOCK_H
:
tl
.
constexpr
=
BLOCK_H
else
:
VALID_BLOCK_H
:
tl
.
constexpr
=
kv_group_num
cur_head
=
cur_head_id
*
VALID_BLOCK_H
+
tl
.
arange
(
0
,
BLOCK_H
)
mask_h
=
cur_head
<
(
cur_head_id
+
1
)
*
VALID_BLOCK_H
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
mask_h
=
mask_h
&
(
cur_head
<
q_head_num
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
...
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
))
)
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
...
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK
=
128
BLOCK
=
128
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logits
.
shape
[
0
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
logits
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
BLOCK_H
=
max
(
16
,
min
(
64
,
triton
.
next_power_of_2
(
kv_group_num
))
)
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
num_warps
=
8
num_warps
=
8
...
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
)
)
def
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
sm_scale
,
logit_cap
=
0.0
,
):
_decode_att_m_fwd
(
q
,
k_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
sm_scale
,
logit_cap
=
0.0
,
):
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
_decode_grouped_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -585,47 +670,33 @@ def decode_attention_fwd(
...
@@ -585,47 +670,33 @@ def decode_attention_fwd(
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
# MHA
# MHA
_
decode_att
_m_fwd
(
decode_att
ention_fwd_normal
(
q
,
q
,
k_buffer
,
k_buffer
,
attn_logits
,
v_buffer
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
_
decode_
grouped_att_m_fw
d
(
decode_
attention_fwd_groupe
d
(
q
,
q
,
k_buffer
,
k_buffer
,
attn_logits
,
v_buffer
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
max_len_in_batch
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_grouped_softmax_reducev_fwd
(
attn_logits
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
)
python/sglang/srt/layers/attention/triton_ops/prefill_attention.py
View file @
c77762d5
...
@@ -168,7 +168,7 @@ def _fwd_kernel(
...
@@ -168,7 +168,7 @@ def _fwd_kernel(
def
context_attention_fwd
(
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
,
is_causal
=
True
):
):
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
=
8
:
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>
8
:
BLOCK
=
128
BLOCK
=
128
else
:
else
:
BLOCK
=
64
BLOCK
=
64
...
...
test/srt/run_suite.py
View file @
c77762d5
...
@@ -26,7 +26,8 @@ suites = {
...
@@ -26,7 +26,8 @@ suites = {
"test_srt_endpoint.py"
,
"test_srt_endpoint.py"
,
"test_torch_compile.py"
,
"test_torch_compile.py"
,
"test_torchao.py"
,
"test_torchao.py"
,
"test_triton_attn_backend.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
"test_update_weights.py"
,
"test_update_weights.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
],
],
...
...
test/srt/test_triton_attn_backend.py
→
test/srt/test_triton_att
entio
n_backend.py
View file @
c77762d5
File moved
test/srt/test_triton_attention_kernels.py
View file @
c77762d5
...
@@ -3,7 +3,11 @@ import unittest
...
@@ -3,7 +3,11 @@ import unittest
import
torch
import
torch
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd_grouped
,
decode_attention_fwd_normal
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
extend_attention_fwd
,
redundant_attention
,
redundant_attention
,
...
@@ -13,7 +17,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
...
@@ -13,7 +17,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
)
)
class
Test
Extend
Attention
(
unittest
.
TestCase
):
class
Test
Triton
Attention
(
unittest
.
TestCase
):
def
_set_all_seeds
(
self
,
seed
):
def
_set_all_seeds
(
self
,
seed
):
"""Set all random seeds for reproducibility."""
"""Set all random seeds for reproducibility."""
...
@@ -127,7 +131,7 @@ class TestExtendAttention(unittest.TestCase):
...
@@ -127,7 +131,7 @@ class TestExtendAttention(unittest.TestCase):
for
value
in
attention_values
:
for
value
in
attention_values
:
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
def
_test_context_attention_once
(
self
,
head_dim
):
def
_test_context_attention_once
(
self
,
head_dim
,
is_causal
):
# Set up a simple test case
# Set up a simple test case
num_heads
=
4
num_heads
=
4
seq_lens
=
[
8
,
12
]
seq_lens
=
[
8
,
12
]
...
@@ -143,15 +147,35 @@ class TestExtendAttention(unittest.TestCase):
...
@@ -143,15 +147,35 @@ class TestExtendAttention(unittest.TestCase):
b_start_loc
=
torch
.
tensor
([
0
,
seq_lens
[
0
]],
device
=
"cuda"
)
b_start_loc
=
torch
.
tensor
([
0
,
seq_lens
[
0
]],
device
=
"cuda"
)
b_seq_len
=
torch
.
tensor
(
seq_lens
,
device
=
"cuda"
)
b_seq_len
=
torch
.
tensor
(
seq_lens
,
device
=
"cuda"
)
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_seq_len
)
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_seq_len
,
is_causal
=
is_causal
)
cu_seq_lens
=
[
0
]
*
(
len
(
seq_lens
)
+
1
)
for
i
,
seq_len
in
enumerate
(
seq_lens
):
cu_seq_lens
[
i
+
1
]
=
cu_seq_lens
[
i
]
+
seq_len
for
i
in
range
(
len
(
seq_lens
)):
start
,
end
=
cu_seq_lens
[
i
],
cu_seq_lens
[
i
+
1
]
o_torch
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
[
start
:
end
].
permute
(
1
,
0
,
2
),
k
[
start
:
end
].
permute
(
1
,
0
,
2
),
v
[
start
:
end
].
permute
(
1
,
0
,
2
),
is_causal
=
is_causal
,
).
permute
(
1
,
0
,
2
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
[
start
:
end
].
flatten
(),
o_torch
.
flatten
(),
dim
=
0
)
self
.
assertTrue
(
cos_sim
.
item
()
>
1
-
(
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
o
[
start
:
end
],
o_torch
,
atol
=
1e-2
))
def
test_context_attention
(
self
):
def
test_context_attention
(
self
):
# Here we just to ensure there is no error
# TODO: correctnesss test
head_dim
=
[
128
,
96
,
80
,
13
]
head_dim
=
[
128
,
96
,
80
,
13
]
for
dim
in
head_dim
:
for
dim
in
head_dim
:
self
.
_test_context_attention_once
(
dim
)
for
is_causal
in
[
True
,
False
]:
self
.
_test_context_attention_once
(
dim
,
is_causal
)
def
_test_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
):
def
_test_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
...
@@ -174,6 +198,12 @@ class TestExtendAttention(unittest.TestCase):
...
@@ -174,6 +198,12 @@ class TestExtendAttention(unittest.TestCase):
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
dtype
=
dtype
,
device
=
"cuda"
,
)
decode_attention_fwd
(
decode_attention_fwd
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -183,8 +213,8 @@ class TestExtendAttention(unittest.TestCase):
...
@@ -183,8 +213,8 @@ class TestExtendAttention(unittest.TestCase):
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
seq_len
,
seq_len
,
total_tokens
,
sm_scale
,
sm_scale
,
)
)
...
@@ -203,6 +233,80 @@ class TestExtendAttention(unittest.TestCase):
...
@@ -203,6 +233,80 @@ class TestExtendAttention(unittest.TestCase):
for
B
,
H_Q
,
H_KV
,
D
in
configs
:
for
B
,
H_Q
,
H_KV
,
D
in
configs
:
self
.
_test_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
)
self
.
_test_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
)
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
k_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_start_loc
=
torch
.
arange
(
0
,
total_tokens
,
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
dtype
=
dtype
,
device
=
"cuda"
,
)
decode_attention_fwd_normal
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
seq_len
,
sm_scale
,
)
decode_attention_fwd_grouped
(
q
,
k_buffer
,
v_buffer
,
o_grouped
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
seq_len
,
sm_scale
,
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
)
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
configs
=
[
(
2
,
16
,
1
,
64
,
64
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
128
,
1
,
80
,
80
),
(
2
,
128
,
2
,
512
,
512
),
(
2
,
128
,
1
,
576
,
512
),
]
for
B
,
H_Q
,
H_KV
,
D
,
D_V
in
configs
:
self
.
_test_grouped_decode_attention_once
(
B
,
H_Q
,
H_KV
,
D
,
D_V
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment