Unverified Commit be10092e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: PT's `top_p` enforces `min_tokens_to_keep` when it is `1` (#24111)

parent 03585f37
......@@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
......
......@@ -255,6 +255,8 @@ class TopPLogitsWarper(LogitsWarper):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
......@@ -266,7 +268,6 @@ class TopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
......
......@@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment