Unverified Commit 0baa9246 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

Fix TypicalLogitsWarper tensor OOB indexing edge case (#26579)

* Fix TypicalLogitsWarper tensor OOB indexing edge case

This can be triggerd fairly quickly with low precision e.g. bfloat16 and typical_p = 0.99.

* Shift threshold index by one

* Use explicit named arg for clamp min
parent 06e782da
......@@ -492,8 +492,8 @@ class TypicalLogitsWarper(LogitsWarper):
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
last_ind = (cumulative_probs < self.mass).sum(dim=1) - 1
last_ind.clamp_(min=0)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment