Unverified Commit 9c6aeba3 authored by Nick Doiron's avatar Nick Doiron Committed by GitHub
Browse files

Document and validate typical_p in generation (#19128)

* Document and validate typical_p in generation
parent de359c45
......@@ -236,6 +236,19 @@ class TopKLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
Args:
mass (`float`):
Value of typical_p between 0 and 1 inclusive, defaults to 0.9.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
mass = float(mass)
if not (mass > 0 and mass < 1):
......
......@@ -1486,6 +1486,9 @@ class GenerationMixin:
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if typical_p is not None:
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment