Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c64290dc
"examples/advanced/cugraph/graphsage.py" did not exist on "b76d0ed1db605d05e2f167cebc57ce55f6acda96"
Unverified
Commit
c64290dc
authored
Jun 16, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 16, 2025
Browse files
Use seq_len_fill_value in the cuda graph runners (#7233)
parent
8e2363dc
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
19 additions
and
19 deletions
+19
-19
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+3
-3
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+3
-4
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+6
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+4
-4
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
c64290dc
...
@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
"""Get the fill value for sequence length in CUDA graph."""
"""Get the fill value for sequence length in CUDA graph."""
return
0
return
1
def
_init_local_attn_metadata
(
self
,
metadata
:
FlashAttentionMetadata
,
device
):
def
_init_local_attn_metadata
(
self
,
metadata
:
FlashAttentionMetadata
,
device
):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
c64290dc
...
@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
raise
ValueError
(
"Invalid forward mode"
)
raise
ValueError
(
"Invalid forward mode"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
1
def
forward_extend
(
def
forward_extend
(
self
,
self
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c64290dc
...
@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
1
def
forward_extend
(
def
forward_extend
(
self
,
self
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
c64290dc
...
@@ -612,7 +612,7 @@ class CudaGraphRunner:
...
@@ -612,7 +612,7 @@ class CudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
# Common inputs
# Common inputs
...
@@ -624,7 +624,7 @@ class CudaGraphRunner:
...
@@ -624,7 +624,7 @@ class CudaGraphRunner:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
pp_proxy_tensors
:
if
pp_proxy_tensors
:
...
@@ -652,7 +652,7 @@ class CudaGraphRunner:
...
@@ -652,7 +652,7 @@ class CudaGraphRunner:
bs
,
bs
,
self
.
req_pool_indices
,
self
.
req_pool_indices
,
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
)
*
self
.
seq_len_fill_value
,
self
.
encoder_lens
,
self
.
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
forward_batch
.
spec_info
,
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
c64290dc
...
@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -187,9 +187,8 @@ class EAGLEDraftCudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
self
.
positions
.
zero_
()
num_tokens
=
bs
*
self
.
num_tokens_per_bs
num_tokens
=
bs
*
self
.
num_tokens_per_bs
...
@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -211,15 +210,15 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
# Special handle for seq_len_cpu used when flashinfer mla is used
if
forward_batch
.
seq_lens_cpu
is
not
None
and
bs
!=
raw_bs
:
if
forward_batch
.
seq_lens_cpu
is
not
None
and
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
forward_batch
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
bs
forward_batch
,
bs
)
)
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
# Replay
# Replay
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
c64290dc
...
@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -207,9 +207,9 @@ class EAGLEDraftExtendCudaGraphRunner:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
bs
=
self
.
capture_bs
[
index
]
bs
=
self
.
capture_bs
[
index
]
if
bs
*
self
.
num_tokens_per_bs
!=
num_tokens
:
if
bs
*
self
.
num_tokens_per_bs
!=
num_tokens
:
self
.
seq_lens
.
fill_
(
1
)
self
.
seq_lens
.
fill_
(
self
.
seq_len_fill_value
)
self
.
accept_length
.
fill_
(
1
)
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
self
.
accept_length
.
fill_
(
1
)
# Common inputs
# Common inputs
self
.
input_ids
[:
num_tokens
].
copy_
(
forward_batch
.
input_ids
)
self
.
input_ids
[:
num_tokens
].
copy_
(
forward_batch
.
input_ids
)
...
@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -223,18 +223,19 @@ class EAGLEDraftExtendCudaGraphRunner:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
self
.
seq_len_fill_value
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
bs
!=
raw_bs
:
if
bs
!=
raw_bs
:
forward_batch
.
spec_info
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
accept_length
=
self
.
accept_length
[:
bs
]
forward_batch
.
spec_info
.
positions
=
None
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
eagle_worker
.
draft_extend_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
bs
=
bs
,
req_pool_indices
=
self
.
req_pool_indices
,
req_pool_indices
=
self
.
req_pool_indices
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
self
.
seq_lens
,
seq_lens_sum
=
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
seq_lens_sum
=
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
)
*
self
.
seq_len_fill_value
,
encoder_lens
=
None
,
encoder_lens
=
None
,
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
,
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
,
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
c64290dc
...
@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -166,6 +166,10 @@ class EAGLEWorker(TpModelWorker):
def
init_attention_backend
(
self
):
def
init_attention_backend
(
self
):
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
self
.
has_prefill_wrapper_verify
=
False
self
.
draft_extend_attn_backend
=
None
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -213,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
from
sglang.srt.layers.attention.flashattention_backend
import
(
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
FlashAttentionBackend
,
...
@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -229,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self
.
draft_model_runner
,
self
.
draft_model_runner
,
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
self
.
has_prefill_wrapper_verify
=
False
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
from
sglang.srt.layers.attention.flashmla_backend
import
(
from
sglang.srt.layers.attention.flashmla_backend
import
(
FlashMLAMultiStepDraftBackend
,
FlashMLAMultiStepDraftBackend
,
...
@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -240,8 +242,6 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
has_prefill_wrapper_verify
=
False
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"EAGLE is not supported in attention backend
{
self
.
server_args
.
attention_backend
}
"
f
"EAGLE is not supported in attention backend
{
self
.
server_args
.
attention_backend
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment