Unverified Commit b28298f2 authored by saumya-saran's avatar saumya-saran Committed by GitHub
Browse files

[Bugfix] Validate SamplingParam n is an int (#8548)

parent 2940afa0
...@@ -273,9 +273,14 @@ class SamplingParams( ...@@ -273,9 +273,14 @@ class SamplingParams(
self._all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of "
f"type {type(self.n)}")
if self.n < 1: if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.") raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int) if not isinstance(self.best_of, int):
raise ValueError(f'best_of must be an int, but is of '
f'type {type(self.best_of)}')
if self.best_of < self.n: if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, " raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.") f"got n={self.n} and best_of={self.best_of}.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment