Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3efa7981
Unverified
Commit
3efa7981
authored
Sep 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 12, 2024
Browse files
Support cuda graph in the triton attention backend (#1401)
parent
2a71be5e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
147 additions
and
60 deletions
+147
-60
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+95
-13
python/sglang/srt/layers/flashinfer_utils.py
python/sglang/srt/layers/flashinfer_utils.py
+10
-10
python/sglang/srt/layers/triton_attention/decode_attention.py
...on/sglang/srt/layers/triton_attention/decode_attention.py
+19
-25
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+13
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-6
test/srt/test_serving_throughput.py
test/srt/test_serving_throughput.py
+10
-0
No files found.
python/sglang/srt/layers/attention_backend.py
View file @
3efa7981
...
...
@@ -36,14 +36,41 @@ class AttentionBackend(ABC):
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
pass
"""Init the metadata for a forward pass."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
,
input_metadata
:
InputMetadata
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
"""Init the global shared states for cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
def
get_cuda_graph_seq_len_fill_value
(
self
):
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
"""Run forward on an attention layer."""
if
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
input_metadata
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
input_metadata
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
raise
NotImplementedError
()
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
...
...
@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
cuda_graph_kv_indices
.
clone
(),
]
def
capture_cuda_graph_init
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
if
self
.
model_runner
.
sliding_window_size
is
None
:
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
...
...
@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
forward_metadata
=
(
False
,
None
,
decode_wrapper
)
def
replay_cuda_graph_init
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
...
...
@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
cuda_graph_metadata
[
bs
],
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
if
not
isinstance
(
self
.
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
...
...
@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.triton_attention.decode_attention
import
(
REDUCE_TORCH_TYPE
,
decode_attention_fwd
,
)
from
sglang.srt.layers.triton_attention.extend_attention
import
(
...
...
@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
REDUCE_TORCH_TYPE
=
REDUCE_TORCH_TYPE
self
.
num_head
=
model_runner
.
model_config
.
num_attention_heads
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if
input_metadata
.
forward_mode
.
is_decode
():
max_seq_len
=
torch
.
max
(
input_metadata
.
seq_lens
).
item
()
start_loc
=
torch
.
zeros_like
(
input_metadata
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
input_metadata
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
REDUCE_TORCH_TYPE
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
input_metadata
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
max_seq_len
=
total_num_tokens
=
None
start_loc
=
attn_logits
=
max_seq_len
=
None
prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
max_extend_len
=
torch
.
max
(
input_metadata
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
self
.
forward_metadata
=
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
),
dtype
=
self
.
REDUCE_TORCH_TYPE
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
...
...
@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend):
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend):
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
...
...
@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend):
input_metadata
.
req_pool_indices
,
start_loc
,
input_metadata
.
seq_lens
,
attn_logits
,
max_seq_len
,
total_num_tokens
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/flashinfer_utils.py
View file @
3efa7981
...
...
@@ -66,18 +66,18 @@ class FlashinferUpdater:
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
batch_size
=
len
(
req_pool_indices
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
self
.
decode_wrapper
=
(
decode_wrapper
or
self
.
model_runner
.
attn_backend
.
decode_wrapper
)
self
.
prefill_wrapper_ragged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
)
self
.
prefill_wrapper_paged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_paged
)
(
self
.
decode_wrapper
,
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrapper_paged
,
)
=
(
decode_wrapper
or
self
.
model_runner
.
attn_backend
.
decode_wrapper
,
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
,
self
.
model_runner
.
attn_backend
.
prefill_wrapper_paged
,
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
_init_indices_no_sliding_window
(
self
):
...
...
python/sglang/srt/layers/triton_attention/decode_attention.py
View file @
3efa7981
...
...
@@ -114,7 +114,7 @@ def _fwd_kernel_stage1(
@
triton
.
jit
def
_fwd_kernel_stage2
(
L
ogi
c
s
,
l
ogi
t
s
,
V_Buffer
,
Out
,
Req_to_tokens
,
...
...
@@ -162,7 +162,7 @@ def _fwd_kernel_stage2(
)
qk
=
tl
.
load
(
L
ogi
c
s
l
ogi
t
s
+
cur_head
*
stride_logic_h
+
(
cur_batch_start_loc
+
start_n
+
offs_n
),
mask
=
start_n
+
offs_n
<
cur_batch_seq_len
,
...
...
@@ -238,7 +238,7 @@ def _decode_att_m_fwd(
def
_decode_softmax_reducev_fwd
(
logi
c
s
,
logi
t
s
,
v_buffer
,
o
,
req_to_tokens
,
...
...
@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd(
b_seq_len
,
):
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logi
c
s
.
shape
[
0
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
logi
t
s
.
shape
[
0
]
grid
=
(
batch
,
head
,
1
)
kv_group_num
=
logi
c
s
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
logi
t
s
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
num_warps
=
1
...
...
@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd(
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage2
[
grid
](
logi
c
s
,
logi
t
s
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logi
c
s
.
stride
(
0
),
logi
t
s
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
...
...
@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1(
@
triton
.
jit
def
_fwd_grouped_kernel_stage2
(
L
ogi
c
s
,
l
ogi
t
s
,
V_Buffer
,
Out
,
Req_to_tokens
,
...
...
@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2(
)
qk
=
tl
.
load
(
L
ogi
c
s
+
offs_qk
,
l
ogi
t
s
+
offs_qk
,
mask
=
mask_h
[:,
None
]
&
(
start_n
+
offs_n
[
None
,
:]
<
cur_batch_seq_len
),
other
=
float
(
"-inf"
),
)
...
...
@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd(
def
_decode_grouped_softmax_reducev_fwd
(
logi
c
s
,
logi
t
s
,
v_buffer
,
o
,
req_to_tokens
,
...
...
@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd(
b_seq_len
,
):
BLOCK
=
128
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logi
c
s
.
shape
[
0
]
kv_group_num
=
logi
c
s
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
batch
,
head_num
=
b_seq_len
.
shape
[
0
],
logi
t
s
.
shape
[
0
]
kv_group_num
=
logi
t
s
.
shape
[
0
]
//
v_buffer
.
shape
[
1
]
BLOCK_H
=
max
(
16
,
triton
.
next_power_of_2
(
kv_group_num
))
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
1
)
...
...
@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_grouped_kernel_stage2
[
grid
](
logi
c
s
,
logi
t
s
,
v_buffer
,
o
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
logi
c
s
.
stride
(
0
),
logi
t
s
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
o
.
stride
(
0
),
...
...
@@ -584,17 +584,11 @@ def decode_attention_fwd(
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
total_num_tokens
,
sm_scale
,
logit_cap
=
0.0
,
att_m
=
None
,
):
if
att_m
is
None
:
att_m
=
torch
.
empty
(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
...
...
@@ -602,7 +596,7 @@ def decode_attention_fwd(
_decode_att_m_fwd
(
q
,
k_buffer
,
att
_m
,
att
n_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
...
...
@@ -612,7 +606,7 @@ def decode_attention_fwd(
logit_cap
,
)
_decode_softmax_reducev_fwd
(
att
_m
,
att
n_logits
,
v_buffer
,
o
,
req_to_token
,
...
...
@@ -625,7 +619,7 @@ def decode_attention_fwd(
_decode_grouped_att_m_fwd
(
q
,
k_buffer
,
att
_m
,
att
n_logits
,
req_to_token
,
b_req_idx
,
b_start_loc
,
...
...
@@ -635,7 +629,7 @@ def decode_attention_fwd(
logit_cap
,
)
_decode_grouped_softmax_reducev_fwd
(
att
_m
,
att
n_logits
,
v_buffer
,
o
,
req_to_token
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
3efa7981
from
__future__
import
annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -17,13 +19,12 @@ limitations under the License.
import
bisect
from
contextlib
import
contextmanager
from
typing
import
Callable
from
typing
import
TYPE_CHECKING
,
Callable
import
torch
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsProcessor
,
...
...
@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
for
sub
in
model
.
_modules
.
values
():
...
...
@@ -111,7 +115,7 @@ class CudaGraphRunner:
self
.
req_pool_indices
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
zero
s
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
seq_lens
=
torch
.
one
s
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
position_ids_offsets
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -121,6 +125,9 @@ class CudaGraphRunner:
# Attention backend
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
)
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# Sampling info
vocab_size
=
model_runner
.
model_config
.
vocab_size
...
...
@@ -176,7 +183,7 @@ class CudaGraphRunner:
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
# Attention backend
self
.
model_runner
.
attn_backend
.
capture_cuda_graph
_init
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_
capture_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
)
...
...
@@ -227,7 +234,7 @@ class CudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
zero_
(
)
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
position_ids_offsets
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
...
...
@@ -239,7 +246,7 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# Attention backend
self
.
model_runner
.
attn_backend
.
replay_cuda_graph
_init
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_
replay_cuda_graph
(
bs
,
self
.
req_pool_indices
,
self
.
seq_lens
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
3efa7981
...
...
@@ -445,12 +445,6 @@ class ModelRunner:
if
self
.
server_args
.
disable_cuda_graph
:
return
if
self
.
server_args
.
attention_backend
!=
"flashinfer"
:
logger
.
warning
(
f
"Cuda graph is not supported for attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
return
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
...
...
test/srt/test_serving_throughput.py
View file @
3efa7981
...
...
@@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase):
if
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
:
assert
res
[
"output_throughput"
]
>
2400
def
test_default_with_triton_attention_backend
(
self
):
res
=
self
.
run_test
(
disable_radix_cache
=
ServerArgs
.
disable_radix_cache
,
attention_backend
=
"triton"
,
chunked_prefill_size
=-
1
,
)
if
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
:
assert
res
[
"output_throughput"
]
>
2400
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment