"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4a210c9fc67661e48a0146a6833381bfd0a4ea07"
Unverified Commit be10092e authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: PT's `top_p` enforces `min_tokens_to_keep` when it is `1` (#24111)

parent 03585f37
...@@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -129,6 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
......
...@@ -255,6 +255,8 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -255,6 +255,8 @@ class TopPLogitsWarper(LogitsWarper):
top_p = float(top_p) top_p = float(top_p)
if top_p < 0 or top_p > 1.0: if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
...@@ -266,9 +268,8 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -266,9 +268,8 @@ class TopPLogitsWarper(LogitsWarper):
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep
# Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
......
...@@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper): ...@@ -160,6 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment