Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc0d9dfc
Unverified
Commit
fc0d9dfc
authored
May 16, 2024
by
Cyrus Leung
Committed by
GitHub
May 15, 2024
Browse files
[Frontend] Re-enable custom roles in Chat Completions API (#4758)
parent
361c461a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
26 deletions
+108
-26
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+30
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+36
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+42
-24
No files found.
tests/entrypoints/test_openai_server.py
View file @
fc0d9dfc
...
@@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
...
@@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
assert
content
==
"2"
assert
content
==
"2"
async
def
test_custom_role
(
server
,
client
:
openai
.
AsyncOpenAI
):
# Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way
resp1
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"my-custom-role"
,
"content"
:
"what is 1+1?"
,
}],
# type: ignore
temperature
=
0
,
seed
=
0
)
resp2
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
[{
"role"
:
"my-custom-role"
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
"what is 1+1?"
}]
}],
# type: ignore
temperature
=
0
,
seed
=
0
)
content1
=
resp1
.
choices
[
0
].
message
.
content
content2
=
resp2
.
choices
[
0
].
message
.
content
assert
content1
==
content2
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
async
def
test_guided_grammar
(
server
,
client
:
openai
.
AsyncOpenAI
):
simple_sql_grammar
=
"""
simple_sql_grammar
=
"""
start: select_statement
start: select_statement
...
...
vllm/entrypoints/openai/protocol.py
View file @
fc0d9dfc
...
@@ -3,16 +3,50 @@
...
@@ -3,16 +3,50 @@
import
time
import
time
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
openai.types.chat
import
torch
import
torch
from
openai.types.chat
import
ChatCompletionMessageParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
# pydantic needs the TypedDict from typing_extensions
from
typing_extensions
import
Annotated
,
Required
,
TypedDict
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
class
CustomChatCompletionContentPartParam
(
TypedDict
,
total
=
False
):
__pydantic_config__
=
ConfigDict
(
extra
=
"allow"
)
# type: ignore
type
:
Required
[
str
]
"""The type of the content part."""
ChatCompletionContentPartParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionContentPartParam
,
CustomChatCompletionContentPartParam
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
"""The role of the message's author."""
content
:
Union
[
str
,
List
[
ChatCompletionContentPartParam
]]
"""The contents of the message."""
name
:
str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
ChatCompletionMessageParam
=
Union
[
openai
.
types
.
chat
.
ChatCompletionMessageParam
,
CustomChatCompletionMessageParam
]
class
OpenAIBaseModel
(
BaseModel
):
class
OpenAIBaseModel
(
BaseModel
):
# OpenAI API does not allow extra fields
# OpenAI API does not allow extra fields
model_config
=
ConfigDict
(
extra
=
"forbid"
)
model_config
=
ConfigDict
(
extra
=
"forbid"
)
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
fc0d9dfc
import
codecs
import
codecs
import
time
import
time
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Iterable
,
List
,
from
dataclasses
import
dataclass
Optional
,
Tuple
,
TypedDict
,
Union
,
final
)
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Iterable
,
List
,
Optional
,
TypedDict
,
Union
,
cast
,
final
)
from
fastapi
import
Request
from
fastapi
import
Request
from
openai.types.chat
import
(
ChatCompletionContentPartParam
,
from
openai.types.chat
import
ChatCompletionContentPartTextParam
ChatCompletionRole
)
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionContentPartParam
,
ChatCompletionMessageParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
...
@@ -31,6 +32,11 @@ class ConversationMessage(TypedDict):
...
@@ -31,6 +32,11 @@ class ConversationMessage(TypedDict):
content
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
class
OpenAIServingChat
(
OpenAIServing
):
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -77,27 +83,40 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -77,27 +83,40 @@ class OpenAIServingChat(OpenAIServing):
logger
.
warning
(
logger
.
warning
(
"No chat template provided. Chat API will not work."
)
"No chat template provided. Chat API will not work."
)
def
_parse_chat_message_content
(
def
_parse_chat_message_content
_parts
(
self
,
self
,
role
:
ChatCompletionRole
,
role
:
str
,
content
:
Optional
[
Union
[
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
Iterable
[
ChatCompletionContentPartParam
]]],
)
->
ChatMessageParseResult
:
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
object
]]]:
if
content
is
None
:
return
[],
[]
if
isinstance
(
content
,
str
):
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)],
[]
texts
:
List
[
str
]
=
[]
texts
:
List
[
str
]
=
[]
for
_
,
part
in
enumerate
(
content
):
if
part
[
"type"
]
==
"text"
:
for
_
,
part
in
enumerate
(
parts
):
text
=
part
[
"text"
]
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
texts
.
append
(
text
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part
[
'type'
]
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
"
\n
"
.
join
(
texts
))]
return
ChatMessageParseResult
(
messages
=
messages
)
def
_parse_chat_message_content
(
self
,
message
:
ChatCompletionMessageParam
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
)
return
[
ConversationMessage
(
role
=
role
,
content
=
"
\n
"
.
join
(
texts
))],
[]
return
self
.
_parse_chat_message_content_parts
(
role
,
content
)
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
self
,
request
:
ChatCompletionRequest
,
raw_request
:
Request
...
@@ -119,11 +138,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -119,11 +138,10 @@ class OpenAIServingChat(OpenAIServing):
try
:
try
:
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
:
List
[
ConversationMessage
]
=
[]
for
m
in
request
.
messages
:
for
msg
in
request
.
messages
:
messages
,
_
=
self
.
_parse_chat_message_content
(
parsed_msg
=
self
.
_parse_chat_message_content
(
msg
)
m
[
"role"
],
m
[
"content"
])
conversation
.
extend
(
messages
)
conversation
.
extend
(
parsed_msg
.
messages
)
prompt
=
self
.
tokenizer
.
apply_chat_template
(
prompt
=
self
.
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
conversation
=
conversation
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment