Unverified Commit aa84e43c authored by Rémi Delacourt's avatar Rémi Delacourt Committed by GitHub
Browse files

[Pixtral] Enable Pixtral language model support Eagle3 (#37182)


Signed-off-by: default avatarremi <remi@mistral.ai>
parent 5e806bcf
...@@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
supports_eagle3,
) )
from .module_mapping import MultiModelKeys from .module_mapping import MultiModelKeys
from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix
...@@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) ...@@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
dummy_inputs=PixtralDummyInputsBuilder, dummy_inputs=PixtralDummyInputsBuilder,
) )
class PixtralForConditionalGeneration( class PixtralForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP
): ):
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
...@@ -390,6 +392,21 @@ class PixtralForConditionalGeneration( ...@@ -390,6 +392,21 @@ class PixtralForConditionalGeneration(
) -> torch.Tensor | None: ) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def _require_language_model_eagle3(self) -> None:
if not supports_eagle3(self.language_model):
raise RuntimeError(
f"EAGLE-3 speculative decoding requires the language model to "
f"support EAGLE-3, but {type(self.language_model).__name__} does not."
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self._require_language_model_eagle3()
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
self._require_language_model_eagle3()
return self.language_model.get_eagle3_aux_hidden_state_layers()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith(("vision_encoder", "vision_tower")) return weight[0].startswith(("vision_encoder", "vision_tower"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment