Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8947bc3c
Unverified
Commit
8947bc3c
authored
Apr 27, 2024
by
Cyrus Leung
Committed by
GitHub
Apr 27, 2024
Browse files
[Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355)
parent
12628d3c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
113 additions
and
55 deletions
+113
-55
requirements-common.txt
requirements-common.txt
+1
-0
requirements-dev.txt
requirements-dev.txt
+0
-1
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+16
-0
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+2
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+35
-29
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+44
-11
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+5
-4
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+10
-8
No files found.
requirements-common.txt
View file @
8947bc3c
...
@@ -8,6 +8,7 @@ py-cpuinfo
...
@@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
fastapi
openai
uvicorn[standard]
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
prometheus_client >= 0.18.0
...
...
requirements-dev.txt
View file @
8947bc3c
...
@@ -21,7 +21,6 @@ pytest-rerunfailures
...
@@ -21,7 +21,6 @@ pytest-rerunfailures
pytest-shard
pytest-shard
httpx
httpx
einops # required for MPT
einops # required for MPT
openai
requests
requests
ray
ray
peft
peft
...
...
tests/entrypoints/test_openai_server.py
View file @
8947bc3c
...
@@ -15,6 +15,7 @@ import ray
...
@@ -15,6 +15,7 @@ import ray
import
requests
import
requests
# downloading lora to test lora requests
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
openai
import
BadRequestError
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
@@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
...
@@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert
loaded
==
{
"result"
:
2
},
loaded
assert
loaded
==
{
"result"
:
2
},
loaded
async
def
test_extra_fields
(
server
,
client
:
openai
.
AsyncOpenAI
):
with
pytest
.
raises
(
BadRequestError
)
as
exc_info
:
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
,
"extra_field"
:
"0"
,
}],
# type: ignore
temperature
=
0
,
seed
=
0
)
assert
"extra_forbidden"
in
exc_info
.
value
.
message
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
simple_sql_grammar
=
"""
simple_sql_grammar
=
"""
start: select_statement
start: select_statement
...
...
vllm/entrypoints/openai/cli_args.py
View file @
8947bc3c
...
@@ -9,7 +9,7 @@ import json
...
@@ -9,7 +9,7 @@ import json
import
ssl
import
ssl
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.serving_engine
import
LoRA
from
vllm.entrypoints.openai.serving_engine
import
LoRA
ModulePath
class
LoRAParserAction
(
argparse
.
Action
):
class
LoRAParserAction
(
argparse
.
Action
):
...
@@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
...
@@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
lora_list
=
[]
lora_list
=
[]
for
item
in
values
:
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRA
(
name
,
path
))
lora_list
.
append
(
LoRA
ModulePath
(
name
,
path
))
setattr
(
namespace
,
self
.
dest
,
lora_list
)
setattr
(
namespace
,
self
.
dest
,
lora_list
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
8947bc3c
...
@@ -4,14 +4,20 @@ import time
...
@@ -4,14 +4,20 @@ import time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
pydantic
import
BaseModel
,
Field
,
model_validator
from
openai.types.chat
import
ChatCompletionMessageParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
class
ErrorResponse
(
BaseModel
):
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
model_config
=
ConfigDict
(
extra
=
"forbid"
)
class
ErrorResponse
(
OpenAIBaseModel
):
object
:
str
=
"error"
object
:
str
=
"error"
message
:
str
message
:
str
type
:
str
type
:
str
...
@@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
...
@@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
code
:
int
code
:
int
class
ModelPermission
(
BaseModel
):
class
ModelPermission
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"modelperm-
{
random_uuid
()
}
"
)
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"modelperm-
{
random_uuid
()
}
"
)
object
:
str
=
"model_permission"
object
:
str
=
"model_permission"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
@@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
...
@@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking
:
bool
=
False
is_blocking
:
bool
=
False
class
ModelCard
(
BaseModel
):
class
ModelCard
(
OpenAI
BaseModel
):
id
:
str
id
:
str
object
:
str
=
"model"
object
:
str
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
@@ -44,26 +50,26 @@ class ModelCard(BaseModel):
...
@@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission
:
List
[
ModelPermission
]
=
Field
(
default_factory
=
list
)
permission
:
List
[
ModelPermission
]
=
Field
(
default_factory
=
list
)
class
ModelList
(
BaseModel
):
class
ModelList
(
OpenAI
BaseModel
):
object
:
str
=
"list"
object
:
str
=
"list"
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
data
:
List
[
ModelCard
]
=
Field
(
default_factory
=
list
)
class
UsageInfo
(
BaseModel
):
class
UsageInfo
(
OpenAI
BaseModel
):
prompt_tokens
:
int
=
0
prompt_tokens
:
int
=
0
total_tokens
:
int
=
0
total_tokens
:
int
=
0
completion_tokens
:
Optional
[
int
]
=
0
completion_tokens
:
Optional
[
int
]
=
0
class
ResponseFormat
(
BaseModel
):
class
ResponseFormat
(
OpenAI
BaseModel
):
# type must be "json_object" or "text"
# type must be "json_object" or "text"
type
:
Literal
[
"text"
,
"json_object"
]
type
:
Literal
[
"text"
,
"json_object"
]
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
(
OpenAI
BaseModel
):
# Ordered by official OpenAI API documentation
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
# https://platform.openai.com/docs/api-reference/chat/create
messages
:
List
[
Dict
[
str
,
str
]
]
messages
:
List
[
ChatCompletionMessageParam
]
model
:
str
model
:
str
frequency_penalty
:
Optional
[
float
]
=
0.0
frequency_penalty
:
Optional
[
float
]
=
0.0
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
logit_bias
:
Optional
[
Dict
[
str
,
float
]]
=
None
...
@@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
return
data
return
data
class
CompletionRequest
(
BaseModel
):
class
CompletionRequest
(
OpenAI
BaseModel
):
# Ordered by official OpenAI API documentation
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
# https://platform.openai.com/docs/api-reference/completions/create
model
:
str
model
:
str
...
@@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
...
@@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
return
data
return
data
class
LogProbs
(
BaseModel
):
class
LogProbs
(
OpenAI
BaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
tokens
:
List
[
str
]
=
Field
(
default_factory
=
list
)
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
str
,
float
]]]]
=
None
top_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
str
,
float
]]]]
=
None
class
CompletionResponseChoice
(
BaseModel
):
class
CompletionResponseChoice
(
OpenAI
BaseModel
):
index
:
int
index
:
int
text
:
str
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
stop_reason
:
Optional
[
Union
[
int
,
str
]
]
=
Field
(
default
=
None
,
default
=
None
,
description
=
(
description
=
(
"The stop string or token id that caused the completion "
"The stop string or token id that caused the completion "
...
@@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
...
@@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
)
)
class
CompletionResponse
(
BaseModel
):
class
CompletionResponse
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
@@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
...
@@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
usage
:
UsageInfo
usage
:
UsageInfo
class
CompletionResponseStreamChoice
(
BaseModel
):
class
CompletionResponseStreamChoice
(
OpenAI
BaseModel
):
index
:
int
index
:
int
text
:
str
text
:
str
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
Field
(
stop_reason
:
Optional
[
Union
[
int
,
str
]
]
=
Field
(
default
=
None
,
default
=
None
,
description
=
(
description
=
(
"The stop string or token id that caused the completion "
"The stop string or token id that caused the completion "
...
@@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
...
@@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
)
)
class
CompletionStreamResponse
(
BaseModel
):
class
CompletionStreamResponse
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"text_completion"
object
:
str
=
"text_completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
@@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
...
@@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
ChatMessage
(
BaseModel
):
class
ChatMessage
(
OpenAI
BaseModel
):
role
:
str
role
:
str
content
:
str
content
:
str
class
ChatCompletionResponseChoice
(
BaseModel
):
class
ChatCompletionResponseChoice
(
OpenAI
BaseModel
):
index
:
int
index
:
int
message
:
ChatMessage
message
:
ChatMessage
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]
]
=
None
class
ChatCompletionResponse
(
BaseModel
):
class
ChatCompletionResponse
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"chat.completion"
object
:
str
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
@@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
...
@@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
usage
:
UsageInfo
usage
:
UsageInfo
class
DeltaMessage
(
BaseModel
):
class
DeltaMessage
(
OpenAI
BaseModel
):
role
:
Optional
[
str
]
=
None
role
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
class
ChatCompletionResponseStreamChoice
(
BaseModel
):
class
ChatCompletionResponseStreamChoice
(
OpenAI
BaseModel
):
index
:
int
index
:
int
delta
:
DeltaMessage
delta
:
DeltaMessage
logprobs
:
Optional
[
LogProbs
]
=
None
logprobs
:
Optional
[
LogProbs
]
=
None
finish_reason
:
Optional
[
Literal
[
"stop"
,
"length"
]
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
None
,
int
,
str
]
=
None
stop_reason
:
Optional
[
Union
[
int
,
str
]
]
=
None
class
ChatCompletionStreamResponse
(
BaseModel
):
class
ChatCompletionStreamResponse
(
OpenAI
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"chatcmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"chat.completion.chunk"
object
:
str
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
8947bc3c
import
codecs
import
codecs
import
time
import
time
from
typing
import
AsyncGenerator
,
AsyncIterator
,
List
,
Optional
,
Union
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
TypedDict
,
Union
,
final
)
from
fastapi
import
Request
from
fastapi
import
Request
from
openai.types.chat
import
(
ChatCompletionContentPartParam
,
ChatCompletionRole
)
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
...
@@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
UsageInfo
)
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
LoRA
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
...
@@ -20,20 +24,41 @@ from vllm.utils import random_uuid
...
@@ -20,20 +24,41 @@ from vllm.utils import random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
final
# So that it should be compatible with Dict[str, str]
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
class
OpenAIServingChat
(
OpenAIServing
):
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
response_role
:
str
,
response_role
:
str
,
lora_modules
:
Optional
[
List
[
LoRA
]]
=
None
,
lora_modules
:
Optional
[
List
[
LoRA
ModulePath
]]
=
None
,
chat_template
=
None
):
chat_template
:
Optional
[
str
]
=
None
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
engine
=
engine
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
)
self
.
response_role
=
response_role
self
.
response_role
=
response_role
self
.
_load_chat_template
(
chat_template
)
self
.
_load_chat_template
(
chat_template
)
def
_parse_chat_message_content
(
self
,
role
:
ChatCompletionRole
,
content
:
Optional
[
Union
[
str
,
Iterable
[
ChatCompletionContentPartParam
]]],
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
object
]]]:
if
content
is
None
:
return
[],
[]
if
isinstance
(
content
,
str
):
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)],
[]
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise
NotImplementedError
(
"Complex input not supported yet"
)
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
],
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
],
...
@@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
return
error_check_ret
return
error_check_ret
try
:
try
:
conversation
:
List
[
ConversationMessage
]
=
[]
for
m
in
request
.
messages
:
messages
,
_
=
self
.
_parse_chat_message_content
(
m
[
"role"
],
m
[
"content"
])
conversation
.
extend
(
messages
)
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
request
.
messages
,
conversation
=
conversation
,
tokenize
=
False
,
tokenize
=
False
,
add_generation_prompt
=
request
.
add_generation_prompt
)
add_generation_prompt
=
request
.
add_generation_prompt
,
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
@@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
async
def
chat_completion_stream_generator
(
async
def
chat_completion_stream_generator
(
self
,
request
:
ChatCompletionRequest
,
self
,
request
:
ChatCompletionRequest
,
result_generator
:
AsyncIterator
[
RequestOutput
],
request_id
:
str
result_generator
:
AsyncIterator
[
RequestOutput
],
)
->
Union
[
ErrorResponse
,
AsyncGenerator
[
str
,
None
]]:
request_id
:
str
)
->
AsyncGenerator
[
str
,
None
]:
model_name
=
self
.
served_model_names
[
0
]
model_name
=
self
.
served_model_names
[
0
]
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
chunk_object_type
=
"chat.completion.chunk"
chunk_object_type
=
"chat.completion.chunk"
...
@@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
model_name
=
self
.
served_model_names
[
0
]
model_name
=
self
.
served_model_names
[
0
]
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
final_res
:
RequestOutput
=
None
final_res
:
Optional
[
RequestOutput
]
=
None
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
...
@@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
return
response
return
response
def
_load_chat_template
(
self
,
chat_template
):
def
_load_chat_template
(
self
,
chat_template
:
Optional
[
str
]
):
tokenizer
=
self
.
tokenizer
tokenizer
=
self
.
tokenizer
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
8947bc3c
...
@@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
...
@@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
LogProbs
,
UsageInfo
)
LogProbs
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
LoRA
,
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
...
@@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
served_model_names
:
List
[
str
],
lora_modules
:
Optional
[
List
[
LoRA
]]
=
None
):
lora_modules
:
Optional
[
List
[
LoRA
ModulePath
]]
=
None
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
engine
=
engine
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
)
...
@@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
=
[]
generators
:
List
[
AsyncIterator
[
RequestOutput
]]
=
[]
try
:
try
:
sampling_params
=
request
.
to_sampling_params
()
sampling_params
=
request
.
to_sampling_params
()
lora_request
=
self
.
_maybe_get_lora
(
request
)
lora_request
=
self
.
_maybe_get_lora
(
request
)
...
@@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts
=
len
(
prompts
))
num_prompts
=
len
(
prompts
))
# Non-streaming response
# Non-streaming response
final_res_batch
:
RequestOutput
=
[
None
]
*
len
(
prompts
)
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
try
:
try
:
async
for
i
,
res
in
result_generator
:
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
if
await
raw_request
.
is_disconnected
():
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
8947bc3c
...
@@ -22,17 +22,15 @@ logger = init_logger(__name__)
...
@@ -22,17 +22,15 @@ logger = init_logger(__name__)
@
dataclass
@
dataclass
class
LoRA
:
class
LoRA
ModulePath
:
name
:
str
name
:
str
local_path
:
str
local_path
:
str
class
OpenAIServing
:
class
OpenAIServing
:
def
__init__
(
self
,
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
engine
:
AsyncLLMEngine
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]):
served_model_names
:
List
[
str
],
lora_modules
=
Optional
[
List
[
LoRA
]]):
self
.
engine
=
engine
self
.
engine
=
engine
self
.
served_model_names
=
served_model_names
self
.
served_model_names
=
served_model_names
if
lora_modules
is
None
:
if
lora_modules
is
None
:
...
@@ -158,7 +156,9 @@ class OpenAIServing:
...
@@ -158,7 +156,9 @@ class OpenAIServing:
})
})
return
json_str
return
json_str
async
def
_check_model
(
self
,
request
)
->
Optional
[
ErrorResponse
]:
async
def
_check_model
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
if
request
.
model
in
self
.
served_model_names
:
return
None
return
None
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
...
@@ -168,14 +168,16 @@ class OpenAIServing:
...
@@ -168,14 +168,16 @@ class OpenAIServing:
err_type
=
"NotFoundError"
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
status_code
=
HTTPStatus
.
NOT_FOUND
)
def
_maybe_get_lora
(
self
,
request
)
->
Optional
[
LoRARequest
]:
def
_maybe_get_lora
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
)
->
Optional
[
LoRARequest
]:
if
request
.
model
in
self
.
served_model_names
:
if
request
.
model
in
self
.
served_model_names
:
return
None
return
None
for
lora
in
self
.
lora_requests
:
for
lora
in
self
.
lora_requests
:
if
request
.
model
==
lora
.
lora_name
:
if
request
.
model
==
lora
.
lora_name
:
return
lora
return
lora
# if _check_model has been called earlier, this will be unreachable
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
"The model `{request.model}` does not exist."
)
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
def
_validate_prompt_and_tokenize
(
def
_validate_prompt_and_tokenize
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment