Unverified Commit 7cd2ebb0 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix `compute_logits` in Jamba (#6093)

parent f1c78138
......@@ -876,7 +876,7 @@ class JambaForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment