Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa84e43c
Unverified
Commit
aa84e43c
authored
Mar 20, 2026
by
Rémi Delacourt
Committed by
GitHub
Mar 20, 2026
Browse files
[Pixtral] Enable Pixtral language model support Eagle3 (#37182)
Signed-off-by:
remi
<
remi@mistral.ai
>
parent
5e806bcf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
1 deletion
+18
-1
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+18
-1
No files found.
vllm/model_executor/models/pixtral.py
View file @
aa84e43c
...
@@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
...
@@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.interfaces
import
(
from
.interfaces
import
(
MultiModalEmbeddings
,
MultiModalEmbeddings
,
SupportsEagle3
,
SupportsLoRA
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
,
SupportsPP
,
supports_eagle3
,
)
)
from
.module_mapping
import
MultiModelKeys
from
.module_mapping
import
MultiModelKeys
from
.utils
import
StageMissingLayer
,
init_vllm_registered_model
,
maybe_prefix
from
.utils
import
StageMissingLayer
,
init_vllm_registered_model
,
maybe_prefix
...
@@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
...
@@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
dummy_inputs
=
PixtralDummyInputsBuilder
,
dummy_inputs
=
PixtralDummyInputsBuilder
,
)
)
class
PixtralForConditionalGeneration
(
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
nn
.
Module
,
SupportsLoRA
,
SupportsEagle3
,
SupportsMultiModal
,
SupportsPP
):
):
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
...
@@ -390,6 +392,21 @@ class PixtralForConditionalGeneration(
...
@@ -390,6 +392,21 @@ class PixtralForConditionalGeneration(
)
->
torch
.
Tensor
|
None
:
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
_require_language_model_eagle3
(
self
)
->
None
:
if
not
supports_eagle3
(
self
.
language_model
):
raise
RuntimeError
(
f
"EAGLE-3 speculative decoding requires the language model to "
f
"support EAGLE-3, but
{
type
(
self
.
language_model
).
__name__
}
does not."
)
def
set_aux_hidden_state_layers
(
self
,
layers
:
tuple
[
int
,
...])
->
None
:
self
.
_require_language_model_eagle3
()
self
.
language_model
.
set_aux_hidden_state_layers
(
layers
)
def
get_eagle3_aux_hidden_state_layers
(
self
)
->
tuple
[
int
,
...]:
self
.
_require_language_model_eagle3
()
return
self
.
language_model
.
get_eagle3_aux_hidden_state_layers
()
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
is_vision_encoder_weights
(
weight
:
tuple
[
str
,
torch
.
Tensor
]):
def
is_vision_encoder_weights
(
weight
:
tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
((
"vision_encoder"
,
"vision_tower"
))
return
weight
[
0
].
startswith
((
"vision_encoder"
,
"vision_tower"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment