Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
739b61a3
Unverified
Commit
739b61a3
authored
Jul 23, 2024
by
Cyrus Leung
Committed by
GitHub
Jul 22, 2024
Browse files
[Frontend] Refactor prompt processing (#4028)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
89c1c6a1
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
56 deletions
+86
-56
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+79
-27
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-4
vllm/inputs/data.py
vllm/inputs/data.py
+1
-23
vllm/sequence.py
vllm/sequence.py
+3
-2
No files found.
vllm/entrypoints/openai/serving_tokenization.py
View file @
739b61a3
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
load_chat_template
,
load_chat_template
,
parse_chat_message_content
)
parse_chat_message_content
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
from
vllm.entrypoints.openai.protocol
import
(
DetokenizeRequest
,
DetokenizeResponse
,
DetokenizeResponse
,
ErrorResponse
,
TokenizeChatRequest
,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
)
TokenizeResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.utils
import
random_uuid
class
OpenAIServingTokenization
(
OpenAIServing
):
class
OpenAIServingTokenization
(
OpenAIServing
):
def
__init__
(
self
,
def
__init__
(
engine
:
AsyncLLMEngine
,
self
,
model_config
:
ModelConfig
,
engine
:
AsyncLLMEngine
,
served_model_names
:
List
[
str
],
model_config
:
ModelConfig
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]]
=
None
,
served_model_names
:
List
[
str
],
chat_template
:
Optional
[
str
]
=
None
):
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
):
super
().
__init__
(
engine
=
engine
,
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
served_model_names
=
served_model_names
,
lora_modules
=
lora_modules
)
lora_modules
=
lora_modules
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
# If this is None we use the tokenizer's default chat template
# If this is None we use the tokenizer's default chat template
self
.
chat_template
=
load_chat_template
(
chat_template
)
self
.
chat_template
=
load_chat_template
(
chat_template
)
async
def
create_tokenize
(
self
,
async
def
create_tokenize
(
request
:
TokenizeRequest
)
->
TokenizeResponse
:
self
,
request
:
TokenizeRequest
,
)
->
Union
[
TokenizeResponse
,
ErrorResponse
]:
error_check_ret
=
await
self
.
_check_model
(
request
)
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
if
not
(
request
.
prompt
or
request
.
messages
):
request_id
=
f
"tokn-
{
random_uuid
()
}
"
return
self
.
create_error_response
(
"Either `prompt` or `messages` should be provided."
)
if
(
request
.
prompt
and
request
.
messages
):
(
return
self
.
create_error_response
(
lora_request
,
"Only one of `prompt` or `messages` should be provided."
)
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
_
,
lora_request
=
self
.
_maybe_get_adapter
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
if
request
.
messages
:
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
:
List
[
ConversationMessage
]
=
[]
for
message
in
request
.
messages
:
for
message
in
request
.
messages
:
result
=
parse_chat_message_content
(
message
,
self
.
model_config
,
result
=
parse_chat_message_content
(
message
,
model_config
,
tokenizer
)
tokenizer
)
conversation
.
extend
(
result
.
messages
)
conversation
.
extend
(
result
.
messages
)
request
.
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
conversation
=
conversation
,
conversation
=
conversation
,
tokenize
=
False
,
tokenize
=
False
,
chat_template
=
self
.
chat_template
)
chat_template
=
self
.
chat_template
)
assert
isinstance
(
prompt
,
str
)
else
:
prompt
=
request
.
prompt
self
.
_log_inputs
(
request_id
,
prompt
,
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
(
input_ids
,
input_text
)
=
await
self
.
_validate_prompt_and_tokenize
(
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input
=
self
.
_tokenize_prompt_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt
=
request
.
prompt
,
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
add_special_tokens
=
request
.
add_special_tokens
,
)
input_ids
=
prompt_input
[
"prompt_token_ids"
]
return
TokenizeResponse
(
tokens
=
input_ids
,
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
)
->
DetokenizeResponse
:
self
,
request
:
DetokenizeRequest
,
)
->
Union
[
DetokenizeResponse
,
ErrorResponse
]:
error_check_ret
=
await
self
.
_check_model
(
request
)
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
_
,
lora_request
=
self
.
_maybe_get_adapter
(
request
)
request_id
=
f
"tokn-
{
random_uuid
()
}
"
(
lora_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine
.
get_tokenizer
(
lora_request
)
(
input_ids
,
input_text
)
=
await
self
.
_validate_prompt_and_tokenize
(
request
,
tokenizer
,
prompt_ids
=
request
.
tokens
)
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
params
=
None
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
if
prompt_adapter_request
is
not
None
:
raise
NotImplementedError
(
"Prompt adapter is not supported "
"for tokenization"
)
prompt_input
=
self
.
_tokenize_prompt_input
(
request
,
tokenizer
,
request
.
tokens
,
)
input_text
=
prompt_input
[
"prompt"
]
return
DetokenizeResponse
(
prompt
=
input_text
)
return
DetokenizeResponse
(
prompt
=
input_text
)
vllm/inputs/__init__.py
View file @
739b61a3
from
.data
import
(
LLMInputs
,
ParsedText
,
ParsedTokens
,
PromptInputs
,
from
.data
import
(
LLMInputs
,
ParsedText
,
ParsedTokens
,
PromptInputs
,
PromptStrictInputs
,
TextPrompt
,
TextTokensPrompt
,
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
TokensPrompt
,
parse_and_batch_prompt
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
...
@@ -14,6 +13,6 @@ See also:
...
@@ -14,6 +13,6 @@ See also:
__all__
=
[
__all__
=
[
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"
TextTokensPrompt"
,
"PromptStrictInputs"
,
"PromptInputs
"
,
"TokensPrompt"
,
"
PromptInputs"
,
"LLMInputs"
,
"INPUT_REGISTRY
"
,
"LLMInputs"
,
"INPUT_REGISTRY"
,
"InputContext"
,
"InputRegistry"
"InputContext"
,
"InputRegistry"
]
]
vllm/inputs/data.py
View file @
739b61a3
...
@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
...
@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
"""
"""
class
TextTokensPrompt
(
TypedDict
):
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt
:
str
"""The prompt text."""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
The inputs to the LLM, which can take one of the following forms:
The inputs to the LLM, which can take one of the following forms:
...
@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
...
@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
"""
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
TextTokensPrompt
]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class
LLMInputs
(
TypedDict
):
class
LLMInputs
(
TypedDict
):
"""
"""
...
...
vllm/sequence.py
View file @
739b61a3
...
@@ -5,7 +5,8 @@ import math
...
@@ -5,7 +5,8 @@ import math
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
)
import
torch
import
torch
...
@@ -438,7 +439,7 @@ class SequenceGroup:
...
@@ -438,7 +439,7 @@ class SequenceGroup:
embeddings
:
Optional
[
List
[
float
]]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
trace_headers
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment