"vscode:/vscode.git/clone" did not exist on "65788e46edfb60a31782a2bda0ba01f594359785"
Unverified Commit e06f504a authored by WanMok's avatar WanMok Committed by GitHub
Browse files

Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715)

parent 462ae522
......@@ -3,18 +3,18 @@
import argparse
import asyncio
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional
from packaging import version
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
......@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
return prompt
async def check_length(request, prompt):
input_ids = tokenizer(prompt).input_ids
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > max_model_len:
......@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
......@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
prompt = request.prompt[0]
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
token_ids, error_check_ret = await check_length(request, prompt)
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
......@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
if use_token_ids:
result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
......
......@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
class CompletionRequest(BaseModel):
model: str
prompt: Union[str, List[str]]
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment