Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
dd7e8f5f
Unverified
Commit
dd7e8f5f
authored
Jan 18, 2024
by
Simon Mo
Committed by
GitHub
Jan 18, 2024
Browse files
refactor complemention api for readability (#2499)
parent
d2a68364
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
284 additions
and
253 deletions
+284
-253
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+10
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+45
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-25
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+210
-215
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+15
-12
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+1
-1
No files found.
tests/entrypoints/test_openai_server.py
View file @
dd7e8f5f
...
...
@@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
6
,
total_tokens
=
11
)
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
MODEL_NAME
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
completion
.
choices
[
0
].
text
is
not
None
and
len
(
completion
.
choices
[
0
].
text
)
>=
5
async
def
test_single_chat_session
(
server
,
client
:
openai
.
AsyncOpenAI
):
messages
=
[{
...
...
vllm/entrypoints/openai/protocol.py
View file @
dd7e8f5f
...
...
@@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from
pydantic
import
BaseModel
,
Field
from
vllm.utils
import
random_uuid
from
vllm.sampling_params
import
SamplingParams
class
ErrorResponse
(
BaseModel
):
...
...
@@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
def
to_sampling_params
(
self
)
->
SamplingParams
:
return
SamplingParams
(
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
min_p
=
self
.
min_p
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
max_tokens
=
self
.
max_tokens
,
best_of
=
self
.
best_of
,
top_k
=
self
.
top_k
,
ignore_eos
=
self
.
ignore_eos
,
use_beam_search
=
self
.
use_beam_search
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
)
class
CompletionRequest
(
BaseModel
):
model
:
str
...
...
@@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
repetition_penalty
:
Optional
[
float
]
=
1.0
min_p
:
Optional
[
float
]
=
0.0
def
to_sampling_params
(
self
):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
return
SamplingParams
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
self
.
repetition_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
min_p
=
self
.
min_p
,
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
self
.
max_tokens
if
not
echo_without_generation
else
1
,
logprobs
=
self
.
logprobs
,
use_beam_search
=
self
.
use_beam_search
,
prompt_logprobs
=
self
.
logprobs
if
self
.
echo
else
None
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
(
self
.
spaces_between_special_tokens
),
)
class
LogProbs
(
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
dd7e8f5f
...
...
@@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
logger
=
init_logger
(
__name__
)
...
...
@@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
f
"Error in applying chat template from request:
{
str
(
e
)
}
"
)
return
self
.
create_error_response
(
str
(
e
))
token_ids
,
error_check_ret
=
await
self
.
_check_length
(
request
,
prompt
=
prompt
)
if
error_check_ret
is
not
None
:
return
error_check_ret
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
try
:
spaces_between_special_tokens
=
request
.
spaces_between_special_tokens
sampling_params
=
SamplingParams
(
n
=
request
.
n
,
presence_penalty
=
request
.
presence_penalty
,
frequency_penalty
=
request
.
frequency_penalty
,
repetition_penalty
=
request
.
repetition_penalty
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
min_p
=
request
.
min_p
,
stop
=
request
.
stop
,
stop_token_ids
=
request
.
stop_token_ids
,
max_tokens
=
request
.
max_tokens
,
best_of
=
request
.
best_of
,
top_k
=
request
.
top_k
,
ignore_eos
=
request
.
ignore_eos
,
use_beam_search
=
request
.
use_beam_search
,
skip_special_tokens
=
request
.
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
token_ids
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
sampling_params
=
request
.
to_sampling_params
()
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
dd7e8f5f
This diff is collapsed.
Click to expand it.
vllm/entrypoints/openai/serving_engine.py
View file @
dd7e8f5f
import
asyncio
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -104,27 +104,30 @@ class OpenAIServing:
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
async
def
_check_length
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
Tuple
[
List
[
int
],
Optional
[
ErrorResponse
]]:
assert
(
not
(
prompt
is
None
and
prompt_ids
is
None
)
and
not
(
prompt
is
not
None
and
prompt_ids
is
not
None
)
),
"Either prompt or prompt_ids should be provided."
def
_validate_prompt_and_tokenize
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
if
not
(
prompt
or
prompt_ids
):
raise
ValueError
(
"Either prompt or prompt_ids should be provided."
)
if
(
prompt
and
prompt_ids
):
raise
ValueError
(
"Only one of prompt or prompt_ids should be provided."
)
input_ids
=
prompt_ids
if
prompt_ids
is
not
None
else
self
.
tokenizer
(
prompt
).
input_ids
token_num
=
len
(
input_ids
)
if
request
.
max_tokens
is
None
:
request
.
max_tokens
=
self
.
max_model_len
-
token_num
if
token_num
+
request
.
max_tokens
>
self
.
max_model_len
:
r
eturn
input_ids
,
self
.
create_error_response
(
r
aise
ValueError
(
f
"This model's maximum context length is
{
self
.
max_model_len
}
tokens. "
f
"However, you requested
{
request
.
max_tokens
+
token_num
}
tokens "
f
"(
{
token_num
}
in the messages, "
f
"
{
request
.
max_tokens
}
in the completion). "
f
"Please reduce the length of the messages or completion."
,
)
else
:
return
input_ids
,
None
return
input_ids
vllm/model_executor/weight_utils.py
View file @
dd7e8f5f
...
...
@@ -163,7 +163,7 @@ def prepare_hf_model_weights(
use_safetensors
=
True
break
logger
.
info
(
f
"
Download
ing model weights
{
allow_patterns
}
"
)
logger
.
info
(
f
"
Us
ing model weights
format
{
allow_patterns
}
"
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment