Unverified Commit bd92089d authored by Iris's avatar Iris Committed by GitHub
Browse files

feature: support eagle3 for HunyuanVL & Hunyuan (#33035)


Signed-off-by: default avataririsliu10 <601012173@qq.com>
Signed-off-by: default avatarIris <38269816+irisliu10@users.noreply.github.com>
parent a6760f15
...@@ -675,7 +675,14 @@ class SpeculativeConfig: ...@@ -675,7 +675,14 @@ class SpeculativeConfig:
f"{self.disable_by_batch_size=}" f"{self.disable_by_batch_size=}"
) )
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] eagle3_target_supported = [
"llama",
"qwen",
"minicpm",
"gpt_oss",
"hunyuan_vl",
"hunyuan_v1_dense",
]
if ( if (
self.method == "eagle3" self.method == "eagle3"
and self.target_model_config and self.target_model_config
......
...@@ -66,7 +66,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -66,7 +66,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
...@@ -630,6 +630,7 @@ class HunYuanModel(nn.Module): ...@@ -630,6 +630,7 @@ class HunYuanModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -654,9 +655,13 @@ class HunYuanModel(nn.Module): ...@@ -654,9 +655,13 @@ class HunYuanModel(nn.Module):
cla_factor = _get_cla_factor(self.config) cla_factor = _get_cla_factor(self.config)
prev_kv_states = None prev_kv_states = None
aux_hidden_states = []
for i, layer in enumerate( for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual, kv_states = layer( hidden_states, residual, kv_states = layer(
positions, positions,
hidden_states, hidden_states,
...@@ -675,6 +680,9 @@ class HunYuanModel(nn.Module): ...@@ -675,6 +680,9 @@ class HunYuanModel(nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def _split_qkv_weight(self, qkv: torch.Tensor): def _split_qkv_weight(self, qkv: torch.Tensor):
...@@ -897,7 +905,7 @@ class HunYuanModel(nn.Module): ...@@ -897,7 +905,7 @@ class HunYuanModel(nn.Module):
return loaded_params return loaded_params
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -936,6 +944,13 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP): ...@@ -936,6 +944,13 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
......
...@@ -83,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape ...@@ -83,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
...@@ -780,6 +781,7 @@ class HunYuanVLForConditionalGeneration( ...@@ -780,6 +781,7 @@ class HunYuanVLForConditionalGeneration(
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
SupportsXDRoPE, SupportsXDRoPE,
SupportsEagle3,
): ):
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
...@@ -966,6 +968,13 @@ class HunYuanVLForConditionalGeneration( ...@@ -966,6 +968,13 @@ class HunYuanVLForConditionalGeneration(
multimodal_embeddings += tuple(image_embeddings) multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings return multimodal_embeddings
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
......
...@@ -115,6 +115,8 @@ class SpecDecodeBaseProposer: ...@@ -115,6 +115,8 @@ class SpecDecodeBaseProposer:
# Use draft model's M-RoPE setting, not target model's # Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal # Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope self.uses_mrope = self.draft_model_config.uses_mrope
self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
if self.uses_mrope: if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy # NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work # position on purpose to make it non-contiguous so that it can work
...@@ -129,6 +131,12 @@ class SpecDecodeBaseProposer: ...@@ -129,6 +131,12 @@ class SpecDecodeBaseProposer:
self.mrope_positions = torch.zeros( self.mrope_positions = torch.zeros(
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
) )
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions = torch.zeros(
(self.uses_xdrope_dim, self.max_num_tokens + 1),
dtype=torch.int64,
device=device,
)
else: else:
# RoPE need (max_num_tokens,) # RoPE need (max_num_tokens,)
self.positions = torch.zeros( self.positions = torch.zeros(
...@@ -221,11 +229,15 @@ class SpecDecodeBaseProposer: ...@@ -221,11 +229,15 @@ class SpecDecodeBaseProposer:
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions[:, :num_tokens] return self.mrope_positions[:, :num_tokens]
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
return self.xdrope_positions[:, :num_tokens]
return self.positions[:num_tokens] return self.positions[:num_tokens]
def _set_positions(self, num_tokens: int, positions: torch.Tensor): def _set_positions(self, num_tokens: int, positions: torch.Tensor):
if self.uses_mrope: if self.uses_mrope:
self.mrope_positions[:, :num_tokens] = positions self.mrope_positions[:, :num_tokens] = positions
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions[:, :num_tokens] = positions
else: else:
# Convert M-RoPE positions if target model uses M-RoPE # Convert M-RoPE positions if target model uses M-RoPE
# but draft doesn't, For text inputs, all M-RoPE # but draft doesn't, For text inputs, all M-RoPE
...@@ -623,6 +635,8 @@ class SpecDecodeBaseProposer: ...@@ -623,6 +635,8 @@ class SpecDecodeBaseProposer:
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions) self._set_positions(num_tokens, target_positions)
return num_tokens, last_token_indices, cad return num_tokens, last_token_indices, cad
...@@ -1126,6 +1140,7 @@ class SpecDecodeBaseProposer: ...@@ -1126,6 +1140,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration", "GlmOcrForConditionalGeneration",
]: ]:
self.model.config.image_token_index = target_model.config.image_token_id self.model.config.image_token_index = target_model.config.image_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment