"vllm/vscode:/vscode.git/clone" did not exist on "55ddae82d7992c8ce0923adbf9a41158cf6fccb2"
Unverified Commit 9d07a3d6 authored by Rahul Tuli's avatar Rahul Tuli Committed by GitHub
Browse files

Add: Eagle3 support for Qwen3.5 (#36658)


Signed-off-by: default avatarRahul-Tuli <rtuli@redhat.com>
parent 646b8554
...@@ -75,6 +75,7 @@ from .interfaces import ( ...@@ -75,6 +75,7 @@ from .interfaces import (
IsHybrid, IsHybrid,
MixtureOfExperts, MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
_require_is_multimodal, _require_is_multimodal,
...@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -353,6 +354,8 @@ class Qwen3_5Model(Qwen3NextModel):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def load_fused_expert_weights( def load_fused_expert_weights(
self, self,
name: str, name: str,
...@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel): ...@@ -536,6 +539,7 @@ class Qwen3_5Model(Qwen3NextModel):
class Qwen3_5ForCausalLMBase( class Qwen3_5ForCausalLMBase(
nn.Module, nn.Module,
HasInnerState, HasInnerState,
SupportsEagle3,
SupportsLoRA, SupportsLoRA,
SupportsPP, SupportsPP,
): ):
...@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase( ...@@ -592,6 +596,13 @@ class Qwen3_5ForCausalLMBase(
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids) return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module): ...@@ -1148,6 +1148,8 @@ class Qwen3NextModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int, ...] = ()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module): ...@@ -1157,7 +1159,7 @@ class Qwen3NextModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module): ...@@ -1169,7 +1171,15 @@ class Qwen3NextModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer): aux_hidden_states = []
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer( hidden_states, residual = layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module): ...@@ -1181,6 +1191,8 @@ class Qwen3NextModel(nn.Module):
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment