Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d07a3d6
Unverified
Commit
9d07a3d6
authored
Mar 11, 2026
by
Rahul Tuli
Committed by
GitHub
Mar 11, 2026
Browse files
Add: Eagle3 support for Qwen3.5 (#36658)
Signed-off-by:
Rahul-Tuli
<
rtuli@redhat.com
>
parent
646b8554
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
2 deletions
+25
-2
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+11
-0
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+14
-2
No files found.
vllm/model_executor/models/qwen3_5.py
View file @
9d07a3d6
...
@@ -75,6 +75,7 @@ from .interfaces import (
...
@@ -75,6 +75,7 @@ from .interfaces import (
IsHybrid
,
IsHybrid
,
MixtureOfExperts
,
MixtureOfExperts
,
MultiModalEmbeddings
,
MultiModalEmbeddings
,
SupportsEagle3
,
SupportsLoRA
,
SupportsLoRA
,
SupportsPP
,
SupportsPP
,
_require_is_multimodal
,
_require_is_multimodal
,
...
@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel):
else
:
else
:
self
.
norm
=
PPMissingLayer
()
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
()
def
load_fused_expert_weights
(
def
load_fused_expert_weights
(
self
,
self
,
name
:
str
,
name
:
str
,
...
@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel):
...
@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel):
class
Qwen3_5ForCausalLMBase
(
class
Qwen3_5ForCausalLMBase
(
nn
.
Module
,
nn
.
Module
,
HasInnerState
,
HasInnerState
,
SupportsEagle3
,
SupportsLoRA
,
SupportsLoRA
,
SupportsPP
,
SupportsPP
,
):
):
...
@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase(
...
@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase(
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
set_aux_hidden_state_layers
(
self
,
layers
:
tuple
[
int
,
...])
->
None
:
self
.
model
.
aux_hidden_state_layers
=
layers
def
get_eagle3_aux_hidden_state_layers
(
self
)
->
tuple
[
int
,
...]:
num_layers
=
len
(
self
.
model
.
layers
)
return
(
2
,
num_layers
//
2
,
num_layers
-
3
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/qwen3_next.py
View file @
9d07a3d6
...
@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module):
...
@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module):
else
:
else
:
self
.
norm
=
PPMissingLayer
()
self
.
norm
=
PPMissingLayer
()
self
.
aux_hidden_state_layers
:
tuple
[
int
,
...]
=
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module):
...
@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
IntermediateTensors
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
:
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module):
...
@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
aux_hidden_states
=
[]
for
layer_idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
),
start
=
self
.
start_layer
,
):
if
layer_idx
in
self
.
aux_hidden_state_layers
:
aux_hidden_states
.
append
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module):
...
@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module):
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
aux_hidden_states
:
return
hidden_states
,
aux_hidden_states
return
hidden_states
return
hidden_states
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
def
get_expert_mapping
(
self
)
->
list
[
tuple
[
str
,
str
,
int
,
str
]]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment