Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ba5ad57
Unverified
Commit
7ba5ad57
authored
Aug 09, 2025
by
DarkSharpness
Committed by
GitHub
Aug 10, 2025
Browse files
[Fix] Fix flashinfer cpu <-> gpu synchronization (#8340)
parent
19bc77f0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
24 deletions
+65
-24
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+52
-13
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-10
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-1
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
7ba5ad57
...
@@ -66,6 +66,10 @@ class PrefillMetadata:
...
@@ -66,6 +66,10 @@ class PrefillMetadata:
# Reuse this workspace buffer across all flashinfer wrappers
# Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer
=
None
global_workspace_buffer
=
None
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu
=
None
class
FlashInferAttnBackend
(
AttentionBackend
):
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
"""Flashinfer attention kernels."""
...
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_cpu
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
self
.
decode_wrappers
,
decode_wrappers
=
self
.
decode_wrappers
,
encoder_lens
=
forward_batch
.
encoder_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
...
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_cpu
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
...
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_cpu
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_verify
,
prefill_wrappers
=
self
.
prefill_wrappers_verify
,
...
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_cpu
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
prefix_lens
,
prefix_lens
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
...
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
seq_lens
.
cpu
(),
# may add a little overhead in capture stage
seq_lens_sum
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
encoder_lens
=
encoder_lens
,
...
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
seq_lens
.
cpu
(),
# may add a little overhead in capture stage
seq_lens_sum
,
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
prefill_wrappers
,
prefill_wrappers
=
prefill_wrappers
,
...
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
seq_lens
.
cpu
(),
# may add a little overhead in capture stage
seq_lens_sum
,
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
prefill_wrappers
,
prefill_wrappers
=
prefill_wrappers
,
...
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
seq_lens_cpu
[:
bs
]
if
seq_lens_cpu
is
not
None
else
None
,
seq_lens_sum
,
seq_lens_sum
,
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
...
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
seq_lens_cpu
[:
bs
]
if
seq_lens_cpu
is
not
None
else
None
,
seq_lens_sum
,
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
...
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens
[:
bs
],
seq_lens_cpu
[:
bs
]
if
seq_lens_cpu
is
not
None
else
None
,
seq_lens_sum
,
seq_lens_sum
,
prefix_lens
=
None
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
...
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
class
FlashInferIndicesUpdaterDecode
:
class
FlashInferIndicesUpdaterDecode
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
Attentio
nBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
FlashInferAtt
nBackend
):
# Parse Constants
# Parse Constants
self
.
num_qo_heads
=
(
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
...
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
self
.
kv_indptr
[
0
],
self
.
kv_indptr
[
0
],
None
,
None
,
spec_info
,
spec_info
,
seq_lens_cpu
,
)
)
def
update_sliding_window
(
def
update_sliding_window
(
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
):
assert
self
.
sliding_window_size
is
not
None
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
# Sliding window attention
# Sliding window attention
paged_kernel_lens_tmp
=
torch
.
minimum
(
# TODO: replace this with clamp
paged_kernel_lens_tmp
=
torch
.
clamp
(
seq_lens
,
seq_lens
,
max
=
self
.
sliding_window_size
+
1
torch
.
tensor
(
self
.
sliding_window_size
+
1
),
)
)
paged_kernel_lens_sum_tmp
=
paged_kernel_lens_tmp
.
sum
().
item
()
if
seq_lens_cpu
is
not
None
:
seq_lens_cpu_tmp
=
torch
.
clamp
(
seq_lens_cpu
,
max
=
self
.
sliding_window_size
+
1
)
paged_kernel_lens_sum_tmp
=
seq_lens_cpu_tmp
.
sum
().
item
()
else
:
paged_kernel_lens_sum_tmp
=
paged_kernel_lens_tmp
.
sum
().
item
()
kv_start_idx_tmp
=
seq_lens
-
paged_kernel_lens_tmp
kv_start_idx_tmp
=
seq_lens
-
paged_kernel_lens_tmp
else
:
else
:
# Full attention
# Full attention
paged_kernel_lens_tmp
=
seq_lens
paged_kernel_lens_tmp
=
seq_lens
paged_kernel_lens_sum_tmp
=
seq_lens_sum
paged_kernel_lens_sum_tmp
=
seq_lens_sum
seq_lens_cpu_tmp
=
seq_lens_cpu
kv_start_idx_tmp
=
None
kv_start_idx_tmp
=
None
use_sliding_window_kv_pool
=
wrapper_id
==
0
and
isinstance
(
use_sliding_window_kv_pool
=
wrapper_id
==
0
and
isinstance
(
...
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx_tmp
,
kv_start_idx_tmp
,
spec_info
,
spec_info
,
seq_lens_cpu
=
seq_lens_cpu_tmp
,
use_sliding_window_kv_pool
=
use_sliding_window_kv_pool
,
use_sliding_window_kv_pool
=
use_sliding_window_kv_pool
,
)
)
...
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
kv_start_idx
,
spec_info
,
spec_info
,
seq_lens_cpu
=
seq_lens_cpu
,
)
)
def
call_begin_forward
(
def
call_begin_forward
(
...
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
use_sliding_window_kv_pool
:
bool
=
False
,
use_sliding_window_kv_pool
:
bool
=
False
,
):
):
if
spec_info
is
None
:
if
spec_info
is
None
:
...
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
)
)
)
)
global
global_override_indptr_cpu
locally_override
=
False
if
seq_lens_cpu
is
not
None
and
global_override_indptr_cpu
is
None
:
locally_override
=
True
global_override_indptr_cpu
=
torch
.
empty_like
(
kv_indptr
,
device
=
"cpu"
)
global_override_indptr_cpu
[
0
]
=
0
global_override_indptr_cpu
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens_cpu
,
dim
=
0
)
wrapper
.
begin_forward
(
wrapper
.
begin_forward
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
...
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
non_blocking
=
True
,
non_blocking
=
True
,
)
)
if
locally_override
:
global_override_indptr_cpu
=
None
class
FlashInferIndicesUpdaterPrefill
:
class
FlashInferIndicesUpdaterPrefill
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
Attentio
nBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
FlashInferAtt
nBackend
):
# Parse Constants
# Parse Constants
self
.
num_qo_heads
=
(
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
...
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
...
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
...
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
...
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
...
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
)
)
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global
global_override_indptr_cpu
class
FlashInferMultiStepDraftBackend
:
class
FlashInferMultiStepDraftBackend
:
"""
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
Wrap multiple flashinfer attention backends as one for multiple consecutive
...
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
self
.
kv_last_page_len
=
torch
.
ones
(
self
.
kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
)
self
.
attn_backends
=
[]
self
.
attn_backends
:
List
[
FlashInferAttnBackend
]
=
[]
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
.
append
(
self
.
attn_backends
.
append
(
FlashInferAttnBackend
(
FlashInferAttnBackend
(
...
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
encoder_lens
=
None
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
None
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
)
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
7ba5ad57
...
@@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
attention_backend_str
=
global_server_args_dict
[
"prefill_attention_backend"
]
attention_backend_str
=
global_server_args_dict
[
"prefill_attention_backend"
]
# Create seq_lens_cpu when needed
# Create seq_lens_cpu when needed
if
(
if
(
attention_backend_str
==
"fa3"
attention_backend_str
or
(
in
[
global_server_args_dict
[
"use_mla_backend"
]
"fa3"
,
and
attention_backend_str
==
"flashinfer"
"flashinfer"
,
)
"flashmla"
,
or
attention_backend_str
==
"f
las
h
mla"
"cut
las
s_
mla"
,
or
attention_backend_str
==
"cutlass_mla"
"ascend"
,
or
attention_backend_str
==
"ascend"
"trtllm_mha"
,
or
attention_backend_str
==
"trtllm_mha"
"aiter"
,
or
attention_backend_str
==
"aiter"
]
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
):
seq_lens_cpu
=
(
seq_lens_cpu
=
(
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
7ba5ad57
...
@@ -729,10 +729,12 @@ class CudaGraphRunner:
...
@@ -729,10 +729,12 @@ class CudaGraphRunner:
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
seq_lens_cpu
=
None
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
if
pp_proxy_tensors
:
if
pp_proxy_tensors
:
for
key
in
self
.
pp_proxy_tensors
.
keys
():
for
key
in
self
.
pp_proxy_tensors
.
keys
():
...
@@ -766,7 +768,7 @@ class CudaGraphRunner:
...
@@ -766,7 +768,7 @@ class CudaGraphRunner:
self
.
encoder_lens
[:
bs
]
if
self
.
is_encoder_decoder
else
None
,
self
.
encoder_lens
[:
bs
]
if
self
.
is_encoder_decoder
else
None
,
self
.
capture_forward_mode
,
self
.
capture_forward_mode
,
forward_batch
.
spec_info
,
forward_batch
.
spec_info
,
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
,
seq_lens_cpu
=
seq_lens_cpu
,
)
)
# Store fields
# Store fields
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment