Unverified Commit 0c55cbcf authored by ehuaa's avatar ehuaa Committed by GitHub
Browse files

[BugFix] add verify logit_bias to avoid crash because of IndexError (#7749)

parent c46e069d
...@@ -604,7 +604,7 @@ class TokenizerManager: ...@@ -604,7 +604,7 @@ class TokenizerManager:
sampling_kwargs = obj.sampling_params sampling_kwargs = obj.sampling_params
sampling_params = SamplingParams(**sampling_kwargs) sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
sampling_params.verify() sampling_params.verify(self.model_config.vocab_size)
# Build return object # Build return object
if isinstance(obj, GenerateReqInput): if isinstance(obj, GenerateReqInput):
......
...@@ -89,7 +89,7 @@ class SamplingParams: ...@@ -89,7 +89,7 @@ class SamplingParams:
if self.top_k == -1: if self.top_k == -1:
self.top_k = TOP_K_ALL # whole vocabulary self.top_k = TOP_K_ALL # whole vocabulary
def verify(self): def verify(self, vocab_size):
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise ValueError(
f"temperature must be non-negative, got {self.temperature}." f"temperature must be non-negative, got {self.temperature}."
...@@ -131,6 +131,13 @@ class SamplingParams: ...@@ -131,6 +131,13 @@ class SamplingParams:
f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got " f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}." f"{self.min_new_tokens}."
) )
if self.logit_bias is not None:
for token_id in self.logit_bias:
if not 0 <= int(token_id) < vocab_size:
raise ValueError(
f"logit_bias must has keys in [0, {vocab_size - 1}], got "
f"{token_id}."
)
grammars = [ grammars = [
self.json_schema, self.json_schema,
self.regex, self.regex,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment