Unverified Commit c0795de2 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): do not warp prefill logits (#116)

parent 1a2d6825
......@@ -75,6 +75,10 @@ class NextTokenChooser:
def __call__(self, input_ids, scores):
# Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores)
# Compute logprobs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment