Unverified Commit c03b6e42 authored by Clara Meister's avatar Clara Meister Committed by GitHub
Browse files

value check for typical sampling (#16165)



* value check for typical sampling

* value check for typical sampling

* change from float to int comparison
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fdc2e643
...@@ -240,6 +240,9 @@ class TopKLogitsWarper(LogitsWarper): ...@@ -240,6 +240,9 @@ class TopKLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(LogitsWarper): class TypicalLogitsWarper(LogitsWarper):
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
mass = float(mass)
if not (mass > 0 and mass < 1):
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
self.filter_value = filter_value self.filter_value = filter_value
self.mass = mass self.mass = mass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment