Unverified Commit ac19871c authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: `min_tokens_to_keep` has to be `>= 1` (#24453)

parent 5f3efdf7
...@@ -129,8 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -129,8 +129,8 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0): if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}") raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
......
...@@ -255,8 +255,8 @@ class TopPLogitsWarper(LogitsWarper): ...@@ -255,8 +255,8 @@ class TopPLogitsWarper(LogitsWarper):
top_p = float(top_p) top_p = float(top_p)
if top_p < 0 or top_p > 1.0: if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0): if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}") raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
...@@ -323,6 +323,8 @@ class TypicalLogitsWarper(LogitsWarper): ...@@ -323,6 +323,8 @@ class TypicalLogitsWarper(LogitsWarper):
mass = float(mass) mass = float(mass)
if not (mass > 0 and mass < 1): if not (mass > 0 and mass < 1):
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}") raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.filter_value = filter_value self.filter_value = filter_value
self.mass = mass self.mass = mass
...@@ -344,8 +346,6 @@ class TypicalLogitsWarper(LogitsWarper): ...@@ -344,8 +346,6 @@ class TypicalLogitsWarper(LogitsWarper):
last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
......
...@@ -160,8 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper): ...@@ -160,8 +160,8 @@ class TFTopPLogitsWarper(TFLogitsWarper):
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 0): if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(f"`min_tokens_to_keep` has to be a non-negative integer, but is {min_tokens_to_keep}") raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.top_p = top_p self.top_p = top_p
self.filter_value = filter_value self.filter_value = filter_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment