Unverified Commit 4df44c16 authored by TundeAtSN's avatar TundeAtSN Committed by GitHub
Browse files

Enable Eagle3 speculative decoding for Mistral3ForConditionalGeneration to support eagle3 (#33939)


Signed-off-by: default avatarAkintunde Oladipo <akintunde.oladipo@servicenow.com>
Signed-off-by: default avatarTundeAtSN <akintunde.oladipo@servicenow.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 81fe69ca
...@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
...@@ -408,7 +409,7 @@ def init_vision_tower_for_llava( ...@@ -408,7 +409,7 @@ def init_vision_tower_for_llava(
dummy_inputs=Mistral3DummyInputsBuilder, dummy_inputs=Mistral3DummyInputsBuilder,
) )
class Mistral3ForConditionalGeneration( class Mistral3ForConditionalGeneration(
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
...@@ -432,6 +433,13 @@ class Mistral3ForConditionalGeneration( ...@@ -432,6 +433,13 @@ class Mistral3ForConditionalGeneration(
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.get_language_model().model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.get_language_model().model.layers)
return (2, num_layers // 2, num_layers - 3)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment