Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fec185ce
Unverified
Commit
fec185ce
authored
Sep 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 11, 2024
Browse files
Refactor attention backend (#1381)
parent
c03cece4
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
568 additions
and
564 deletions
+568
-564
python/sglang/srt/layers/attention_backend.py
python/sglang/srt/layers/attention_backend.py
+383
-0
python/sglang/srt/layers/flashinfer_utils.py
python/sglang/srt/layers/flashinfer_utils.py
+35
-37
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+7
-168
python/sglang/srt/layers/triton_attention/decode_attention.py
...on/sglang/srt/layers/triton_attention/decode_attention.py
+3
-4
python/sglang/srt/layers/triton_attention/extend_attention.py
...on/sglang/srt/layers/triton_attention/extend_attention.py
+12
-19
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+46
-108
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+26
-108
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+28
-97
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+3
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+5
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-2
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+1
-1
test/srt/test_moe_serving_throughput.py
test/srt/test_moe_serving_throughput.py
+2
-1
test/srt/test_serving_throughput.py
test/srt/test_serving_throughput.py
+2
-1
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+1
-7
No files found.
python/sglang/srt/layers/attention_backend.py
0 → 100644
View file @
fec185ce
from
__future__
import
annotations
"""
Support different attention backends.
Now there are two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
sglang.global_config
import
global_config
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
AttentionBackend
(
ABC
):
"""The base class of attention backends"""
@
abstractmethod
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
pass
def
forward
(
self
,
q
,
k
,
v
,
layer
,
input_metadata
:
InputMetadata
):
if
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
input_metadata
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
input_metadata
)
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
self
.
model_runner
=
model_runner
if
not
_grouped_size_compiled_for_decode_kernels
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
):
self
.
decode_use_tensor_cores
=
True
else
:
self
.
decode_use_tensor_cores
=
False
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
if
model_runner
.
sliding_window_size
is
None
:
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
)
else
:
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self
.
prefill_wrapper_ragged
=
None
self
.
prefill_wrapper_paged
=
[]
self
.
decode_wrapper
=
[]
for
_
in
range
(
2
):
self
.
prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
self
.
decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
)
)
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
if
input_metadata
.
forward_mode
.
is_decode
():
prefix_lens
=
None
use_ragged
=
False
total_num_tokens
=
None
else
:
prefix_lens
=
input_metadata
.
extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged
=
False
if
(
int
(
torch
.
sum
(
input_metadata
.
seq_lens
))
>
4096
and
self
.
model_runner
.
sliding_window_size
is
None
):
use_ragged
=
True
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
update_flashinfer_indices
(
input_metadata
.
forward_mode
,
self
.
model_runner
,
input_metadata
.
req_pool_indices
,
input_metadata
.
seq_lens
,
prefix_lens
,
use_ragged
=
use_ragged
,
)
self
.
forward_metadata
=
(
use_ragged
,
total_num_tokens
,
self
.
decode_wrapper
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
self
.
cuda_graph_kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
model_runner
.
sliding_window_size
is
not
None
:
self
.
cuda_graph_kv_indptr
=
[
self
.
cuda_graph_kv_indptr
,
self
.
cuda_graph_kv_indptr
.
clone
(),
]
self
.
cuda_graph_kv_indices
=
[
self
.
cuda_graph_kv_indices
,
self
.
cuda_graph_kv_indices
.
clone
(),
]
def
capture_cuda_graph_init
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
if
self
.
model_runner
.
sliding_window_size
is
None
:
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[:
bs
],
)
else
:
decode_wrapper
=
[]
for
i
in
range
(
2
):
decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[
:
bs
],
)
)
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
decode_wrapper
,
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrapper
self
.
forward_metadata
=
(
False
,
None
,
decode_wrapper
)
def
replay_cuda_graph_init
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
None
,
self
.
cuda_graph_metadata
[
bs
],
)
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
if
not
isinstance
(
self
.
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
else
:
if
layer
.
sliding_window_size
!=
-
1
:
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
[
0
]
else
:
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
[
1
]
use_ragged
,
total_num_tokens
,
decode_wrapper
=
self
.
forward_metadata
if
not
use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
True
,
sm_scale
=
layer
.
scaling
,
window_left
=
layer
.
sliding_window_size
,
logits_soft_cap
=
layer
.
logit_cap
,
)
else
:
o1
,
s1
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
head_dim
),
causal
=
True
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
)
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
causal
=
False
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
if
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
use_ragged
,
total_num_tokens
,
decode_wrapper
=
self
.
forward_metadata
if
isinstance
(
decode_wrapper
,
list
):
if
layer
.
sliding_window_size
!=
-
1
:
decode_wrapper
=
decode_wrapper
[
0
]
else
:
decode_wrapper
=
decode_wrapper
[
1
]
if
k
is
not
None
:
assert
v
is
not
None
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.triton_attention.decode_attention
import
(
decode_attention_fwd
,
)
from
sglang.srt.layers.triton_attention.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
forward_metadata
=
None
def
init_forward_metadata
(
self
,
batch
:
ScheduleBatch
,
input_metadata
:
InputMetadata
):
"""Init auxiliary variables for triton attention backend."""
if
input_metadata
.
forward_mode
.
is_decode
():
max_seq_len
=
torch
.
max
(
input_metadata
.
seq_lens
).
item
()
start_loc
=
torch
.
zeros_like
(
input_metadata
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
input_metadata
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
input_metadata
.
seq_lens
).
item
()
max_extend_len
=
None
else
:
start_loc
=
max_seq_len
=
total_num_tokens
=
None
prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
max_extend_len
=
torch
.
max
(
input_metadata
.
seq_lens
-
prefix_lens
).
item
()
self
.
forward_metadata
=
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
seq_lens
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
input_metadata
:
InputMetadata
):
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
start_loc
,
max_seq_len
,
max_extend_len
,
total_num_tokens
=
self
.
forward_metadata
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
input_metadata
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
start_loc
,
input_metadata
.
seq_lens
,
max_seq_len
,
total_num_tokens
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
python/sglang/srt/layers/flashinfer_utils.py
View file @
fec185ce
...
...
@@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton(
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
max_context_len
,
kv_indices_ptr
,
max_context_len
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
...
...
@@ -47,15 +47,15 @@ class FlashinferUpdater:
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_
decode_wrapper
=
None
,
flashinfer_
use_ragged
=
False
,
decode_wrapper
=
None
,
use_ragged
=
False
,
):
self
.
forward_mode
=
forward_mode
self
.
model_runner
=
model_runner
self
.
req_pool_indices
=
req_pool_indices
self
.
seq_lens
=
seq_lens
self
.
prefix_lens
=
prefix_lens
self
.
flashinfer_
use_ragged
=
flashinfer_
use_ragged
self
.
use_ragged
=
use_ragged
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
...
@@ -71,20 +71,17 @@ class FlashinferUpdater:
)
(
self
.
flashinfer_
decode_wrapper
,
self
.
flashinfer_
prefill_wrapper_ragged
,
self
.
flashinfer_
prefill_wrapper_paged
,
self
.
decode_wrapper
,
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrapper_paged
,
)
=
(
flashinfer_
decode_wrapper
,
self
.
model_runner
.
flashinfer_
prefill_wrapper_ragged
,
self
.
model_runner
.
flashinfer_
prefill_wrapper_paged
,
decode_wrapper
or
self
.
model_runner
.
attn_backend
.
decode_wrapper
,
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
,
self
.
model_runner
.
attn_backend
.
prefill_wrapper_paged
,
)
# CUDA graph uses different flashinfer_decode_wrapper
if
self
.
flashinfer_decode_wrapper
is
None
:
self
.
flashinfer_decode_wrapper
=
self
.
model_runner
.
flashinfer_decode_wrapper
def
_init_indices_no_window
(
self
):
if
self
.
flashinfer_
use_ragged
:
def
_init_indices_no_
sliding_
window
(
self
):
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
...
...
@@ -103,13 +100,13 @@ class FlashinferUpdater:
paged_kernel_lens
,
self
.
kv_indptr
,
None
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_init_indices_window
(
self
,
wrapper_id
):
# window attention use paged only
def
_init_indices_sliding_window
(
self
,
wrapper_id
):
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
...
...
@@ -123,6 +120,7 @@ class FlashinferUpdater:
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
...
...
@@ -139,8 +137,8 @@ class FlashinferUpdater:
paged_kernel_lens
,
self
.
kv_indptr
,
kv_start_idx
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
...
...
@@ -164,7 +162,7 @@ class FlashinferUpdater:
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
-
self
.
prefix_lens
,
dim
=
0
)
if
self
.
flashinfer_
use_ragged
:
if
self
.
use_ragged
:
ragged_wrapper
.
end_forward
()
ragged_wrapper
.
begin_forward
(
qo_indptr
,
...
...
@@ -187,28 +185,28 @@ class FlashinferUpdater:
1
,
)
def
update_indices_no_window
(
self
):
self
.
_init_indices_no_window
()
def
update_indices_no_
sliding_
window
(
self
):
self
.
_init_indices_no_
sliding_
window
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
flashinfer_
decode_wrapper
)
self
.
_update_decode_indices
(
self
.
decode_wrapper
)
else
:
self
.
_update_extend_indices
(
self
.
flashinfer_
prefill_wrapper_ragged
,
self
.
flashinfer_
prefill_wrapper_paged
,
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrapper_paged
,
)
def
update_indices_window
(
self
):
assert
self
.
flashinfer_
use_ragged
is
False
def
update_indices_
sliding_
window
(
self
):
assert
self
.
use_ragged
is
False
for
wrapper_id
in
range
(
2
):
self
.
_init_indices_window
(
wrapper_id
)
self
.
_init_indices_
sliding_
window
(
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
flashinfer_
decode_wrapper
[
wrapper_id
])
self
.
_update_decode_indices
(
self
.
decode_wrapper
[
wrapper_id
])
else
:
self
.
_update_extend_indices
(
None
,
self
.
flashinfer_
prefill_wrapper_paged
[
wrapper_id
],
self
.
prefill_wrapper_paged
[
wrapper_id
],
)
...
...
@@ -218,20 +216,20 @@ def update_flashinfer_indices(
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_
decode_wrapper
=
None
,
flashinfer_
use_ragged
=
False
,
decode_wrapper
=
None
,
use_ragged
=
False
,
):
flashinfer_
updater
=
FlashinferUpdater
(
updater
=
FlashinferUpdater
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_
decode_wrapper
,
flashinfer_
use_ragged
,
decode_wrapper
,
use_ragged
,
)
if
model_runner
.
sliding_window_size
is
None
:
flashinfer_
updater
.
update_indices_no_window
()
updater
.
update_indices_no_
sliding_
window
()
else
:
flashinfer_
updater
.
update_indices_window
()
updater
.
update_indices_
sliding_
window
()
python/sglang/srt/layers/radix_attention.py
View file @
fec185ce
...
...
@@ -15,25 +15,14 @@ limitations under the License.
"""Radix attention."""
from
typing
import
Optional
import
torch
from
flashinfer.cascade
import
merge_state
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.triton_attention.decode_attention
import
decode_attention_fwd
from
sglang.srt.layers.triton_attention.extend_attention
import
extend_attention_fwd
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.model_runner
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
class
RadixAttention
(
nn
.
Module
):
"""
The attention layer implementation.
Now it has two backends: FlashInfer and Triton.
FlashInfer is faster and Triton is easier to customize.
It supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
def
__init__
(
...
...
@@ -43,8 +32,8 @@ class RadixAttention(nn.Module):
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
sliding_window_size
:
Optional
[
int
]
=
None
,
logit_cap
:
in
t
=
-
1
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
floa
t
=
0.0
,
v_head_dim
:
int
=
-
1
,
):
super
().
__init__
()
...
...
@@ -56,164 +45,14 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
self
.
sliding_window_size
=
sliding_window_size
if
sliding_window_size
else
-
1
# Choose backend
if
(
global_server_args_dict
[
"attention_backend"
]
==
"flashinfer"
and
self
.
qk_head_dim
==
self
.
v_head_dim
):
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
elif
global_server_args_dict
[
"attention_backend"
]
==
"triton"
:
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
global_server_args_dict
[
'attention_backend'
]
}
"
)
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
extend_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
triton_start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
triton_prefix_lens
,
input_metadata
.
extend_start_loc
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
triton_max_seq_len
,
input_metadata
.
triton_max_extend_len
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
)
return
o
def
decode_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
decode_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
triton_start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
triton_max_seq_len
,
input_metadata
.
total_num_tokens
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
)
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
if
self
.
sliding_window_size
!=
-
1
:
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
else
:
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
if
not
input_metadata
.
flashinfer_use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
)
else
:
o1
,
s1
=
(
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
False
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
decode_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
decode_wrapper
=
input_metadata
.
flashinfer_decode_wrapper
if
self
.
sliding_window_size
!=
-
1
:
decode_wrapper
=
decode_wrapper
[
0
]
else
:
if
isinstance
(
decode_wrapper
,
list
):
decode_wrapper
=
decode_wrapper
[
1
]
if
k
is
not
None
:
assert
v
is
not
None
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
if
k
is
not
None
:
# For cross-layer sharing, kv can be None
assert
v
is
not
None
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
if
input_metadata
.
forward_mode
.
is_extend
():
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
elif
input_metadata
.
forward_mode
.
is_decode
():
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
input_metadata
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
layer_id
,
input_metadata
.
out_cache_loc
,
cache_k
,
cache_v
)
return
input_metadata
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
input_metadata
)
python/sglang/srt/layers/triton_attention/decode_attention.py
View file @
fec185ce
...
...
@@ -15,6 +15,7 @@ limitations under the License.
"""
Memory-efficient attention for decoding.
It supports page size = 1.
"""
# Adapted from
...
...
@@ -197,7 +198,6 @@ def _decode_att_m_fwd(
logit_cap
,
):
BLOCK
=
32
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd(
logit_cap
,
):
BLOCK
=
32
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
if
Lk
==
576
:
...
...
@@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
1
,
Lv
=
Lv
,
)
...
...
@@ -588,7 +587,7 @@ def decode_attention_fwd(
max_len_in_batch
,
total_num_tokens
,
sm_scale
,
logit_cap
=
-
1
,
logit_cap
=
0.0
,
att_m
=
None
,
):
if
att_m
is
None
:
...
...
python/sglang/srt/layers/triton_attention/extend_attention.py
View file @
fec185ce
...
...
@@ -61,14 +61,14 @@ def _fwd_kernel(
stride_buf_vbs
,
stride_buf_vh
,
stride_req_to_tokens_b
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -111,7 +111,7 @@ def _fwd_kernel(
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
# stage1: compute scores with prefix
# stage
1: compute scores with prefix
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
...
...
@@ -174,7 +174,7 @@ def _fwd_kernel(
e_max
=
n_e_max
# stage2: compute the trianlge part
# stage
2: compute the trianlge part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
...
...
@@ -255,26 +255,22 @@ def extend_attention_fwd(
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len_prefix
,
b_start_loc_extend
,
b_seq_len_extend
,
max_len_in_batch
,
b_start_loc_extend
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
-
1
,
logit_cap
=
0.0
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
Lq
,
Lk
,
Lv
,
Lo
=
(
Lq
,
Lk
,
Lv
=
(
q_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
o_extend
.
shape
[
-
1
],
)
if
Lq
==
576
:
...
...
@@ -303,7 +299,7 @@ def extend_attention_fwd(
else
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
if
Lq
<=
128
else
(
32
,
32
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
sm_scale
is
None
else
sm_scale
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
...
...
@@ -338,27 +334,24 @@ def extend_attention_fwd(
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
logit_cap
=
logit_cap
,
Lq
=
Lq
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
def
redundant_attention
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
fec185ce
...
...
@@ -368,7 +368,7 @@ class ScheduleBatch:
)
def
batch_size
(
self
):
return
len
(
self
.
reqs
)
if
self
.
reqs
is
not
None
else
0
return
len
(
self
.
reqs
)
if
self
.
reqs
else
0
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
fec185ce
...
...
@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Run the model with cuda graph."""
"""Run the model with cuda graph
and torch.compile
."""
import
bisect
from
contextlib
import
contextmanager
from
typing
import
Callable
,
List
from
typing
import
Callable
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
...
...
@@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
tp_group
:
"GroupCoordinator"
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
try
:
...
...
@@ -86,23 +85,28 @@ def set_torch_compile_config():
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
max_batch_size_to_capture
:
int
,
use_torch_compile
:
bool
,
disable_padding
:
bool
,
):
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def
__init__
(
self
,
model_runner
:
"ModelRunner"
):
# Parse args
self
.
model_runner
=
model_runner
self
.
graphs
=
{}
self
.
input_buffers
=
{}
self
.
output_buffers
=
{}
self
.
flashinfer_handlers
=
{}
self
.
graph_memory_pool
=
None
self
.
disable_padding
=
disable_padding
self
.
use_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
disable_padding
=
model_runner
.
server_args
.
disable_cuda_graph_padding
# Batch sizes to capture
if
self
.
model_runner
.
server_args
.
disable_cuda_graph_padding
:
self
.
capture_bs
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
self
.
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
self
.
use_torch_compile
else
[]
# Common inputs
self
.
max_bs
=
max
_batch_size_to_
capture
self
.
max_bs
=
max
(
self
.
capture
_bs
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
req_pool_indices
=
torch
.
zeros
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
...
...
@@ -115,56 +119,39 @@ class CudaGraphRunner:
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# FlashInfer inputs
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
flashinfer_kv_indices
=
torch
.
zeros
(
(
self
.
max_bs
*
model_runner
.
model_config
.
context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
model_runner
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffer
)
else
:
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffer
)
self
.
flashinfer_kv_indptr
=
[
self
.
flashinfer_kv_indptr
,
self
.
flashinfer_kv_indptr
.
clone
(),
]
self
.
flashinfer_kv_indices
=
[
self
.
flashinfer_kv_indices
,
self
.
flashinfer_kv_indices
.
clone
(),
]
# Attention backend
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_bs
)
# Sampling in
puts
# Sampling in
fo
vocab_size
=
model_runner
.
model_config
.
vocab_size
self
.
sampling_info
=
SamplingBatchInfo
.
dummy_one
(
self
.
max_bs
,
vocab_size
)
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
if
use_torch_compile
:
if
self
.
use_torch_compile
:
set_torch_compile_config
()
# Capture
try
:
self
.
capture
()
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. disable cuda graph by --disable-cuda-graph
\n
"
"2. set --mem-fraction-static to a smaller value
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
def
can_run
(
self
,
batch_size
:
int
):
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
else
:
return
batch_size
<=
self
.
max_bs
def
capture
(
self
,
batch_size_list
:
List
[
int
]):
self
.
batch_size_list
=
batch_size_list
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
for
bs
in
self
.
capture_bs
:
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
...
...
@@ -172,14 +159,10 @@ class CudaGraphRunner:
)
as
forward
:
(
graph
,
input_buffers
,
output_buffers
,
flashinfer_handler
,
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
@@ -192,48 +175,9 @@ class CudaGraphRunner:
position_ids_offsets
=
self
.
position_ids_offsets
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
# FlashInfer inputs
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_runner
.
model_config
.
num_attention_heads
//
self
.
model_runner
.
tp_size
,
self
.
model_runner
.
model_config
.
get_num_kv_heads
(
self
.
model_runner
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
if
self
.
model_runner
.
sliding_window_size
is
None
:
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
else
:
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[
:
bs
],
)
)
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
flashinfer_decode_wrapper
,
# Attention backend
self
.
model_runner
.
attn_backend
.
capture_cuda_graph_init
(
bs
,
req_pool_indices
,
seq_lens
)
# Run and capture
...
...
@@ -246,13 +190,12 @@ class CudaGraphRunner:
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
return_logprob
=
False
,
top_logprobs_nums
=
0
,
positions
=
(
seq_lens
-
1
+
position_ids_offsets
).
to
(
torch
.
int64
),
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
)
return
forward
(
input_ids
,
input_metadata
.
positions
,
input_metadata
)
for
_
in
range
(
2
):
...
...
@@ -274,15 +217,15 @@ class CudaGraphRunner:
self
.
model_runner
.
tp_group
.
barrier
()
self
.
graph_memory_pool
=
graph
.
pool
()
return
graph
,
None
,
out
,
flashinfer_decode_wrapper
return
graph
,
out
def
replay
(
self
,
batch
:
ScheduleBatch
):
assert
batch
.
out_cache_loc
is
not
None
raw_bs
=
len
(
batch
.
reqs
)
# Pad
index
=
bisect
.
bisect_left
(
self
.
batch_size_list
,
raw_bs
)
bs
=
self
.
batch_size_list
[
index
]
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
self
.
seq_lens
.
zero_
()
self
.
position_ids_offsets
.
fill_
(
1
)
...
...
@@ -295,14 +238,9 @@ class CudaGraphRunner:
self
.
position_ids_offsets
[:
raw_bs
]
=
batch
.
position_ids_offsets
self
.
out_cache_loc
[:
raw_bs
]
=
batch
.
out_cache_loc
# FlashInfer inputs
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
self
.
req_pool_indices
[:
bs
],
self
.
seq_lens
[:
bs
],
None
,
self
.
flashinfer_handlers
[
bs
],
# Attention backend
self
.
model_runner
.
attn_backend
.
replay_cuda_graph_init
(
bs
,
self
.
req_pool_indices
,
self
.
seq_lens
)
# Sampling inputs
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fec185ce
...
...
@@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List
import
numpy
as
np
import
torch
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -66,12 +65,11 @@ class InputMetadata:
seq_lens
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
BaseTokenToKVPool
attn_backend
:
AttentionBackend
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
total_num_tokens
:
int
=
None
# Position information
positions
:
torch
.
Tensor
=
None
...
...
@@ -93,18 +91,6 @@ class InputMetadata:
image_offsets
:
List
[
List
[
int
]]
=
None
modalities
:
List
[
List
[
str
]]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
triton_max_extend_len
:
int
=
0
triton_start_loc
:
torch
.
Tensor
=
None
triton_prefix_lens
:
torch
.
Tensor
=
None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_use_ragged
:
bool
=
False
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
reqs
=
batch
.
reqs
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
...
...
@@ -154,32 +140,27 @@ class InputMetadata:
self
.
positions
=
self
.
positions
.
to
(
torch
.
int64
)
def
compute_extend_infos
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
.
is_decode
():
self
.
extend_seq_lens
=
self
.
extend_start_loc
=
self
.
extend_no_prefix
=
None
self
.
extend_seq_lens_cpu
=
self
.
logprob_start_lens_cpu
=
None
else
:
extend_lens_cpu
=
[
len
(
r
.
fill_ids
)
-
batch
.
prefix_lens_cpu
[
i
]
for
i
,
r
in
enumerate
(
batch
.
reqs
)
]
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
l
==
0
for
l
in
batch
.
prefix_lens_cpu
)
self
.
extend_seq_lens_cpu
=
extend_lens_cpu
self
.
logprob_start_lens_cpu
=
[
(
min
(
req
.
logprob_start_len
-
batch
.
prefix_lens_cpu
[
i
],
extend_lens_cpu
[
i
]
-
1
,
)
if
req
.
logprob_start_len
>=
batch
.
prefix_lens_cpu
[
i
]
else
extend_lens_cpu
[
i
]
-
1
# Fake extend, actually decode
extend_lens_cpu
=
[
len
(
r
.
fill_ids
)
-
batch
.
prefix_lens_cpu
[
i
]
for
i
,
r
in
enumerate
(
batch
.
reqs
)
]
self
.
extend_seq_lens
=
torch
.
tensor
(
extend_lens_cpu
,
device
=
"cuda"
)
self
.
extend_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
extend_no_prefix
=
all
(
l
==
0
for
l
in
batch
.
prefix_lens_cpu
)
self
.
extend_seq_lens_cpu
=
extend_lens_cpu
self
.
logprob_start_lens_cpu
=
[
(
min
(
req
.
logprob_start_len
-
batch
.
prefix_lens_cpu
[
i
],
extend_lens_cpu
[
i
]
-
1
,
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
]
if
req
.
logprob_start_len
>=
batch
.
prefix_lens_cpu
[
i
]
else
extend_lens_cpu
[
i
]
-
1
# Fake extend, actually decode
)
for
i
,
req
in
enumerate
(
batch
.
reqs
)
]
@
classmethod
def
from_schedule_batch
(
...
...
@@ -195,6 +176,7 @@ class InputMetadata:
seq_lens
=
batch
.
seq_lens
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
attn_backend
=
model_runner
.
attn_backend
,
out_cache_loc
=
batch
.
out_cache_loc
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
...
...
@@ -202,76 +184,12 @@ class InputMetadata:
ret
.
sampling_info
.
update_penalties
()
ret
.
sampling_info
.
update_regex_vocab_mask
(
batch
)
ret
.
compute_positions
(
batch
)
ret
.
compute_extend_infos
(
batch
)
fm
=
batch
.
forward_mode
if
not
fm
.
is_decode
()
or
model_runner
.
server_args
.
attention_backend
==
"triton"
:
ret
.
total_num_tokens
=
int
(
torch
.
sum
(
ret
.
seq_lens
))
if
not
fm
.
is_decode
():
if
not
batch
.
forward_mode
.
is_decode
():
ret
.
init_multimuldal_info
(
batch
)
ret
.
compute_extend_infos
(
batch
)
if
model_runner
.
server_args
.
attention_backend
==
"triton"
:
ret
.
init_triton_args
(
batch
)
flashinfer_use_ragged
=
False
if
model_runner
.
server_args
.
attention_backend
==
"flashinfer"
:
if
(
not
fm
.
is_decode
()
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
and
model_runner
.
sliding_window_size
is
None
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
model_runner
,
batch
.
prefix_lens_cpu
,
flashinfer_use_ragged
)
model_runner
.
attn_backend
.
init_forward_metadata
(
batch
,
ret
)
return
ret
def
init_triton_args
(
self
,
batch
:
ScheduleBatch
):
"""Init auxiliary variables for triton attention backend."""
self
.
triton_max_seq_len
=
int
(
torch
.
max
(
self
.
seq_lens
))
self
.
triton_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
,
dtype
=
torch
.
int32
)
self
.
triton_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
[:
-
1
],
dim
=
0
)
if
self
.
forward_mode
.
is_decode
():
self
.
triton_max_extend_len
=
None
else
:
self
.
triton_prefix_lens
=
torch
.
tensor
(
batch
.
prefix_lens_cpu
,
device
=
"cuda"
)
extend_seq_lens
=
self
.
seq_lens
-
self
.
triton_prefix_lens
self
.
triton_max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
def
init_flashinfer_handlers
(
self
,
model_runner
,
prefix_lens_cpu
,
flashinfer_use_ragged
,
):
if
self
.
forward_mode
.
is_decode
():
prefix_lens
=
None
else
:
prefix_lens
=
self
.
extend_prefix_lens
update_flashinfer_indices
(
self
.
forward_mode
,
model_runner
,
self
.
req_pool_indices
,
self
.
seq_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
)
(
self
.
flashinfer_prefill_wrapper_ragged
,
self
.
flashinfer_prefill_wrapper_paged
,
self
.
flashinfer_decode_wrapper
,
self
.
flashinfer_use_ragged
,
)
=
(
model_runner
.
flashinfer_prefill_wrapper_ragged
,
model_runner
.
flashinfer_prefill_wrapper_paged
,
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
python/sglang/srt/model_executor/model_runner.py
View file @
fec185ce
...
...
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
import
torch
import
torch.nn
as
nn
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
...
...
@@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.layers.attention_backend
import
FlashInferAttnBackend
,
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
...
...
@@ -69,6 +63,8 @@ logger = logging.getLogger(__name__)
class
ModelRunner
:
"""ModelRunner runs the forward passes of the models."""
def
__init__
(
self
,
model_config
:
ModelConfig
,
...
...
@@ -100,6 +96,7 @@ class ModelRunner:
}
)
# Model-specific adjustment
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
...
...
@@ -107,6 +104,7 @@ class ModelRunner:
server_args
.
chunked_prefill_size
=
None
server_args
.
mem_fraction_static
*=
0.95
# Init componnets
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
load_model
()
self
.
init_memory_pool
(
...
...
@@ -115,7 +113,7 @@ class ModelRunner:
server_args
.
max_total_tokens
,
)
self
.
init_cublas
()
self
.
init_
flashinfer
()
self
.
init_
attention_backend
()
self
.
init_cuda_graphs
()
def
init_torch_distributed
(
self
):
...
...
@@ -397,9 +395,6 @@ class ModelRunner:
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
"using MLA Triton implementaion, flashinfer is disabled"
)
# FIXME: temporarily only Triton MLA is supported
self
.
server_args
.
attention_backend
=
"triton"
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
...
...
@@ -422,106 +417,42 @@ class ModelRunner:
c
=
a
@
b
return
c
def
init_flashinfer
(
self
):
"""Init flashinfer attention kernel wrappers."""
if
self
.
server_args
.
attention_backend
!=
"flashinfer"
:
assert
(
self
.
sliding_window_size
is
None
),
"turn on flashinfer to support window attention"
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
None
self
.
flashinfer_decode_wrapper
=
None
return
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
):
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
if
self
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
self
.
flashinfer_workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
)
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
self
.
cuda_graph_runner
=
None
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
if
self
.
server_args
.
disable_cuda_graph
:
return
if
(
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
attention_backend
!=
"flashinfer"
):
self
.
cuda_graph_runner
=
None
if
self
.
server_args
.
attention_backend
!=
"flashinfer"
:
logger
.
warning
(
f
"Cuda graph is not supported for attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
return
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
if
self
.
server_args
.
disable_cuda_graph_padding
:
batch_size_list
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
),
use_torch_compile
=
self
.
server_args
.
enable_torch_compile
,
disable_padding
=
self
.
server_args
.
disable_cuda_graph_padding
,
)
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed:
{
e
}
\n
"
"Possible solutions:
\n
"
"1. disable cuda graph by --disable-cuda-graph
\n
"
"2. set --mem-fraction-static to a smaller value
\n
"
"3. disable torch compile by not using --enable-torch-compile
\n
"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose
\n
"
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
@
torch
.
inference_mode
()
def
forward_decode
(
self
,
batch
:
ScheduleBatch
):
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
fec185ce
...
...
@@ -143,18 +143,16 @@ class SamplingBatchInfo:
self
.
linear_penalties
=
penalizer
.
apply
(
self
.
linear_penalties
)
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
device
=
"cuda"
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
reqs
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
batch
.
reqs
)
# Reset the vocab mask
self
.
vocab_mask
=
None
if
has_regex
:
self
.
vocab_mask
=
torch
.
zeros
(
b
s
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
b
atch
.
batch_size
()
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
"cuda"
)
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
self
.
vocab_mask
[
i
].
fill_
(
1
)
self
.
vocab_mask
[
i
][
...
...
python/sglang/srt/server.py
View file @
fec185ce
...
...
@@ -335,23 +335,19 @@ def launch_server(
return
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
if
server_args
.
chat_template
:
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
)
pipe_controller_reader
,
pipe_controller_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
if
server_args
.
dp_size
==
1
:
start_controller_process
=
start_controller_process_single
else
:
start_controller_process
=
start_controller_process_multi
proc_controller
=
mp
.
Process
(
target
=
start_controller_process
,
args
=
(
server_args
,
port_args
,
pipe_controller_writer
),
)
proc_controller
.
start
()
pipe_detoken_reader
,
pipe_detoken_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc_detoken
=
mp
.
Process
(
target
=
start_detokenizer_process
,
args
=
(
...
...
@@ -362,6 +358,10 @@ def launch_server(
)
proc_detoken
.
start
()
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
if
server_args
.
chat_template
:
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
)
# Wait for the model to finish loading
controller_init_state
=
pipe_controller_reader
.
recv
()
detoken_init_state
=
pipe_detoken_reader
.
recv
()
...
...
python/sglang/srt/server_args.py
View file @
fec185ce
...
...
@@ -83,8 +83,8 @@ class ServerArgs:
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
attention_backend
:
str
=
"flashinfer"
sampling_backend
:
str
=
"flashinfer"
attention_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
...
...
@@ -148,6 +148,17 @@ class ServerArgs:
)
self
.
sampling_backend
=
"pytorch"
# Default kernel backends
if
self
.
enable_mla
:
logger
.
info
(
"MLA optimization is tunred on. Use triton backend."
)
self
.
attention_backend
=
"triton"
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"flashinfer"
if
self
.
sampling_backend
is
None
:
self
.
sampling_backend
=
"flashinfer"
# Model-specific patches
if
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
==
self
.
model_path
:
logger
.
info
(
...
...
test/srt/test_create_kvindices.py
View file @
fec185ce
...
...
@@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase):
paged_kernel_lens
,
kv_indptr
,
None
,
req_to_token
.
size
(
1
),
kv_indices_triton
,
req_to_token
.
size
(
1
),
)
# Check
...
...
test/srt/test_moe_serving_throughput.py
View file @
fec185ce
...
...
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args
=
[]
if
disable_radix_cache
:
other_args
.
append
(
"--disable-radix-cache"
)
other_args
.
extend
([
"--attention-backend"
,
attention_backend
])
if
attention_backend
:
other_args
.
extend
([
"--attention-backend"
,
attention_backend
])
other_args
.
extend
([
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)])
other_args
.
extend
([
"--tensor-parallel-size"
,
"2"
])
...
...
test/srt/test_serving_throughput.py
View file @
fec185ce
...
...
@@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase):
other_args
=
[]
if
disable_radix_cache
:
other_args
.
append
(
"--disable-radix-cache"
)
other_args
.
extend
([
"--attention-backend"
,
attention_backend
])
if
attention_backend
:
other_args
.
extend
([
"--attention-backend"
,
attention_backend
])
other_args
.
extend
([
"--chunked-prefill-size"
,
str
(
chunked_prefill_size
)])
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
test/srt/test_triton_attention_kernels.py
View file @
fec185ce
...
...
@@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase):
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
b_seq_len_prefix
,
b_start_loc_extend
,
b_seq_len_extend
,
max_len_in_batch
,
b_start_loc_extend
,
max_len_extend
,
)
redundant_attention
(
q_extend
,
k_extend
,
v_extend
,
o_redundant
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment