"vscode:/vscode.git/clone" did not exist on "762be26a8ee0de15638fa21a59d85efedacec847"
Unverified Commit 781d0562 authored by Bryan Lu's avatar Bryan Lu Committed by GitHub
Browse files

[Feature] Enhance EAGLE Architecture with Proper RMS Norms (#14990)


Signed-off-by: default avatarBryan Lu <yuzhelu@amazon.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 5aefd6ac
...@@ -800,10 +800,18 @@ class ModelConfig: ...@@ -800,10 +800,18 @@ class ModelConfig:
@property @property
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
return (hasattr(self.hf_text_config, "model_type")) \ if not hasattr(self.hf_text_config, "model_type"):
and (self.hf_text_config.model_type in \ return False
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'))\ elif self.hf_text_config.model_type in \
and (self.hf_text_config.kv_lora_rank is not None) ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
# underlying architecture
return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3') \
and self.hf_text_config.kv_lora_rank is not None
return False
def get_head_size(self) -> int: def get_head_size(self) -> int:
# TODO remove hard code # TODO remove hard code
......
...@@ -7,6 +7,7 @@ import torch.nn as nn ...@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -59,7 +60,15 @@ class EAGLE(nn.Module): ...@@ -59,7 +60,15 @@ class EAGLE(nn.Module):
truncated_vocab_size < vocab_size. To use this technique, one has to find truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute.
4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP
module with regards to the use of additional RMS norms. The original
EAGLE architecture 1) skips the pre-attention norm in its first
transformer block, and 2) skips the final output norm, both of which we
found to be suboptimal. We also add the support for separate norms
applying to both the token embedding and hidden states before projection
as in DeepSeek MTP, which we found to improve performance as well.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
...@@ -81,10 +90,23 @@ class EAGLE(nn.Module): ...@@ -81,10 +90,23 @@ class EAGLE(nn.Module):
# While weights and biases are generally not needed, # While weights and biases are generally not needed,
# they are retained here to support certain unit tests # they are retained here to support certain unit tests
# (e.g., spec_decode/e2e/test_eagle_correctness.py). # (e.g., spec_decode/e2e/test_eagle_correctness.py).
if not hasattr(self.config.model,
"skip_prenorm") or self.config.model.skip_prenorm:
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( self.model.model.layers[0].input_layernorm = DummyInputLayerNorm(
weight=self.model.model.layers[0].input_layernorm.weight) weight=self.model.model.layers[0].input_layernorm.weight)
if not hasattr(
self.config.model,
"skip_output_norm") or self.config.model.skip_output_norm:
self.model.model.norm = DummyOutputNorm() self.model.model.norm = DummyOutputNorm()
self.add_para_norm = False
if hasattr(self.config.model,
"add_para_norm") and self.config.model.add_para_norm:
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.add_para_norm = True
self.orig_vocab_size = config.vocab_size self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size
...@@ -128,8 +150,17 @@ class EAGLE(nn.Module): ...@@ -128,8 +150,17 @@ class EAGLE(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.fc( if self.add_para_norm:
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) inputs_embeds = torch.cat([
self.enorm(inputs_embeds),
self.hnorm(previous_hidden_states)
],
dim=-1)
else:
inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
inputs_embeds = self.fc(inputs_embeds)
inputs_embeds[positions == 0] = 0 # masking inputs at position=0 inputs_embeds[positions == 0] = 0 # masking inputs at position=0
...@@ -190,6 +221,14 @@ class EAGLE(nn.Module): ...@@ -190,6 +221,14 @@ class EAGLE(nn.Module):
else: else:
logger.warning_once("Found bias in the loaded weights but " logger.warning_once("Found bias in the loaded weights but "
"the model config doesn't have bias.") "the model config doesn't have bias.")
elif name.startswith("enorm.weight"):
weight_loader = getattr(self.enorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.enorm.weight, loaded_weight)
elif name.startswith("hnorm.weight"):
weight_loader = getattr(self.hnorm.weight, "weight_loader",
default_weight_loader)
weight_loader(self.hnorm.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith( elif name.startswith("model.lm_head.") or name.startswith(
"model.model."): "model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight model_weights[name.split("model.", 1)[-1]] = loaded_weight
......
...@@ -5,6 +5,8 @@ from typing import Optional, Union ...@@ -5,6 +5,8 @@ from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
class EAGLEConfig(PretrainedConfig): class EAGLEConfig(PretrainedConfig):
model_type = "eagle" model_type = "eagle"
...@@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig): ...@@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
truncated_vocab_size: Optional[int] = None, truncated_vocab_size: Optional[int] = None,
**kwargs): **kwargs):
model_config = None if model is None else (AutoConfig.for_model( model_config: Union[PretrainedConfig, DeepseekV2Config, None]
**model) if isinstance(model, dict) else model) if isinstance(model, dict):
archs = model.get("architectures", [])
target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
if any(target_arch in archs for target_arch in target_archs):
# AutoConfig does not support DeepSeek MoE models yet
model_config = DeepseekV2Config(**model)
else:
model_config = AutoConfig.for_model(**model)
else:
model_config = model
for k, v in kwargs.items(): for k, v in kwargs.items():
if k != "architectures" and k != "model_type" and hasattr( if k != "architectures" and k != "model_type" and hasattr(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment