Unverified Commit 886d3449 authored by lukec's avatar lukec Committed by GitHub
Browse files

support llama4 eagle3 (#6985)


Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
Co-authored-by: default avatarShenggui Li <somerlee.9@gmail.com>
Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 637bfee4
......@@ -306,7 +306,26 @@ class ModelRunner:
# auxiliary hidden capture mode. TODO: expose this to server args?
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
self.model.set_eagle3_layers_to_capture()
# load draft config
draft_model_config = ModelConfig.from_server_args(
server_args,
model_path=(server_args.speculative_draft_model_path),
is_draft_model=True,
)
try:
# get the aux layer from draft model config
eagle_config = getattr(
draft_model_config.hf_config, "eagle_config", None
)
eagle_aux_hidden_state_layer_ids = eagle_config[
"eagle_aux_hidden_state_layer_ids"
]
except:
# if there is no aux layer, set to None
eagle_aux_hidden_state_layer_ids = None
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
def model_specific_adjustment(self):
server_args = self.server_args
......
......@@ -124,6 +124,9 @@ def _get_quantization_config(
quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
return None
major, minor = get_device_capability()
if major is not None and minor is not None:
......
......@@ -209,6 +209,17 @@ def get_quant_config(
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if config["quantization"]["quant_algo"] is None:
if (
model_config.hf_config.architectures[0]
!= "LlamaForCausalLMEagle3"
):
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
if "FP4" in config["quantization"]["quant_algo"]:
return ModelOptFp4Config.from_config(config)
else:
......
......@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self):
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
class Phi3ForCausalLM(LlamaForCausalLM):
......
......@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
class LlamaDecoderLayer(LlamaDecoderLayer):
......@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
prefix=add_prefix("qkv_proj", prefix),
)
if config.model_type == "llama4_text":
inter_size = config.intermediate_size_mlp
else:
inter_size = config.intermediate_size
self.mlp = LlamaMLP(
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
)
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
......@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
)
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
self.hidden_size_in = config.target_hidden_size
else:
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
self.hidden_size_in = config.hidden_size
self.fc = torch.nn.Linear(
self.hidden_size_in * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
)
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = True
self.hot_token_id = None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
params_dict = dict(self.named_parameters())
# Define the parameter mapping for stacked parameters
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights:
if "d2t" in name:
# d2t stores diffs between draft id and target id
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
new_name = f"model.{name}"
super().load_weights([(new_name, loaded_weight)])
elif "lm_head" in name:
super().load_weights([(name, loaded_weight)])
continue
if "t2d" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param_name = f"model.{name}" if name not in params_dict else name
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, shard_id)
break
else:
# Handle regular parameters
param_name = name if name in params_dict else f"model.{name}"
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
def get_hot_token_id(self):
return self.hot_token_id
......
......@@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module):
)
weight_loader(param, loaded_weight)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
self.language_model.set_eagle3_layers_to_capture(layer_ids)
def get_embed_and_head(self):
# For EAGLE3, we delegate to the language model which should have this method
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
embed = self.language_model.get_embed()
if hasattr(self.language_model, "get_embed_and_head"):
return self.language_model.get_embed_and_head()
elif hasattr(self.language_model, "lm_head"):
return embed, self.language_model.lm_head.weight
else:
# For EAGLE3, head might not be needed
return embed, None
def set_embed_and_head(self, embed, head):
if hasattr(self.language_model, "set_embed_and_head"):
return self.language_model.set_embed_and_head(embed, head)
else:
# For EAGLE3, only set embed
return self.language_model.set_embed(embed)
def get_embed(self):
return self.language_model.get_embed()
def set_embed(self, embed):
return self.language_model.set_embed(embed)
EntryClass = Llama4ForConditionalGeneration
......@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner.model.set_embed(embed)
# grab hot token ids
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
if self.draft_model_runner.model.hot_token_id is not None:
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
embed.device
)
else:
if self.hot_token_id is not None:
head = head.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment