"vscode:/vscode.git/clone" did not exist on "3c4cebf751a6d2ff9ada2f8234bab17ba7283e09"
Unverified Commit 583507d1 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Spec Decode] Make EAGLE3 draft token ID mapping optional (#18488)


Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent e44d8ce8
......@@ -214,6 +214,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.draft_id_to_target_id is None:
return logits
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
targets = base + self.draft_id_to_target_id
logits_new = logits.new_full((
......@@ -246,4 +249,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
loaded_weights = loader.load_weights(model_weights.items())
if 'd2t' not in loaded_weights:
self.draft_id_to_target_id = None
return loaded_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment