Unverified Commit 8781cd6b authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Add Eagle and Eagle3 support to Transformers modeling backend (#30340)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent aa3c250c
...@@ -280,9 +280,20 @@ def test_speculators_model_integration( ...@@ -280,9 +280,20 @@ def test_speculators_model_integration(
@pytest.mark.parametrize( @pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill"], ["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[ [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False), (
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
),
pytest.param( pytest.param(
( (
"eagle3", "eagle3",
...@@ -292,6 +303,7 @@ def test_speculators_model_integration( ...@@ -292,6 +303,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3" reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
), ),
...@@ -305,6 +317,7 @@ def test_speculators_model_integration( ...@@ -305,6 +317,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32" reason="Skipping due to its head_dim not being a a multiple of 32"
), ),
...@@ -318,6 +331,7 @@ def test_speculators_model_integration( ...@@ -318,6 +331,7 @@ def test_speculators_model_integration(
), ),
False, False,
True, True,
"auto",
marks=large_gpu_mark(min_gb=40), marks=large_gpu_mark(min_gb=40),
), # works on 4x H100 ), # works on 4x H100
( (
...@@ -329,6 +343,7 @@ def test_speculators_model_integration( ...@@ -329,6 +343,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
), ),
pytest.param( pytest.param(
( (
...@@ -339,6 +354,7 @@ def test_speculators_model_integration( ...@@ -339,6 +354,7 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
pytest.param( pytest.param(
...@@ -350,6 +366,7 @@ def test_speculators_model_integration( ...@@ -350,6 +366,7 @@ def test_speculators_model_integration(
), ),
True, True,
True, True,
"auto",
marks=large_gpu_mark(min_gb=80), marks=large_gpu_mark(min_gb=80),
), # works on 4x H100 ), # works on 4x H100
( (
...@@ -361,10 +378,12 @@ def test_speculators_model_integration( ...@@ -361,10 +378,12 @@ def test_speculators_model_integration(
), ),
False, False,
False, False,
"auto",
), ),
], ],
ids=[ ids=[
"qwen3_eagle3", "qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3", "qwen3_vl_eagle3",
"qwen2_5_vl_eagle3", "qwen2_5_vl_eagle3",
"llama3_eagle", "llama3_eagle",
...@@ -381,6 +400,7 @@ def test_eagle_correctness( ...@@ -381,6 +400,7 @@ def test_eagle_correctness(
model_setup: tuple[str, str, str, int], model_setup: tuple[str, str, str, int],
mm_enabled: bool, mm_enabled: bool,
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str, attn_backend: str,
): ):
if attn_backend == "TREE_ATTN": if attn_backend == "TREE_ATTN":
...@@ -389,6 +409,17 @@ def test_eagle_correctness( ...@@ -389,6 +409,17 @@ def test_eagle_correctness(
"TREE_ATTN is flaky in the test disable for now until it can be " "TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)" "resolved (see https://github.com/vllm-project/vllm/issues/22922)"
) )
if model_impl == "transformers":
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture # Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled) test_prompts = get_test_prompts(mm_enabled)
...@@ -448,6 +479,7 @@ def test_eagle_correctness( ...@@ -448,6 +479,7 @@ def test_eagle_correctness(
max_model_len=max_model_len, max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
) )
spec_outputs = spec_llm.chat(test_prompts, sampling_config) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0 matches = 0
......
...@@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices ...@@ -36,6 +36,8 @@ from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
...@@ -92,7 +94,15 @@ def vllm_flash_attention_forward( ...@@ -92,7 +94,15 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): class Base(
nn.Module,
VllmModel,
SupportsQuant,
SupportsLoRA,
SupportsPP,
SupportsEagle,
SupportsEagle3,
):
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
...@@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -131,17 +141,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# Weights to skip in `self.load_weights` # Attrs for weight loading (see self.load_weights)
self.skip_prefixes: list[str] = [] self.skip_prefixes: list[str] = []
"""Skip loading weights whose qualname starts with these prefixes.""" """Skip loading weights whose qualname starts with these prefixes."""
self.skip_substrs: list[str] = [] self.skip_substrs: list[str] = []
"""Skip loading weights whose qualname contains these substrings.""" """Skip loading weights whose qualname contains these substrings."""
self.ignore_unexpected_prefixes: list[str] = [] self.ignore_unexpected_prefixes: list[str] = []
"""Ignore unexpected weights whose qualname starts with these prefixes. """Ignore unexpected weights whose qualname starts with these prefixes."""
"""
self.ignore_unexpected_suffixes: list[str] = [] self.ignore_unexpected_suffixes: list[str] = []
"""Ignore unexpected weights whose qualname ends with these suffixes.""" """Ignore unexpected weights whose qualname ends with these suffixes."""
# Attrs for Eagle3 (see self.set_aux_hidden_state_layers)
self._target_class: type[nn.Module] = nn.Module
"""Target class for Eagle3 aux hidden state recording."""
self._layer_names: dict[int, str] = {}
"""Mapping from layer index to layer name for Eagle3."""
self._output_aux_hidden_states_kwargs: dict[str, bool] = {}
"""Kwargs to pass to model forward for Eagle3 aux hidden states."""
if self.quant_config: if self.quant_config:
quant_method_name = self.quant_config.get_name() quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods. # Check for unsupported quantization methods.
...@@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -278,6 +295,15 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
new_module = child_module new_module = child_module
qual_name = maybe_prefix(prefix, child_name) qual_name = maybe_prefix(prefix, child_name)
# Populate Eagle3 attrs
if (
isinstance(module, nn.ModuleList)
and len(module) == self.text_config.num_hidden_layers
):
self._target_class = type(child_module)
layer_name = qual_name.removeprefix("model.")
self._layer_names[int(child_name)] = layer_name
# Replace modules as needed
if isinstance(child_module, nn.Linear): if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name)) generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None) pattern = next(generator, None)
...@@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -425,19 +451,26 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
else: else:
position_ids = positions[None, ...] position_ids = positions[None, ...]
hidden_states = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=False, use_cache=False,
position_ids=position_ids, position_ids=position_ids,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False, return_dict=False,
**self._output_aux_hidden_states_kwargs,
**kwargs, **kwargs,
)[0][0, ...] # we remove batch dimension for now )
# We must remove the batch dimension from these outputs
hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
if self._output_aux_hidden_states_kwargs and len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def load_weights( def load_weights(
...@@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP): ...@@ -462,3 +495,24 @@ class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
f"Transformers modeling backend requires transformers>={required} " f"Transformers modeling backend requires transformers>={required} "
f"for {feature}, but got {installed}" f"for {feature}, but got {installed}"
) )
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.check_version("5.0.0.dev0", "Eagle3 support")
from transformers.utils.generic import OutputRecorder
# The default value in PreTrainedModel is None
if self.model._can_record_outputs is None:
self.model._can_record_outputs = {}
target_class = self._target_class
for layer in layers:
# layer - 1 because we want the input to the layer
layer_name = self._layer_names[layer - 1]
layer_key = f"aux_hidden_state_{layer}"
aux_hidden_state_i = OutputRecorder(target_class, layer_name=layer_name)
self.model._can_record_outputs[layer_key] = aux_hidden_state_i
self._output_aux_hidden_states_kwargs[f"output_{layer_key}"] = True
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = self.text_config.num_hidden_layers
return (2, num_layers // 2, num_layers - 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment