Unverified Commit 27ff1871 authored by OlivierDehaene's avatar OlivierDehaene
Browse files

hotfix: fix flashllama

parent 03c9388b
......@@ -692,7 +692,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if not self.logits_scaled:
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment