Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
66c54aa9
"...git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "89b74f0aaf0a9a60b36d4241a5578c92a7cced8a"
Unverified
Commit
66c54aa9
authored
Aug 08, 2023
by
Nicolas Basile
Committed by
GitHub
Aug 08, 2023
Browse files
Check the max prompt length for the OpenAI completions API (#472)
parent
735ecfff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
5 deletions
+12
-5
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+12
-5
No files found.
vllm/entrypoints/openai/api_server.py
View file @
66c54aa9
...
@@ -120,7 +120,7 @@ async def check_length(request, prompt):
...
@@ -120,7 +120,7 @@ async def check_length(request, prompt):
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
if
token_num
+
request
.
max_tokens
>
max_model_len
:
if
token_num
+
request
.
max_tokens
>
max_model_len
:
return
create_error_response
(
return
input_ids
,
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
HTTPStatus
.
BAD_REQUEST
,
f
"This model's maximum context length is
{
max_model_len
}
tokens. "
f
"This model's maximum context length is
{
max_model_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
...
@@ -129,7 +129,7 @@ async def check_length(request, prompt):
...
@@ -129,7 +129,7 @@ async def check_length(request, prompt):
f
"Please reduce the length of the messages or completion."
,
f
"Please reduce the length of the messages or completion."
,
)
)
else
:
else
:
return
None
return
input_ids
,
None
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
...
@@ -191,7 +191,7 @@ async def create_chat_completion(raw_request: Request):
...
@@ -191,7 +191,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported"
)
"logit_bias is not currently supported"
)
prompt
=
await
get_gen_prompt
(
request
)
prompt
=
await
get_gen_prompt
(
request
)
error_check_ret
=
await
check_length
(
request
,
prompt
)
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
...
@@ -215,7 +215,8 @@ async def create_chat_completion(raw_request: Request):
...
@@ -215,7 +215,8 @@ async def create_chat_completion(raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
async
def
abort_request
()
->
None
:
async
def
abort_request
()
->
None
:
await
engine
.
abort
(
request_id
)
await
engine
.
abort
(
request_id
)
...
@@ -386,6 +387,11 @@ async def create_completion(raw_request: Request):
...
@@ -386,6 +387,11 @@ async def create_completion(raw_request: Request):
prompt
=
request
.
prompt
[
0
]
prompt
=
request
.
prompt
[
0
]
else
:
else
:
prompt
=
request
.
prompt
prompt
=
request
.
prompt
token_ids
,
error_check_ret
=
await
check_length
(
request
,
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
try
:
try
:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
...
@@ -405,7 +411,8 @@ async def create_completion(raw_request: Request):
...
@@ -405,7 +411,8 @@ async def create_completion(raw_request: Request):
except
ValueError
as
e
:
except
ValueError
as
e
:
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
return
create_error_response
(
HTTPStatus
.
BAD_REQUEST
,
str
(
e
))
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
result_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
,
token_ids
)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
# results. In addition, we do not stream the results when use beam search.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment