Unverified Commit 5a4b4b37 authored by Rahul Tuli's avatar Rahul Tuli Committed by GitHub
Browse files

Add: `SupportsEagle3` interface for explicit EAGLE3 support (#22642)


Signed-off-by: default avatarRahul Tuli <rtuli@redhat.com>
parent e5d3d63c
......@@ -3,12 +3,20 @@
import pytest
import torch
from vllm.model_executor.models.interfaces import supports_eagle3
@pytest.mark.parametrize(
"model_path",
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
def test_llama(vllm_runner, example_prompts, model_path):
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
......@@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
@pytest.mark.parametrize(
"model_path",
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
def test_qwen(vllm_runner, example_prompts, model_path):
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
# Set environment variable for V1 engine serialization
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
eagle3_supported = vllm_model.apply_model(supports_eagle3)
assert eagle3_supported
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens=20)
print(vllm_outputs)
......
......@@ -823,3 +823,56 @@ def supports_v0_only(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
return getattr(model, "supports_v0_only", False)
@runtime_checkable
class SupportsEagle3(Protocol):
"""The interface required for models that support
EAGLE3 speculative decoding."""
supports_eagle3: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports EAGLE3
speculative decoding.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""
Set which layers should output auxiliary
hidden states for EAGLE3.
Args:
layers: Tuple of layer indices that should output auxiliary
hidden states.
"""
...
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""
Get the layer indices that should output auxiliary hidden states
for EAGLE3.
Returns:
Tuple of layer indices for auxiliary hidden state outputs.
"""
...
@overload
def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]:
...
@overload
def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]:
...
def supports_eagle3(
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
return isinstance(model, SupportsEagle3)
......@@ -49,7 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
......@@ -463,7 +463,7 @@ class LlamaModel(nn.Module):
return loaded_params
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
......
......@@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
......@@ -261,7 +261,7 @@ class Qwen3Model(Qwen2Model):
decoder_layer_type=Qwen3DecoderLayer)
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
supports_eagle3,
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
......@@ -1981,8 +1982,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested")
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info("Model loading took %.4f GiB and %.6f seconds",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment