Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
87a0f7d2
Unverified
Commit
87a0f7d2
authored
Aug 30, 2025
by
KerwinKai
Committed by
GitHub
Aug 29, 2025
Browse files
[feat] Support EAGLE3 for Qwen2 (#9216)
parent
839c93bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
5 deletions
+48
-5
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+26
-3
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+22
-2
No files found.
python/sglang/srt/models/qwen2.py
View file @
87a0f7d2
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# Modify details for the adaptation of Qwen2 model.
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
)
else
:
else
:
# ranks other than the last rank will have a placeholder layer
# ranks other than the last rank will have a placeholder layer
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
...
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embedding
(
input_ids
)
return
self
.
model
.
get_input_embedding
(
input_ids
)
...
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
input_embeds
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
if
not
get_embedding
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
...
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
def
load_kv_cache_scales
(
self
,
quantization_param_path
:
str
)
->
None
:
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
self
.
capture_aux_hidden_states
=
True
if
layer_ids
is
None
:
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
# Specific layers for EAGLE3 support
else
:
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
EntryClass
=
Qwen2ForCausalLM
EntryClass
=
Qwen2ForCausalLM
python/sglang/srt/models/qwen2_moe.py
View file @
87a0f7d2
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
],
use_attn_tp_group
=
global_server_args_dict
[
"enable_dp_lm_head"
],
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
)
else
:
else
:
return
hidden_states
return
hidden_states
...
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
num_groups
=
None
,
num_groups
=
None
,
)
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
self
.
capture_aux_hidden_states
=
True
if
layer_ids
is
None
:
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
,
]
# Specific layers for EAGLE3 support
else
:
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment