Unverified Commit 1964c325 authored by Ximingwang-09's avatar Ximingwang-09 Committed by GitHub
Browse files

[feat] Support EAGLE3 for Qwen (#7745)


Co-authored-by: default avatar纬杭 <ximing.wxm@antgroup.com>
Co-authored-by: default avatarzyksir <zyksir@outlook.com>
parent af564774
...@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module): ...@@ -293,6 +293,9 @@ class Qwen2Model(nn.Module):
else: else:
self.norm = PPMissingLayer(return_tuple=True) self.norm = PPMissingLayer(return_tuple=True)
# For EAGLE3 support
self.layers_to_capture = []
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
return self.get_input_embeddings()(input_ids) * self.config.scale_emb return self.get_input_embeddings()(input_ids) * self.config.scale_emb
...@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module): ...@@ -321,7 +324,12 @@ class Qwen2Model(nn.Module):
hidden_states = pp_proxy_tensors["hidden_states"] hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"] residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
...@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module): ...@@ -342,7 +350,11 @@ class Qwen2Model(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
...@@ -440,6 +440,9 @@ class Qwen2MoeModel(nn.Module): ...@@ -440,6 +440,9 @@ class Qwen2MoeModel(nn.Module):
else: else:
self.norm = PPMissingLayer(return_tuple=True) self.norm = PPMissingLayer(return_tuple=True)
# For EAGLE3 support
self.layers_to_capture = []
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -459,6 +462,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -459,6 +462,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states = pp_proxy_tensors["hidden_states"] hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"] residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
if forward_batch.can_run_tbo: if forward_batch.can_run_tbo:
hidden_states, residual = model_forward_maybe_tbo( hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers, layers=self.layers,
...@@ -471,6 +475,12 @@ class Qwen2MoeModel(nn.Module): ...@@ -471,6 +475,12 @@ class Qwen2MoeModel(nn.Module):
) )
else: else:
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
if i in self.layers_to_capture:
aux_hidden_states.append(
hidden_states + residual
if residual is not None
else hidden_states
)
with get_global_expert_distribution_recorder().with_current_layer(i): with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
...@@ -489,7 +499,11 @@ class Qwen2MoeModel(nn.Module): ...@@ -489,7 +499,11 @@ class Qwen2MoeModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class Qwen2MoeForCausalLM(nn.Module): class Qwen2MoeForCausalLM(nn.Module):
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import logging import logging
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -325,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -325,6 +325,9 @@ class Qwen3ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -346,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -346,10 +349,18 @@ class Qwen3ForCausalLM(nn.Module):
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
...@@ -447,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -447,5 +458,20 @@ class Qwen3ForCausalLM(nn.Module):
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
EntryClass = Qwen3ForCausalLM EntryClass = Qwen3ForCausalLM
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -717,6 +717,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -717,6 +717,7 @@ class Qwen3MoeForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -735,9 +736,13 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -735,9 +736,13 @@ class Qwen3MoeForCausalLM(nn.Module):
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
) )
else: else:
return hidden_states return hidden_states
...@@ -750,6 +755,24 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -750,6 +755,24 @@ class Qwen3MoeForCausalLM(nn.Module):
def end_layer(self): def end_layer(self):
return self.model.end_layer return self.model.end_layer
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment