Unverified Commit 1bc86a3d authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix EAGLE3 broken logits (#18909)


Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
parent bbfa0c61
...@@ -215,6 +215,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -215,6 +215,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
if self.draft_id_to_target_id is None: if self.draft_id_to_target_id is None:
assert logits.shape[1] == self.config.vocab_size, \
"Expected logits to have shape " \
f"(*, {self.config.vocab_size}), but got {logits.shape}"
return logits return logits
base = torch.arange(self.config.draft_vocab_size, device=logits.device) base = torch.arange(self.config.draft_vocab_size, device=logits.device)
...@@ -234,24 +237,22 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -234,24 +237,22 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
return self.model.fc(hidden_states) return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {} model_weights = {}
includes_draft_id_mapping = False
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "t2d" in name: if "t2d" in name:
continue continue
if "d2t" in name: if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id") name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name: elif "lm_head" not in name:
name = "model." + name name = "model." + name
model_weights[name] = loaded_weight model_weights[name] = loaded_weight
loaded_weights = loader.load_weights(model_weights.items()) loader = AutoWeightsLoader(
self,
if 'd2t' not in loaded_weights: skip_prefixes=None,
self.draft_id_to_target_id = None skip_substrs=["draft_id_to_target_id"] \
if not includes_draft_id_mapping else None,
return loaded_weights )
loader.load_weights(model_weights.items())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment