Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e06f504a
Unverified
Commit
e06f504a
authored
Aug 11, 2023
by
WanMok
Committed by
GitHub
Aug 11, 2023
Browse files
Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715)
parent
462ae522
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
16 deletions
+45
-16
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+43
-15
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-1
No files found.
vllm/entrypoints/openai/api_server.py
View file @
e06f504a
...
@@ -3,18 +3,18 @@
...
@@ -3,18 +3,18 @@
import
argparse
import
argparse
import
asyncio
import
asyncio
from
http
import
HTTPStatus
import
json
import
json
import
time
import
time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
from
http
import
HTTPStatus
from
packag
ing
import
vers
ion
from
typ
ing
import
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
,
Un
ion
import
fastapi
import
fastapi
import
uvicorn
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
import
uvicor
n
from
packaging
import
versio
n
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
...
@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
return
prompt
return
prompt
async
def
check_length
(
request
,
prompt
):
async
def
check_length
(
input_ids
=
tokenizer
(
prompt
).
input_ids
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
Tuple
[
List
[
int
],
Optional
[
JSONResponse
]]:
assert
(
not
(
prompt
is
None
and
prompt_ids
is
None
)
and
not
(
prompt
is
not
None
and
prompt_ids
is
not
None
)
),
"Either prompt or prompt_ids should be provided."
if
prompt_ids
is
not
None
:
input_ids
=
prompt_ids
else
:
input_ids
=
tokenizer
(
prompt
).
input_ids
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
token_num
+
request
.
max_tokens
>
max_model_len
:
if
token_num
+
request
.
max_tokens
>
max_model_len
:
...
@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
...
@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported"
)
"logit_bias is not currently supported"
)
prompt
=
await
get_gen_prompt
(
request
)
prompt
=
await
get_gen_prompt
(
request
)
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
...
@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
model_name
=
request
.
model
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
use_token_ids
=
False
if
isinstance
(
request
.
prompt
,
list
):
if
isinstance
(
request
.
prompt
,
list
):
if
len
(
request
.
prompt
)
==
0
:
if
len
(
request
.
prompt
)
==
0
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"please provide at least one prompt"
)
"please provide at least one prompt"
)
if
len
(
request
.
prompt
)
>
1
:
first_element
=
request
.
prompt
[
0
]
return
create_error_response
(
if
isinstance
(
first_element
,
int
):
HTTPStatus
.
BAD_REQUEST
,
use_token_ids
=
True
"multiple prompts in a batch is not currently supported"
)
prompt
=
request
.
prompt
prompt
=
request
.
prompt
[
0
]
elif
isinstance
(
first_element
,
(
str
,
list
)):
# TODO: handles multiple prompt case in list[list[int]]
if
len
(
request
.
prompt
)
>
1
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"multiple prompts in a batch is not currently supported"
)
use_token_ids
=
not
isinstance
(
first_element
,
str
)
prompt
=
request
.
prompt
[
0
]
else
:
else
:
prompt
=
request
.
prompt
prompt
=
request
.
prompt
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
if
use_token_ids
:
_
,
error_check_ret
=
await
check_length
(
request
,
prompt_ids
=
prompt
)
else
:
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
...
@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
if
use_token_ids
:
token_ids
)
result_generator
=
engine
.
generate
(
None
,
sampling_params
,
request_id
,
prompt_token_ids
=
prompt
)
else
:
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
...
...
vllm/entrypoints/openai/protocol.py
View file @
e06f504a
...
@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
...
@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
BaseModel
):
model
:
str
model
:
str
prompt
:
Union
[
str
,
List
[
str
]]
# a string, array of strings, array of tokens, or array of token arrays
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
suffix
:
Optional
[
str
]
=
None
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
1.0
temperature
:
Optional
[
float
]
=
1.0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment