Unverified Commit 2892b9bb authored by Wang Ran (汪然)'s avatar Wang Ran (汪然) Committed by GitHub
Browse files

bugfix: Update sampling_params.py (#4413)

parent 470b4740
...@@ -77,7 +77,7 @@ class SamplingParams: ...@@ -77,7 +77,7 @@ class SamplingParams:
self.custom_params = custom_params self.custom_params = custom_params
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if 0 <= self.temperature < _SAMPLING_EPS:
# top_k = 1 means greedy sampling # top_k = 1 means greedy sampling
self.temperature = 1.0 self.temperature = 1.0
self.top_k = 1 self.top_k = 1
...@@ -93,9 +93,9 @@ class SamplingParams: ...@@ -93,9 +93,9 @@ class SamplingParams:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0: if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < -1 or self.top_k == 0: if self.top_k < 1 or self.top_k == -1:
raise ValueError( raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." f"top_k must be -1 (disable) or at least 1, got {self.top_k}."
) )
if not -2.0 <= self.frequency_penalty <= 2.0: if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError( raise ValueError(
...@@ -108,12 +108,12 @@ class SamplingParams: ...@@ -108,12 +108,12 @@ class SamplingParams:
) )
if not 0.0 <= self.repetition_penalty <= 2.0: if not 0.0 <= self.repetition_penalty <= 2.0:
raise ValueError( raise ValueError(
"repetition_penalty must be in (0, 2], got " "repetition_penalty must be in [0, 2], got "
f"{self.repetition_penalty}." f"{self.repetition_penalty}."
) )
if not 0 <= self.min_new_tokens: if not 0 <= self.min_new_tokens:
raise ValueError( raise ValueError(
f"min_new_tokens must be in (0, max_new_tokens], got " f"min_new_tokens must be in [0, max_new_tokens], got "
f"{self.min_new_tokens}." f"{self.min_new_tokens}."
) )
if self.max_new_tokens is not None: if self.max_new_tokens is not None:
...@@ -123,7 +123,7 @@ class SamplingParams: ...@@ -123,7 +123,7 @@ class SamplingParams:
) )
if not self.min_new_tokens <= self.max_new_tokens: if not self.min_new_tokens <= self.max_new_tokens:
raise ValueError( raise ValueError(
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}." f"{self.min_new_tokens}."
) )
grammars = [ grammars = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment