Unverified Commit e4b26aa1 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

fix(server): avoid errors for very small top_p values (#544)

See https://github.com/huggingface/transformers/pull/24111

I didn't add validation to the `__init__` method since it's not done for
other values/warpers.
parent 2a101207
...@@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper): ...@@ -189,9 +189,8 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = probs <= self.top_p_opposite sorted_indices_to_remove = probs <= self.top_p_opposite
if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep
# Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment