Unverified Commit 87a0f7d2 authored by KerwinKai's avatar KerwinKai Committed by GitHub
Browse files

[feat] Support EAGLE3 for Qwen2 (#9216)

parent 839c93bd
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
else: else:
# ranks other than the last rank will have a placeholder layer # ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
...@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# For EAGLE3 support
self.capture_aux_hidden_states = False
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids) return self.model.get_input_embedding(input_ids)
...@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
input_embeds, input_embeds,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
...@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
EntryClass = Qwen2ForCausalLM EntryClass = Qwen2ForCausalLM
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
# For EAGLE3 support
self.capture_aux_hidden_states = False
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds, input_embeds,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
) )
else: else:
return hidden_states return hidden_states
...@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
num_groups=None, num_groups=None,
) )
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment