Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7dc66fcb
Unverified
Commit
7dc66fcb
authored
Dec 08, 2024
by
Ke Bao
Committed by
GitHub
Dec 08, 2024
Browse files
Optimize Triton decoding kernel for long context (#2394)
parent
1f09e84b
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
328 additions
and
360 deletions
+328
-360
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+13
-8
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+287
-342
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+21
-10
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
7dc66fcb
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
reduce_dtype
=
torch
.
float16
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend):
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
forward_batch
.
seq_lens_sum
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
(
dtype
=
self
.
reduce_dtype
,
forward_batch
.
batch_size
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
self
.
num_head
,
dtype
=
torch
.
float32
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
max_seq_len
,
max_seq_len
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
7dc66fcb
This diff is collapsed.
Click to expand it.
python/sglang/srt/server_args.py
View file @
7dc66fcb
...
@@ -141,6 +141,7 @@ class ServerArgs:
...
@@ -141,6 +141,7 @@ class ServerArgs:
enable_nan_detection
:
bool
=
False
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_num_kv_splits
:
int
=
8
num_continuous_decode_steps
:
int
=
1
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
delete_ckpt_after_loading
:
bool
=
False
...
@@ -753,6 +754,12 @@ class ServerArgs:
...
@@ -753,6 +754,12 @@ class ServerArgs:
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
)
)
parser
.
add_argument
(
"--triton-attention-num-kv-splits"
,
type
=
int
,
default
=
ServerArgs
.
triton_attention_num_kv_splits
,
help
=
"The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-continuous-decode-steps"
,
"--num-continuous-decode-steps"
,
type
=
int
,
type
=
int
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
7dc66fcb
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
):
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
seq_len
=
1
0
# This represents the number of tokens already in the sequence
seq_len
=
1
28
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase):
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
total_tokens
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D
_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
H_Q
,
total_tokens
),
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
dtype
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
decode_attention_fwd_grouped
(
decode_attention_fwd_grouped
(
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
req_to_token
,
req_to_token
,
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
attn_logits
,
attn_logits
1
,
seq_len
,
seq_len
,
num_kv_splits
,
sm_scale
,
sm_scale
,
)
)
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
o
.
flatten
(),
o_grouped
.
flatten
(),
dim
=
0
)
)
print
(
cos_sim
.
item
())
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
cos_sim
.
item
()
>
0.99
)
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o
,
o_grouped
,
atol
=
3e-2
))
def
test_grouped_decode_attention
(
self
):
def
test_grouped_decode_attention
(
self
):
configs
=
[
configs
=
[
(
2
,
16
,
16
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
16
,
1
,
64
,
64
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
64
,
1
,
13
,
13
),
(
2
,
128
,
1
,
80
,
80
),
(
2
,
128
,
1
,
80
,
80
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment