Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
90532b76
Unverified
Commit
90532b76
authored
Mar 18, 2025
by
Baizhou Zhang
Committed by
GitHub
Mar 18, 2025
Browse files
[Fix] Fix raw_bs bug when using flashinfer mla and eagle (#4557)
parent
c0e9a36c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
0 deletions
+11
-0
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+11
-0
No files found.
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
90532b76
...
@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -52,6 +52,9 @@ class EAGLEDraftCudaGraphRunner:
self
.
seq_len_fill_value
=
self
.
model_runner
.
draft_attn_backend
.
attn_backends
[
self
.
seq_len_fill_value
=
self
.
model_runner
.
draft_attn_backend
.
attn_backends
[
0
0
].
get_cuda_graph_seq_len_fill_value
()
].
get_cuda_graph_seq_len_fill_value
()
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
if
self
.
enable_torch_compile
:
if
self
.
enable_torch_compile
:
set_torch_compile_config
()
set_torch_compile_config
()
...
@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -210,6 +213,12 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
forward_batch
.
positions
=
self
.
positions
[:
num_tokens
]
# Special handle for seq_len_cpu used when flashinfer mla is used
if
(
forward_batch
.
decode_seq_lens_cpu
is
not
None
)
and
(
bs
!=
raw_bs
):
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
decode_seq_lens_cpu
)
forward_batch
.
decode_seq_lens_cpu
=
self
.
seq_lens_cpu
[:
bs
]
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
draft_attn_backend
.
init_forward_metadata_replay_cuda_graph
(
forward_batch
,
bs
forward_batch
,
bs
)
)
...
@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -224,5 +233,7 @@ class EAGLEDraftCudaGraphRunner:
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
positions
=
self
.
positions
[:
raw_num_token
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
seq_lens
=
self
.
seq_lens
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
forward_batch
.
req_pool_indices
=
self
.
req_pool_indices
[:
raw_bs
]
if
forward_batch
.
decode_seq_lens_cpu
is
not
None
:
forward_batch
.
decode_seq_lens_cpu
=
self
.
seq_lens_cpu
[:
raw_bs
]
return
out
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment