Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a3d2a5b
Unverified
Commit
7a3d2a5b
authored
Jul 16, 2024
by
sasha0552
Committed by
GitHub
Jul 16, 2024
Browse files
[Frontend] Support for chat completions input in the tokenize endpoint (#5923)
parent
d9701151
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
386 additions
and
244 deletions
+386
-244
tests/async_engine/test_chat_template.py
tests/async_engine/test_chat_template.py
+5
-9
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+0
-49
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+128
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+8
-2
vllm/entrypoints/openai/chat_utils.py
vllm/entrypoints/openai/chat_utils.py
+156
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-3
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+10
-151
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-30
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+73
-0
No files found.
tests/async_engine/test_chat_template.py
View file @
7a3d2a5b
...
@@ -4,8 +4,8 @@ from dataclasses import dataclass
...
@@ -4,8 +4,8 @@ from dataclasses import dataclass
import
pytest
import
pytest
from
vllm.entrypoints.openai.chat_utils
import
load_chat_template
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
chatml_jinja_path
=
pathlib
.
Path
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
chatml_jinja_path
=
pathlib
.
Path
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
...
@@ -64,8 +64,7 @@ def test_load_chat_template():
...
@@ -64,8 +64,7 @@ def test_load_chat_template():
# Testing chatml template
# Testing chatml template
tokenizer
=
MockTokenizer
()
tokenizer
=
MockTokenizer
()
mock_serving_chat
=
MockServingChat
(
tokenizer
)
mock_serving_chat
=
MockServingChat
(
tokenizer
)
OpenAIServingChat
.
_load_chat_template
(
mock_serving_chat
,
load_chat_template
(
mock_serving_chat
,
chat_template
=
chatml_jinja_path
)
chat_template
=
chatml_jinja_path
)
template_content
=
tokenizer
.
chat_template
template_content
=
tokenizer
.
chat_template
...
@@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
...
@@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike():
mock_serving_chat
=
MockServingChat
(
tokenizer
)
mock_serving_chat
=
MockServingChat
(
tokenizer
)
with
pytest
.
raises
(
ValueError
,
match
=
"looks like a file path"
):
with
pytest
.
raises
(
ValueError
,
match
=
"looks like a file path"
):
OpenAIServingChat
.
_load_chat_template
(
mock_serving_chat
,
load_chat_template
(
mock_serving_chat
,
chat_template
=
template
)
chat_template
=
template
)
def
test_no_load_chat_template_literallike
():
def
test_no_load_chat_template_literallike
():
...
@@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
...
@@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike():
tokenizer
=
MockTokenizer
()
tokenizer
=
MockTokenizer
()
mock_serving_chat
=
MockServingChat
(
tokenizer
)
mock_serving_chat
=
MockServingChat
(
tokenizer
)
OpenAIServingChat
.
_load_chat_template
(
mock_serving_chat
,
load_chat_template
(
mock_serving_chat
,
chat_template
=
template
)
chat_template
=
template
)
template_content
=
tokenizer
.
chat_template
template_content
=
tokenizer
.
chat_template
assert
template_content
==
template
assert
template_content
==
template
...
@@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Initialize the tokenizer
# Initialize the tokenizer
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model
)
mock_serving_chat
=
MockServingChat
(
tokenizer
)
mock_serving_chat
=
MockServingChat
(
tokenizer
)
OpenAIServingChat
.
_load_chat_template
(
mock_serving_chat
,
load_chat_template
(
mock_serving_chat
,
chat_template
=
template
)
chat_template
=
template
)
# Create a mock request object using keyword arguments
# Create a mock request object using keyword arguments
mock_request
=
ChatCompletionRequest
(
mock_request
=
ChatCompletionRequest
(
...
...
tests/entrypoints/openai/test_completion.py
View file @
7a3d2a5b
...
@@ -6,7 +6,6 @@ from typing import List
...
@@ -6,7 +6,6 @@ from typing import List
import
jsonschema
import
jsonschema
import
openai
# use the official client for correctness check
import
openai
# use the official client for correctness check
import
pytest
import
pytest
import
requests
# downloading lora to test lora requests
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
openai
import
BadRequestError
from
openai
import
BadRequestError
...
@@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
...
@@ -636,51 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
prompt
=
"Give an example string that fits this regex"
,
prompt
=
"Give an example string that fits this regex"
,
extra_body
=
dict
(
guided_regex
=
sample_regex
,
extra_body
=
dict
(
guided_regex
=
sample_regex
,
guided_json
=
sample_json_schema
))
guided_json
=
sample_json_schema
))
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_tokenize
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_name
,
tokenizer_mode
=
"fast"
)
for
add_special
in
[
False
,
True
]:
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_special_tokens"
:
add_special
,
"model"
:
model_name
,
"prompt"
:
prompt
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_detokenize
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
]
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_name
,
tokenizer_mode
=
"fast"
)
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
response
=
requests
.
post
(
base_url
+
"detokenize"
,
json
=
{
"model"
:
model_name
,
"tokens"
:
tokens
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
tests/entrypoints/openai/test_tokenization.py
0 → 100644
View file @
7a3d2a5b
import
openai
# use the official client for correctness check
import
pytest
import
requests
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
...utils
import
RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
with
RemoteOpenAIServer
([
"--model"
,
MODEL_NAME
,
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
"--max-num-seqs"
,
"128"
,
])
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
(
server
):
return
server
.
get_async_client
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_tokenize_completions
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_name
,
tokenizer_mode
=
"fast"
)
for
add_special
in
[
False
,
True
]:
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_special_tokens"
:
add_special
,
"model"
:
model_name
,
"prompt"
:
prompt
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_tokenize_chat
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_name
,
tokenizer_mode
=
"fast"
)
for
add_generation
in
[
False
,
True
]:
for
add_special
in
[
False
,
True
]:
conversation
=
[{
"role"
:
"user"
,
"content"
:
"Hi there!"
},
{
"role"
:
"assistant"
,
"content"
:
"Nice to meet you!"
},
{
"role"
:
"user"
,
"content"
:
"Can I ask a question?"
}]
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
add_generation
,
conversation
=
conversation
,
tokenize
=
False
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_generation_prompt"
:
add_generation
,
"add_special_tokens"
:
add_special
,
"messages"
:
conversation
,
"model"
:
model_name
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_detokenize
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
].
strip
(
"/"
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model_name
,
tokenizer_mode
=
"fast"
)
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
response
=
requests
.
post
(
base_url
+
"/detokenize"
,
json
=
{
"model"
:
model_name
,
"tokens"
:
tokens
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
vllm/entrypoints/openai/api_server.py
View file @
7a3d2a5b
...
@@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -46,6 +48,7 @@ engine_args: AsyncEngineArgs
...
@@ -46,6 +48,7 @@ engine_args: AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
logger
=
init_logger
(
'vllm.entrypoints.openai.api_server'
)
logger
=
init_logger
(
'vllm.entrypoints.openai.api_server'
)
...
@@ -86,7 +89,7 @@ async def health() -> Response:
...
@@ -86,7 +89,7 @@ async def health() -> Response:
@
router
.
post
(
"/tokenize"
)
@
router
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
async
def
tokenize
(
request
:
TokenizeRequest
):
generator
=
await
openai_serving_
comple
tion
.
create_tokenize
(
request
)
generator
=
await
openai_serving_
tokeniza
tion
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
...
@@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
@
router
.
post
(
"/detokenize"
)
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
async
def
detokenize
(
request
:
DetokenizeRequest
):
generator
=
await
openai_serving_
comple
tion
.
create_detokenize
(
request
)
generator
=
await
openai_serving_
tokeniza
tion
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
...
@@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
global
openai_serving_chat
global
openai_serving_chat
global
openai_serving_completion
global
openai_serving_completion
global
openai_serving_embedding
global
openai_serving_embedding
global
openai_serving_tokenization
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
served_model_names
,
served_model_names
,
...
@@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
...
@@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
args
.
prompt_adapters
)
args
.
prompt_adapters
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
)
served_model_names
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
engine
,
model_config
,
served_model_names
,
args
.
chat_template
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
logger
.
info
(
"Available routes are:"
)
logger
.
info
(
"Available routes are:"
)
...
...
vllm/entrypoints/openai/chat_utils.py
0 → 100644
View file @
7a3d2a5b
import
codecs
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Awaitable
,
Iterable
,
List
,
Optional
,
TypedDict
,
cast
,
final
from
openai.types.chat
import
(
ChatCompletionContentPartImageParam
,
ChatCompletionContentPartTextParam
)
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionContentPartParam
,
ChatCompletionMessageParam
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
logger
=
init_logger
(
__name__
)
@
final
# So that it should be compatible with Dict[str, str]
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
default_factory
=
list
)
def
load_chat_template
(
engine
:
OpenAIServing
,
chat_template
:
Optional
[
str
]):
tokenizer
=
engine
.
tokenizer
if
chat_template
is
not
None
:
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
tokenizer
.
chat_template
=
f
.
read
()
except
OSError
as
e
:
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer
.
chat_template
=
codecs
.
decode
(
chat_template
,
"unicode_escape"
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
tokenizer
.
chat_template
)
elif
tokenizer
.
chat_template
is
not
None
:
logger
.
info
(
"Using default chat template:
\n
%s"
,
tokenizer
.
chat_template
)
else
:
logger
.
warning
(
"No chat template provided. Chat API will not work."
)
@
lru_cache
(
maxsize
=
None
)
def
_image_token_str
(
engine
:
OpenAIServing
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
engine
.
model_config
.
hf_config
.
model_type
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
engine
.
tokenizer
.
decode
(
engine
.
model_config
.
hf_config
.
image_token_index
)
else
:
raise
TypeError
(
"Unknown model type: {model_type}"
)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def
_get_full_image_text_prompt
(
engine
:
OpenAIServing
,
image_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return
f
"
{
image_token_str
}
\n
{
text_prompt
}
"
def
_parse_chat_message_content_parts
(
engine
:
OpenAIServing
,
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple 'image_url' input is currently not supported."
)
image_url
=
cast
(
ChatCompletionContentPartImageParam
,
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported and "
"will be ignored."
)
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
mm_futures
:
image_token_str
=
_image_token_str
(
engine
)
if
image_token_str
is
not
None
:
if
image_token_str
in
text_prompt
:
logger
.
warning
(
"Detected image token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
_get_full_image_text_prompt
(
engine
,
image_token_str
=
image_token_str
,
text_prompt
=
text_prompt
,
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
parse_chat_message_content
(
engine
:
OpenAIServing
,
message
:
ChatCompletionMessageParam
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[],
mm_futures
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
return
_parse_chat_message_content_parts
(
engine
,
role
,
content
)
vllm/entrypoints/openai/protocol.py
View file @
7a3d2a5b
...
@@ -738,15 +738,17 @@ class BatchRequestOutput(OpenAIBaseModel):
...
@@ -738,15 +738,17 @@ class BatchRequestOutput(OpenAIBaseModel):
class
TokenizeRequest
(
OpenAIBaseModel
):
class
TokenizeRequest
(
OpenAIBaseModel
):
add_generation_prompt
:
bool
=
Field
(
default
=
True
)
add_special_tokens
:
bool
=
Field
(
default
=
False
)
prompt
:
Optional
[
str
]
=
Field
(
default
=
None
)
messages
:
Optional
[
List
[
ChatCompletionMessageParam
]]
=
Field
(
default
=
None
)
model
:
str
model
:
str
prompt
:
str
add_special_tokens
:
bool
=
Field
(
default
=
True
)
class
TokenizeResponse
(
OpenAIBaseModel
):
class
TokenizeResponse
(
OpenAIBaseModel
):
tokens
:
List
[
int
]
count
:
int
count
:
int
max_model_len
:
int
max_model_len
:
int
tokens
:
List
[
int
]
class
DetokenizeRequest
(
OpenAIBaseModel
):
class
DetokenizeRequest
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
7a3d2a5b
import
codecs
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
List
,
from
functools
import
cached_property
Optional
)
from
typing
import
(
AsyncGenerator
,
AsyncIterator
,
Awaitable
,
Dict
,
Iterable
,
List
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
TypedDict
,
Union
,
cast
,
final
from
typing
import
Union
from
fastapi
import
Request
from
fastapi
import
Request
from
openai.types.chat
import
(
ChatCompletionContentPartImageParam
,
ChatCompletionContentPartTextParam
)
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionContentPartParam
,
ChatCompletionLogProb
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ChatCompletionLogProbsContent
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionMessageParam
,
ChatCompletionNamedToolChoiceParam
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionResponseChoice
,
ChatCompletionResponseStreamChoice
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
ChatCompletionStreamResponse
,
ChatMessage
,
DeltaMessage
,
ErrorResponse
,
...
@@ -28,7 +25,6 @@ from vllm.logger import init_logger
...
@@ -28,7 +25,6 @@ from vllm.logger import init_logger
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
get_guided_decoding_logits_processor
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
@@ -38,19 +34,6 @@ from vllm.utils import random_uuid
...
@@ -38,19 +34,6 @@ from vllm.utils import random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
final
# So that it should be compatible with Dict[str, str]
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
@
dataclass
(
frozen
=
True
)
class
ChatMessageParseResult
:
messages
:
List
[
ConversationMessage
]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
field
(
default_factory
=
list
)
class
OpenAIServingChat
(
OpenAIServing
):
class
OpenAIServingChat
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -66,131 +49,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -66,131 +49,7 @@ class OpenAIServingChat(OpenAIServing):
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
)
self
.
response_role
=
response_role
self
.
response_role
=
response_role
self
.
_load_chat_template
(
chat_template
)
load_chat_template
(
self
,
chat_template
)
def
_load_chat_template
(
self
,
chat_template
:
Optional
[
str
]):
tokenizer
=
self
.
tokenizer
if
chat_template
is
not
None
:
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
tokenizer
.
chat_template
=
f
.
read
()
except
OSError
as
e
:
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer
.
chat_template
=
codecs
.
decode
(
chat_template
,
"unicode_escape"
)
logger
.
info
(
"Using supplied chat template:
\n
%s"
,
tokenizer
.
chat_template
)
elif
tokenizer
.
chat_template
is
not
None
:
logger
.
info
(
"Using default chat template:
\n
%s"
,
tokenizer
.
chat_template
)
else
:
logger
.
warning
(
"No chat template provided. Chat API will not work."
)
@
cached_property
def
image_token_str
(
self
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
self
.
model_config
.
hf_config
.
model_type
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
if
model_type
in
(
"blip-2"
,
"chatglm"
,
"fuyu"
,
"minicpmv"
,
"paligemma"
):
# These models do not use image tokens in the prompt
return
None
if
model_type
.
startswith
(
"llava"
):
return
self
.
tokenizer
.
decode
(
self
.
model_config
.
hf_config
.
image_token_index
)
else
:
raise
TypeError
(
"Unknown model type: {model_type}"
)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
def
_get_full_image_text_prompt
(
self
,
image_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine image and text prompts for vision language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return
f
"
{
image_token_str
}
\n
{
text_prompt
}
"
def
_parse_chat_message_content_parts
(
self
,
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple 'image_url' input is currently not supported."
)
image_url
=
cast
(
ChatCompletionContentPartImageParam
,
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported and "
"will be ignored."
)
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
mm_futures
:
image_token_str
=
self
.
image_token_str
if
image_token_str
is
not
None
:
if
image_token_str
in
text_prompt
:
logger
.
warning
(
"Detected image token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
self
.
_get_full_image_text_prompt
(
image_token_str
=
image_token_str
,
text_prompt
=
text_prompt
,
)
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
text_prompt
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
mm_futures
)
def
_parse_chat_message_content
(
self
,
message
:
ChatCompletionMessageParam
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
if
content
is
None
:
return
ChatMessageParseResult
(
messages
=
[],
mm_futures
=
[])
if
isinstance
(
content
,
str
):
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
return
self
.
_parse_chat_message_content_parts
(
role
,
content
)
async
def
create_chat_completion
(
async
def
create_chat_completion
(
self
,
self
,
...
@@ -216,7 +75,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -216,7 +75,7 @@ class OpenAIServingChat(OpenAIServing):
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
for
msg
in
request
.
messages
:
for
msg
in
request
.
messages
:
chat_parsed_result
=
self
.
_
parse_chat_message_content
(
msg
)
chat_parsed_result
=
parse_chat_message_content
(
self
,
msg
)
conversation
.
extend
(
chat_parsed_result
.
messages
)
conversation
.
extend
(
chat_parsed_result
.
messages
)
mm_futures
.
extend
(
chat_parsed_result
.
mm_futures
)
mm_futures
.
extend
(
chat_parsed_result
.
mm_futures
)
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
7a3d2a5b
...
@@ -16,10 +16,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
...
@@ -16,10 +16,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
DetokenizeRequest
,
UsageInfo
)
DetokenizeResponse
,
TokenizeRequest
,
TokenizeResponse
,
UsageInfo
)
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
,
OpenAIServing
,
...
@@ -457,29 +454,3 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -457,29 +454,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens
=
out_tokens
,
tokens
=
out_tokens
,
top_logprobs
=
out_top_logprobs
,
top_logprobs
=
out_top_logprobs
,
)
)
async
def
create_tokenize
(
self
,
request
:
TokenizeRequest
)
->
TokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
request
.
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
)
->
DetokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
request
.
tokens
)
return
DetokenizeResponse
(
prompt
=
input_text
)
vllm/entrypoints/openai/serving_tokenization.py
0 → 100644
View file @
7a3d2a5b
from
typing
import
List
,
Optional
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
parse_chat_message_content
)
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
DetokenizeResponse
,
TokenizeRequest
,
TokenizeResponse
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
class
OpenAIServingTokenization
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
],
chat_template
:
Optional
[
str
]
=
None
):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
)
load_chat_template
(
self
,
chat_template
)
async
def
create_tokenize
(
self
,
request
:
TokenizeRequest
)
->
TokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
if
not
(
request
.
prompt
or
request
.
messages
):
return
self
.
create_error_response
(
"Either `prompt` or `messages` should be provided."
)
if
(
request
.
prompt
and
request
.
messages
):
return
self
.
create_error_response
(
"Only one of `prompt` or `messages` should be provided."
)
if
request
.
messages
:
conversation
:
List
[
ConversationMessage
]
=
[]
for
message
in
request
.
messages
:
conversation
.
extend
(
parse_chat_message_content
(
self
,
message
).
messages
)
request
.
prompt
=
self
.
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
conversation
=
conversation
,
tokenize
=
False
)
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
request
.
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
)
->
DetokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
request
.
tokens
)
return
DetokenizeResponse
(
prompt
=
input_text
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment