Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37d83c6e
Unverified
Commit
37d83c6e
authored
Sep 08, 2025
by
Lzhang-hub
Committed by
GitHub
Sep 07, 2025
Browse files
Qwen2.5-VL eagle3 infer (#8801)
parent
7802586c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
114 additions
and
5 deletions
+114
-5
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+1
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+5
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+53
-1
python/sglang/srt/models/llama_eagle3.py
python/sglang/srt/models/llama_eagle3.py
+13
-0
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+7
-0
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+24
-1
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+5
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+5
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+1
-0
No files found.
python/sglang/srt/managers/mm_utils.py
View file @
37d83c6e
...
...
@@ -629,6 +629,7 @@ def general_mm_embed_routine(
embed_tokens
=
language_model
.
get_input_embeddings
()
if
(
not
forward_batch
.
forward_mode
.
is_decode
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
forward_batch
.
contains_mm_inputs
()
):
mm_inputs_list
=
[
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
37d83c6e
...
...
@@ -317,7 +317,9 @@ class CudaGraphRunner:
(
self
.
max_num_token
,),
dtype
=
self
.
_cache_loc_dtype
()
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
self
.
max_num_token
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
tbo_plugin
=
TboCudaGraphRunnerPlugin
()
...
...
@@ -532,7 +534,7 @@ class CudaGraphRunner:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
b
s
]
mrope_positions
=
self
.
mrope_positions
[:,
:
num_token
s
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
num_tokens
]
self
.
num_token_non_padded
[...]
=
num_tokens
...
...
@@ -751,7 +753,7 @@ class CudaGraphRunner:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_
bs
].
copy_
(
forward_batch
.
mrope_positions
)
self
.
mrope_positions
[:,
:
raw_
num_token
].
copy_
(
forward_batch
.
mrope_positions
)
if
self
.
require_gathered_buffer
:
self
.
global_num_tokens_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
self
.
global_num_tokens_for_logprob_gpu
.
fill_
(
bs
*
self
.
num_tokens_per_bs
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
37d83c6e
...
...
@@ -441,7 +441,13 @@ class ForwardBatch:
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens
if
model_runner
.
model_is_mrope
:
ret
.
_compute_mrope_positions
(
model_runner
,
batch
)
if
(
ret
.
spec_info
is
not
None
and
getattr
(
ret
.
spec_info
,
"positions"
,
None
)
is
not
None
):
ret
.
_compute_spec_mrope_positions
(
model_runner
,
batch
)
else
:
ret
.
_compute_mrope_positions
(
model_runner
,
batch
)
# Init lora information
if
model_runner
.
server_args
.
enable_lora
:
...
...
@@ -507,6 +513,52 @@ class ForwardBatch:
or
self
.
contains_image_inputs
()
)
def
_compute_spec_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
# TODO support batched deltas
batch_size
=
self
.
seq_lens
.
shape
[
0
]
device
=
model_runner
.
device
mm_inputs
=
batch
.
multimodal_inputs
if
batch
.
forward_mode
.
is_draft_extend
():
# draft_extend_after_decode
mrope_deltas
=
[]
extend_lens
=
[]
for
batch_idx
in
range
(
batch_size
):
extend_seq_len
=
batch
.
extend_seq_lens
[
batch_idx
]
extend_lens
.
append
(
extend_seq_len
)
mrope_delta
=
(
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
)
if
mm_inputs
[
batch_idx
]
is
None
else
mm_inputs
[
batch_idx
].
mrope_position_delta
.
squeeze
(
0
)
)
mrope_deltas
.
append
(
mrope_delta
.
to
(
device
=
device
))
position_chunks
=
torch
.
split
(
batch
.
spec_info
.
positions
,
extend_lens
)
mrope_positions_list
=
[
pos_chunk
+
delta
for
pos_chunk
,
delta
in
zip
(
position_chunks
,
mrope_deltas
)
]
next_input_positions
=
(
torch
.
cat
(
mrope_positions_list
,
dim
=
0
).
unsqueeze
(
0
).
repeat
(
3
,
1
)
)
else
:
# target_verify or draft_decode
seq_positions
=
batch
.
spec_info
.
positions
.
view
(
batch_size
,
-
1
)
mrope_deltas
=
[
(
torch
.
tensor
([
0
],
dtype
=
torch
.
int64
)
if
mm_inputs
[
i
]
is
None
else
mm_inputs
[
i
].
mrope_position_delta
.
squeeze
(
0
)
)
for
i
in
range
(
batch_size
)
]
mrope_delta_tensor
=
torch
.
stack
(
mrope_deltas
,
dim
=
0
).
to
(
device
=
device
)
next_input_positions
=
(
(
seq_positions
+
mrope_delta_tensor
).
flatten
().
unsqueeze
(
0
).
repeat
(
3
,
1
)
)
self
.
mrope_positions
=
next_input_positions
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
...
...
python/sglang/srt/models/llama_eagle3.py
View file @
37d83c6e
...
...
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
is_mrope_enabled
=
(
hasattr
(
config
,
"rope_scaling"
)
and
config
.
rope_scaling
is
not
None
and
"mrope_section"
in
config
.
rope_scaling
)
# fix rope_scaling for qwen2.5-vl
if
self
.
is_mrope_enabled
:
config
.
rope_scaling
[
"rope_type"
]
=
"default"
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
...
...
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
else
:
embeds
=
input_embeds
if
self
.
is_mrope_enabled
:
positions
=
forward_batch
.
mrope_positions
hidden_states
=
forward_batch
.
spec_info
.
hidden_states
if
hidden_states
.
shape
[
-
1
]
!=
embeds
.
shape
[
-
1
]:
hidden_states
=
self
.
fc
(
hidden_states
)
...
...
python/sglang/srt/models/qwen2.py
View file @
37d83c6e
...
...
@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
(
input_ids
)
...
...
@@ -481,6 +484,10 @@ class Qwen2ForCausalLM(nn.Module):
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
37d83c6e
...
...
@@ -518,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
return
pattern
.
pad_input_tokens
(
input_ids
,
mm_inputs
)
...
...
@@ -588,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
positions
=
positions
,
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
...
...
@@ -644,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
self
.
capture_aux_hidden_states
=
True
self
.
model
.
capture_aux_hidden_states
=
True
if
layer_ids
is
None
:
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
# Specific layers for EAGLE3 support
else
:
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
EntryClass
=
[
Qwen2_5_VLForConditionalGeneration
]
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
37d83c6e
...
...
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
(
self
.
max_num_token
*
self
.
speculative_num_steps
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
self
.
max_num_token
),
dtype
=
torch
.
int64
)
self
.
topk_p
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
float32
)
self
.
topk_index
=
torch
.
zeros
((
self
.
max_bs
,
self
.
topk
),
dtype
=
torch
.
int64
)
self
.
hidden_states
=
torch
.
zeros
(
...
...
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
seq_lens
=
self
.
seq_lens
[:
num_seqs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
*
self
.
speculative_num_steps
]
positions
=
self
.
positions
[:
num_tokens
]
mrope_positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
topk_p
=
self
.
topk_p
[:
num_seqs
]
topk_index
=
self
.
topk_index
[:
num_seqs
]
hidden_states
=
self
.
hidden_states
[:
num_seqs
]
...
...
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
return_logprob
=
False
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
global_num_tokens_gpu
=
global_num_tokens
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
global_dp_buffer_len
=
global_dp_buffer_len
,
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
37d83c6e
...
...
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
ones
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
(
(
3
,
self
.
max_num_token
),
dtype
=
torch
.
int64
)
if
self
.
eagle_worker
.
speculative_algorithm
.
is_eagle3
():
self
.
hidden_states
=
torch
.
zeros
(
...
...
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
accept_length
=
self
.
accept_length
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
mrope_positions
=
self
.
mrope_positions
[:,
:
num_tokens
]
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
next_token_logits_buffer
=
self
.
next_token_logits_buffer
[:
bs
]
...
...
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum
=
seq_lens
.
sum
().
item
(),
return_logprob
=
False
,
positions
=
positions
,
mrope_positions
=
mrope_positions
,
global_num_tokens_gpu
=
self
.
global_num_tokens_gpu
,
global_num_tokens_for_logprob_gpu
=
self
.
global_num_tokens_for_logprob_gpu
,
dp_padding_mode
=
DpPaddingMode
.
get_default_mode_in_cuda_graph
(),
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
37d83c6e
...
...
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
get_last_loc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment