Unverified Commit 0c55cbcf authored by ehuaa's avatar ehuaa Committed by GitHub
Browse files

[BugFix] add verify logit_bias to avoid crash because of IndexError (#7749)

parent c46e069d
......@@ -604,7 +604,7 @@ class TokenizerManager:
sampling_kwargs = obj.sampling_params
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
sampling_params.verify(self.model_config.vocab_size)
# Build return object
if isinstance(obj, GenerateReqInput):
......
......@@ -89,7 +89,7 @@ class SamplingParams:
if self.top_k == -1:
self.top_k = TOP_K_ALL # whole vocabulary
def verify(self):
def verify(self, vocab_size):
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
......@@ -131,6 +131,13 @@ class SamplingParams:
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
if self.logit_bias is not None:
for token_id in self.logit_bias:
if not 0 <= int(token_id) < vocab_size:
raise ValueError(
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}."
)
grammars = [
self.json_schema,
self.regex,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment