"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "54bb49a50264270e97ade94082691859668f99ee"
Unverified Commit 886d3449 authored by lukec's avatar lukec Committed by GitHub
Browse files

support llama4 eagle3 (#6985)


Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
Co-authored-by: default avatarShenggui Li <somerlee.9@gmail.com>
Co-authored-by: default avatarYingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 637bfee4
...@@ -306,7 +306,26 @@ class ModelRunner: ...@@ -306,7 +306,26 @@ class ModelRunner:
# auxiliary hidden capture mode. TODO: expose this to server args? # auxiliary hidden capture mode. TODO: expose this to server args?
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
self.model.set_eagle3_layers_to_capture() # load draft config
draft_model_config = ModelConfig.from_server_args(
server_args,
model_path=(server_args.speculative_draft_model_path),
is_draft_model=True,
)
try:
# get the aux layer from draft model config
eagle_config = getattr(
draft_model_config.hf_config, "eagle_config", None
)
eagle_aux_hidden_state_layer_ids = eagle_config[
"eagle_aux_hidden_state_layer_ids"
]
except:
# if there is no aux layer, set to None
eagle_aux_hidden_state_layer_ids = None
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
def model_specific_adjustment(self): def model_specific_adjustment(self):
server_args = self.server_args server_args = self.server_args
......
...@@ -124,6 +124,9 @@ def _get_quantization_config( ...@@ -124,6 +124,9 @@ def _get_quantization_config(
quant_config = get_quant_config( quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping model_config, load_config, packed_modules_mapping
) )
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
return None
major, minor = get_device_capability() major, minor = get_device_capability()
if major is not None and minor is not None: if major is not None and minor is not None:
......
...@@ -209,6 +209,17 @@ def get_quant_config( ...@@ -209,6 +209,17 @@ def get_quant_config(
config["adapter_name_or_path"] = model_name_or_path config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization == "modelopt": elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt": if config["producer"]["name"] == "modelopt":
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if config["quantization"]["quant_algo"] is None:
if (
model_config.hf_config.architectures[0]
!= "LlamaForCausalLMEagle3"
):
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
if "FP4" in config["quantization"]["quant_algo"]: if "FP4" in config["quantization"]["quant_algo"]:
return ModelOptFp4Config.from_config(config) return ModelOptFp4Config.from_config(config)
else: else:
......
...@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module): ...@@ -697,13 +697,19 @@ class LlamaForCausalLM(nn.Module):
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)
def set_eagle3_layers_to_capture(self): def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return return
self.capture_aux_hidden_states = True if layer_ids is None:
num_layers = self.config.num_hidden_layers self.capture_aux_hidden_states = True
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
class Phi3ForCausalLM(LlamaForCausalLM): class Phi3ForCausalLM(LlamaForCausalLM):
......
...@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -35,7 +35,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
class LlamaDecoderLayer(LlamaDecoderLayer): class LlamaDecoderLayer(LlamaDecoderLayer):
...@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer): ...@@ -59,6 +60,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
if config.model_type == "llama4_text":
inter_size = config.intermediate_size_mlp
else:
inter_size = config.intermediate_size
self.mlp = LlamaMLP(
config.hidden_size, inter_size, config.hidden_act, quant_config, prefix
)
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
...@@ -105,11 +115,19 @@ class LlamaModel(nn.Module): ...@@ -105,11 +115,19 @@ class LlamaModel(nn.Module):
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
if hasattr(config, "target_hidden_size"): if hasattr(config, "target_hidden_size"):
self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size) self.hidden_size_in = config.target_hidden_size
else: else:
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size) self.hidden_size_in = config.hidden_size
self.fc = torch.nn.Linear(
self.hidden_size_in * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
)
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): ...@@ -179,18 +197,50 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = True self.capture_aux_hidden_states = True
self.hot_token_id = None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
params_dict = dict(self.named_parameters())
# Define the parameter mapping for stacked parameters
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "d2t" in name: if "d2t" in name:
# d2t stores diffs between draft id and target id # d2t stores diffs between draft id and target id
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
continue
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
new_name = f"model.{name}" if "t2d" in name:
super().load_weights([(new_name, loaded_weight)]) continue
elif "lm_head" in name:
super().load_weights([(name, loaded_weight)]) for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param_name = f"model.{name}" if name not in params_dict else name
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, shard_id)
break
else:
# Handle regular parameters
param_name = name if name in params_dict else f"model.{name}"
if param_name in params_dict:
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
def get_hot_token_id(self): def get_hot_token_id(self):
return self.hot_token_id return self.hot_token_id
......
...@@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -223,5 +223,34 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
self.language_model.set_eagle3_layers_to_capture(layer_ids)
def get_embed_and_head(self):
# For EAGLE3, we delegate to the language model which should have this method
# If the language model doesn't have lm_head (like EAGLE3), we return None for head
embed = self.language_model.get_embed()
if hasattr(self.language_model, "get_embed_and_head"):
return self.language_model.get_embed_and_head()
elif hasattr(self.language_model, "lm_head"):
return embed, self.language_model.lm_head.weight
else:
# For EAGLE3, head might not be needed
return embed, None
def set_embed_and_head(self, embed, head):
if hasattr(self.language_model, "set_embed_and_head"):
return self.language_model.set_embed_and_head(embed, head)
else:
# For EAGLE3, only set embed
return self.language_model.set_embed(embed)
def get_embed(self):
return self.language_model.get_embed()
def set_embed(self, embed):
return self.language_model.set_embed(embed)
EntryClass = Llama4ForConditionalGeneration EntryClass = Llama4ForConditionalGeneration
...@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker): ...@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner.model.set_embed(embed) self.draft_model_runner.model.set_embed(embed)
# grab hot token ids # grab hot token ids
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to( if self.draft_model_runner.model.hot_token_id is not None:
embed.device self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
) embed.device
)
else: else:
if self.hot_token_id is not None: if self.hot_token_id is not None:
head = head.clone() head = head.clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment