Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e06f504a
"vscode:/vscode.git/clone" did not exist on "65788e46edfb60a31782a2bda0ba01f594359785"
Unverified
Commit
e06f504a
authored
Aug 11, 2023
by
WanMok
Committed by
GitHub
Aug 11, 2023
Browse files
Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715)
parent
462ae522
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
16 deletions
+45
-16
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+43
-15
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-1
No files found.
vllm/entrypoints/openai/api_server.py
View file @
e06f504a
...
...
@@ -3,18 +3,18 @@
import
argparse
import
asyncio
from
http
import
HTTPStatus
import
json
import
time
from
typing
import
AsyncGenerator
,
Dict
,
List
,
Optional
from
packag
ing
import
vers
ion
from
http
import
HTTPStatus
from
typ
ing
import
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
,
Un
ion
import
fastapi
import
uvicorn
from
fastapi
import
BackgroundTasks
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
import
uvicor
n
from
packaging
import
versio
n
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
return
prompt
async
def
check_length
(
request
,
prompt
):
input_ids
=
tokenizer
(
prompt
).
input_ids
async
def
check_length
(
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
Tuple
[
List
[
int
],
Optional
[
JSONResponse
]]:
assert
(
not
(
prompt
is
None
and
prompt_ids
is
None
)
and
not
(
prompt
is
not
None
and
prompt_ids
is
not
None
)
),
"Either prompt or prompt_ids should be provided."
if
prompt_ids
is
not
None
:
input_ids
=
prompt_ids
else
:
input_ids
=
tokenizer
(
prompt
).
input_ids
token_num
=
len
(
input_ids
)
if
token_num
+
request
.
max_tokens
>
max_model_len
:
...
...
@@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported"
)
prompt
=
await
get_gen_prompt
(
request
)
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
...
...
@@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
use_token_ids
=
False
if
isinstance
(
request
.
prompt
,
list
):
if
len
(
request
.
prompt
)
==
0
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"please provide at least one prompt"
)
if
len
(
request
.
prompt
)
>
1
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"multiple prompts in a batch is not currently supported"
)
prompt
=
request
.
prompt
[
0
]
first_element
=
request
.
prompt
[
0
]
if
isinstance
(
first_element
,
int
):
use_token_ids
=
True
prompt
=
request
.
prompt
elif
isinstance
(
first_element
,
(
str
,
list
)):
# TODO: handles multiple prompt case in list[list[int]]
if
len
(
request
.
prompt
)
>
1
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
"multiple prompts in a batch is not currently supported"
)
use_token_ids
=
not
isinstance
(
first_element
,
str
)
prompt
=
request
.
prompt
[
0
]
else
:
prompt
=
request
.
prompt
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
if
use_token_ids
:
_
,
error_check_ret
=
await
check_length
(
request
,
prompt_ids
=
prompt
)
else
:
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
...
...
@@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
if
use_token_ids
:
result_generator
=
engine
.
generate
(
None
,
sampling_params
,
request_id
,
prompt_token_ids
=
prompt
)
else
:
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
...
...
vllm/entrypoints/openai/protocol.py
View file @
e06f504a
...
...
@@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
class
CompletionRequest
(
BaseModel
):
model
:
str
prompt
:
Union
[
str
,
List
[
str
]]
# a string, array of strings, array of tokens, or array of token arrays
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
suffix
:
Optional
[
str
]
=
None
max_tokens
:
Optional
[
int
]
=
16
temperature
:
Optional
[
float
]
=
1.0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment