Unverified Commit f95e6617 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix max_tokens for OpenAI chat completion API (#766)

parent de854fb5
...@@ -98,17 +98,21 @@ class ModelTpServer: ...@@ -98,17 +98,21 @@ class ModelTpServer:
if server_args.max_prefill_tokens is None if server_args.max_prefill_tokens is None
else server_args.max_prefill_tokens else server_args.max_prefill_tokens
) )
self.max_running_requests = (
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
)
self.max_running_requests = min( self.max_running_requests = min(
self.max_running_requests, self.model_runner.req_to_token_pool.size - 1 (
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
),
self.model_runner.req_to_token_pool.size - 1,
) )
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
) )
self.max_req_input_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
)
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
...@@ -295,18 +299,16 @@ class ModelTpServer: ...@@ -295,18 +299,16 @@ class ModelTpServer:
) )
# Truncate prompts that are too long # Truncate prompts that are too long
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1] if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens, req.sampling_params.max_new_tokens or 1 << 30,
self.model_config.context_len - 1 - len(req.origin_input_ids), self.max_req_input_len - 1 - len(req.origin_input_ids),
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
) )
if req.sampling_params.max_new_tokens < 0:
req.origin_input_ids = req.origin_input_ids[
: self.max_total_num_tokens - 128
]
logger.error("Request longer than memory pool size, truncated!!!")
self.forward_queue.append(req) self.forward_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]: def get_new_prefill_batch(self) -> Optional[Batch]:
......
...@@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -152,7 +152,7 @@ class ChatCompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None top_logprobs: Optional[int] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = None
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
......
...@@ -65,10 +65,11 @@ class SamplingParams: ...@@ -65,10 +65,11 @@ class SamplingParams:
raise ValueError( raise ValueError(
"presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
) )
if self.max_new_tokens < 0: if self.max_new_tokens is not None:
raise ValueError( if self.max_new_tokens < 0:
f"max_new_tokens must be at least 0, got {self.max_new_tokens}." raise ValueError(
) f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
)
def normalize(self, tokenizer): def normalize(self, tokenizer):
# Process stop strings # Process stop strings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment