Unverified Commit 5b8c30d6 authored by Shubhra Pandit's avatar Shubhra Pandit Committed by GitHub
Browse files

[Spec Decode, BugFix] Propagate norm_before_fc from Eagle3 speculator (#38111)


Signed-off-by: default avatarShubhra Pandit <shubhra.pandit@gmail.com>
parent d39b8daf
...@@ -22,6 +22,7 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None: ...@@ -22,6 +22,7 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
- draft_vocab_size: Size of the draft model's vocabulary - draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model - target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection - norm_before_residual: Whether to apply norm before residual connection
- norm_before_fc: Whether to apply RMSNorm before the fc projection
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base - eagle_aux_hidden_state_layer_ids: List of layer indices from the base
model to use as auxiliary inputs for the Eagle3 drafter. These layers model to use as auxiliary inputs for the Eagle3 drafter. These layers
provide intermediate hidden states that help the drafter make better provide intermediate hidden states that help the drafter make better
...@@ -34,6 +35,7 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None: ...@@ -34,6 +35,7 @@ def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
pre_trained_config["norm_before_residual"] = config_dict.get( pre_trained_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True "norm_before_residual", True
) )
pre_trained_config["norm_before_fc"] = config_dict.get("norm_before_fc", False)
pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"] pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"): if config_dict.get("eagle_aux_hidden_state_layer_ids"):
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment