Unverified Commit e06f504a authored by WanMok's avatar WanMok Committed by GitHub
Browse files

Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715)

parent 462ae522
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
import argparse import argparse
import asyncio import asyncio
from http import HTTPStatus
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional from http import HTTPStatus
from packaging import version from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str: ...@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
return prompt return prompt
async def check_length(request, prompt): async def check_length(
input_ids = tokenizer(prompt).input_ids request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids) token_num = len(input_ids)
if token_num + request.max_tokens > max_model_len: if token_num + request.max_tokens > max_model_len:
...@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request): ...@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported") "logit_bias is not currently supported")
prompt = await get_gen_prompt(request) prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt) token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
...@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request): ...@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list): if isinstance(request.prompt, list):
if len(request.prompt) == 0: if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt") "please provide at least one prompt")
if len(request.prompt) > 1: first_element = request.prompt[0]
return create_error_response( if isinstance(first_element, int):
HTTPStatus.BAD_REQUEST, use_token_ids = True
"multiple prompts in a batch is not currently supported") prompt = request.prompt
prompt = request.prompt[0] elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else: else:
prompt = request.prompt prompt = request.prompt
token_ids, error_check_ret = await check_length(request, prompt) if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
...@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request): ...@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id, if use_token_ids:
token_ids) result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
......
...@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
prompt: Union[str, List[str]] # a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None suffix: Optional[str] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment