Unverified Commit 87985077 authored by Shubhra Pandit's avatar Shubhra Pandit Committed by GitHub
Browse files

[Speculative Decoding] Add `norm_before_fc` for gpt-oss draft models (#36545)


Signed-off-by: default avatarShubhra Pandit <shubhra.pandit@gmail.com>
Co-authored-by: default avatarBenjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent a79c1c2c
...@@ -150,6 +150,7 @@ class LlamaModel(nn.Module): ...@@ -150,6 +150,7 @@ class LlamaModel(nn.Module):
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"] self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
else: else:
self.use_aux_hidden_state = True self.use_aux_hidden_state = True
self.norm_before_fc = getattr(self.config, "norm_before_fc", False)
current_vllm_config = get_current_vllm_config() current_vllm_config = get_current_vllm_config()
...@@ -175,6 +176,13 @@ class LlamaModel(nn.Module): ...@@ -175,6 +176,13 @@ class LlamaModel(nn.Module):
fc_input_size = self.config.target_hidden_size * 3 fc_input_size = self.config.target_hidden_size * 3
else: else:
fc_input_size = self.config.hidden_size * 3 fc_input_size = self.config.hidden_size * 3
if self.norm_before_fc:
self.input_norm = RMSNorm(
fc_input_size,
eps=self.config.rms_norm_eps,
)
else:
self.input_norm = None
self.fc = ReplicatedLinear( self.fc = ReplicatedLinear(
input_size=fc_input_size, input_size=fc_input_size,
output_size=self.config.hidden_size, output_size=self.config.hidden_size,
...@@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -357,6 +365,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
if not self.model.use_aux_hidden_state: if not self.model.use_aux_hidden_state:
return hidden_states return hidden_states
# combine multiple auxiliary hidden states returned by eagle3 # combine multiple auxiliary hidden states returned by eagle3
if self.model.norm_before_fc:
hidden_states = self.model.input_norm(hidden_states)
return self.model.fc(hidden_states) return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
...@@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -403,6 +414,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("embed_tokens") skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state: if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.") skip_substrs.append("fc.")
if not self.model.norm_before_fc:
skip_substrs.append("input_norm.")
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=None, skip_prefixes=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment