Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cb234955
Unverified
Commit
cb234955
authored
May 02, 2025
by
Cyrus Leung
Committed by
GitHub
May 02, 2025
Browse files
[Misc] Clean up input processing (#17582)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3a500cd0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
357 additions
and
283 deletions
+357
-283
tests/models/multimodal/pooling/test_intern_vit.py
tests/models/multimodal/pooling/test_intern_vit.py
+8
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+0
-4
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-27
vllm/engine/protocol.py
vllm/engine/protocol.py
+3
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-2
vllm/inputs/data.py
vllm/inputs/data.py
+18
-5
vllm/inputs/parse.py
vllm/inputs/parse.py
+8
-19
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+302
-215
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+7
-5
No files found.
tests/models/multimodal/pooling/test_intern_vit.py
View file @
cb234955
...
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
...
@@ -6,6 +6,7 @@ from huggingface_hub import snapshot_download
from
transformers
import
AutoConfig
,
AutoModel
,
CLIPImageProcessor
from
transformers
import
AutoConfig
,
AutoModel
,
CLIPImageProcessor
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
....conftest
import
ImageTestAssets
from
....conftest
import
ImageTestAssets
...
@@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
...
@@ -14,6 +15,7 @@ from ....conftest import ImageTestAssets
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
DOWNLOAD_PATTERN
=
[
"*.json"
,
"*.py"
,
"*.safetensors"
,
"*.txt"
,
"*.model"
]
@
torch
.
inference_mode
()
def
run_intern_vit_test
(
def
run_intern_vit_test
(
image_assets
:
ImageTestAssets
,
image_assets
:
ImageTestAssets
,
model_id
:
str
,
model_id
:
str
,
...
@@ -21,11 +23,12 @@ def run_intern_vit_test(
...
@@ -21,11 +23,12 @@ def run_intern_vit_test(
dtype
:
str
,
dtype
:
str
,
):
):
model
=
snapshot_download
(
model_id
,
allow_patterns
=
DOWNLOAD_PATTERN
)
model
=
snapshot_download
(
model_id
,
allow_patterns
=
DOWNLOAD_PATTERN
)
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
img_processor
=
CLIPImageProcessor
.
from_pretrained
(
model
)
img_processor
=
CLIPImageProcessor
.
from_pretrained
(
model
)
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
pixel_values
=
[
pixel_values
=
[
img_processor
(
images
,
return_tensors
=
'pt'
).
pixel_values
.
to
(
dtype
)
img_processor
(
images
,
return_tensors
=
'pt'
).
pixel_values
.
to
(
torch_
dtype
)
for
images
in
images
for
images
in
images
]
]
...
@@ -34,7 +37,7 @@ def run_intern_vit_test(
...
@@ -34,7 +37,7 @@ def run_intern_vit_test(
config
.
norm_type
=
"rms_norm"
config
.
norm_type
=
"rms_norm"
hf_model
=
AutoModel
.
from_pretrained
(
model
,
hf_model
=
AutoModel
.
from_pretrained
(
model
,
torch_dtype
=
dtype
,
torch_dtype
=
torch_
dtype
,
trust_remote_code
=
True
).
to
(
"cuda"
)
trust_remote_code
=
True
).
to
(
"cuda"
)
hf_outputs_per_image
=
[
hf_outputs_per_image
=
[
hf_model
(
pixel_value
.
to
(
"cuda"
)).
last_hidden_state
hf_model
(
pixel_value
.
to
(
"cuda"
)).
last_hidden_state
...
@@ -48,7 +51,7 @@ def run_intern_vit_test(
...
@@ -48,7 +51,7 @@ def run_intern_vit_test(
del
hf_model
del
hf_model
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
vllm_model
=
vllm_model
.
to
(
"cuda"
,
dtype
)
vllm_model
=
vllm_model
.
to
(
"cuda"
,
torch_
dtype
)
vllm_outputs_per_image
=
[
vllm_outputs_per_image
=
[
vllm_model
(
pixel_values
=
pixel_value
.
to
(
"cuda"
))
vllm_model
(
pixel_values
=
pixel_value
.
to
(
"cuda"
))
for
pixel_value
in
pixel_values
for
pixel_value
in
pixel_values
...
@@ -66,9 +69,8 @@ def run_intern_vit_test(
...
@@ -66,9 +69,8 @@ def run_intern_vit_test(
"OpenGVLab/InternViT-300M-448px"
,
"OpenGVLab/InternViT-300M-448px"
,
"OpenGVLab/InternViT-6B-448px-V1-5"
,
"OpenGVLab/InternViT-6B-448px-V1-5"
,
])
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
torch
.
inference_mode
()
def
test_models
(
dist_init
,
image_assets
,
model_id
,
dtype
:
str
)
->
None
:
def
test_models
(
image_assets
,
model_id
,
dtype
:
str
)
->
None
:
run_intern_vit_test
(
run_intern_vit_test
(
image_assets
,
image_assets
,
model_id
,
model_id
,
...
...
vllm/engine/async_llm_engine.py
View file @
cb234955
...
@@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -497,10 +497,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt
[
"prompt_token_ids"
]
=
[
0
prompt
[
"prompt_token_ids"
]
=
[
0
]
*
prompt
[
"prompt_embeds"
].
shape
[
-
2
]
]
*
prompt
[
"prompt_embeds"
].
shape
[
-
2
]
if
self
.
tokenizer
is
not
None
:
tokenizer
=
await
self
.
get_tokenizer_async
(
lora_request
)
self
.
_validate_token_prompt
(
prompt
,
tokenizer
=
tokenizer
)
processed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
processed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
prompt
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
...
vllm/engine/llm_engine.py
View file @
cb234955
...
@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
...
@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors
as
get_openai_logits_processors
)
get_logits_processors
as
get_openai_logits_processors
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.inputs
import
ProcessorInputs
,
PromptType
,
SingletonInputs
from
vllm.inputs
import
ProcessorInputs
,
PromptType
,
SingletonInputs
from
vllm.inputs.parse
import
is_token_prompt
,
split_enc_dec_inputs
from
vllm.inputs.parse
import
split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logits_process
import
get_bad_words_logits_processors
from
vllm.logits_process
import
get_bad_words_logits_processors
...
@@ -759,11 +759,6 @@ class LLMEngine:
...
@@ -759,11 +759,6 @@ class LLMEngine:
seq_len
=
prompt
[
"prompt_embeds"
].
shape
[
0
]
seq_len
=
prompt
[
"prompt_embeds"
].
shape
[
0
]
prompt
[
"prompt_token_ids"
]
=
[
0
]
*
seq_len
prompt
[
"prompt_token_ids"
]
=
[
0
]
*
seq_len
if
self
.
tokenizer
is
not
None
:
self
.
_validate_token_prompt
(
prompt
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
))
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
...
@@ -782,27 +777,6 @@ class LLMEngine:
...
@@ -782,27 +777,6 @@ class LLMEngine:
priority
=
priority
,
priority
=
priority
,
)
)
def
_validate_token_prompt
(
self
,
prompt
:
PromptType
,
tokenizer
:
AnyTokenizer
):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if
is_token_prompt
(
prompt
):
prompt_ids
=
prompt
[
"prompt_token_ids"
]
if
len
(
prompt_ids
)
==
0
:
# Empty prompt check is handled later
return
max_input_id
=
max
(
prompt_ids
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
"Token id {} is out of vocabulary"
.
format
(
max_input_id
))
def
_create_sequence_group_with_sampling
(
def
_create_sequence_group_with_sampling
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -2049,6 +2023,12 @@ class LLMEngine:
...
@@ -2049,6 +2023,12 @@ class LLMEngine:
else
:
else
:
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
raise
ValueError
(
f
"The
{
prompt_type
}
prompt cannot be empty"
)
if
tokenizer
is
not
None
:
max_input_id
=
max
(
prompt_ids
,
default
=
0
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
f
"Token id
{
max_input_id
}
is out of vocabulary"
)
max_prompt_len
=
self
.
model_config
.
max_model_len
max_prompt_len
=
self
.
model_config
.
max_model_len
if
len
(
prompt_ids
)
>
max_prompt_len
:
if
len
(
prompt_ids
)
>
max_prompt_len
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
...
...
vllm/engine/protocol.py
View file @
cb234955
...
@@ -83,6 +83,9 @@ class EngineClient(ABC):
...
@@ -83,6 +83,9 @@ class EngineClient(ABC):
else
:
else
:
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
if
processed_inputs
[
"type"
]
==
"embeds"
:
raise
NotImplementedError
prompt_token_ids
=
processed_inputs
[
"prompt_token_ids"
]
prompt_token_ids
=
processed_inputs
[
"prompt_token_ids"
]
prompt_text
=
processed_inputs
.
get
(
"prompt"
)
prompt_text
=
processed_inputs
.
get
(
"prompt"
)
multi_modal_data
=
processed_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
processed_inputs
.
get
(
"multi_modal_data"
)
...
...
vllm/entrypoints/llm.py
View file @
cb234955
...
@@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
...
@@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens
)
_validate_score_input_lens
)
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
PromptType
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
is_token_prompt
,
parse_and_batch_prompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding.guided_fields
import
(
from
vllm.model_executor.guided_decoding.guided_fields
import
(
...
@@ -567,10 +567,12 @@ class LLM:
...
@@ -567,10 +567,12 @@ class LLM:
mm_kwargs
[
"mm_processor_kwargs"
]
=
prompt
[
mm_kwargs
[
"mm_processor_kwargs"
]
=
prompt
[
"mm_processor_kwargs"
]
"mm_processor_kwargs"
]
if
is_token_prompt
(
prompt
):
if
"prompt_token_ids"
in
prompt
:
prompt
=
cast
(
TokensPrompt
,
prompt
)
# Needed for mypy
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
,
logprobs
=
None
,
**
mm_kwargs
))
BeamSearchInstance
(
prompt_tokens
,
logprobs
=
None
,
**
mm_kwargs
))
...
...
vllm/inputs/data.py
View file @
cb234955
...
@@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
...
@@ -70,6 +70,11 @@ class EmbedsPrompt(TypedDict):
prompt_embeds
:
torch
.
Tensor
prompt_embeds
:
torch
.
Tensor
"""The embeddings of the prompt."""
"""The embeddings of the prompt."""
cache_salt
:
NotRequired
[
str
]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
EmbedsPrompt
]
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
EmbedsPrompt
]
"""
"""
...
@@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
...
@@ -195,13 +200,21 @@ class EmbedsInputs(TypedDict):
prompt_embeds
:
torch
.
Tensor
prompt_embeds
:
torch
.
Tensor
"""The embeddings of the prompt."""
"""The embeddings of the prompt."""
cache_salt
:
NotRequired
[
str
]
"""
Optional cache salt to be used for prefix caching.
"""
def
embeds_inputs
(
prompt_embeds
:
torch
.
Tensor
)
->
EmbedsInputs
:
def
embeds_inputs
(
prompt_embeds
:
torch
.
Tensor
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
EmbedsInputs
:
"""Construct :class:`EmbedsInputs` from optional values."""
"""Construct :class:`EmbedsInputs` from optional values."""
inputs
=
EmbedsInputs
(
inputs
=
EmbedsInputs
(
type
=
"embeds"
,
prompt_embeds
=
prompt_embeds
)
type
=
"embeds"
,
prompt_embeds
=
prompt_embeds
,
if
cache_salt
is
not
None
:
)
inputs
[
"cache_salt"
]
=
cache_salt
return
inputs
return
inputs
...
...
vllm/inputs/parse.py
View file @
cb234955
...
@@ -6,9 +6,9 @@ from typing_extensions import TypeIs
...
@@ -6,9 +6,9 @@ from typing_extensions import TypeIs
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
.data
import
(
EmbedsInputs
,
EmbedsPrompt
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
EmbedsPrompt
,
ExplicitEncoderDecoderPrompt
,
ProcessorInputs
,
ProcessorInputs
,
PromptType
,
SingletonInputs
,
PromptType
,
SingletonInputs
,
SingletonPrompt
,
TextPrompt
,
SingletonPrompt
,
TextPrompt
,
TokensPrompt
)
TokensPrompt
)
class
ParsedText
(
TypedDict
):
class
ParsedText
(
TypedDict
):
...
@@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
...
@@ -90,6 +90,10 @@ class ParsedEmbedsPrompt(TypedDict):
content
:
EmbedsPrompt
content
:
EmbedsPrompt
ParsedSingletonPrompt
=
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
,
ParsedEmbedsPrompt
]
@
overload
@
overload
def
parse_singleton_prompt
(
prompt
:
str
)
->
ParsedStrPrompt
:
def
parse_singleton_prompt
(
prompt
:
str
)
->
ParsedStrPrompt
:
...
...
...
@@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
@@ -110,10 +114,7 @@ def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
...
def
parse_singleton_prompt
(
def
parse_singleton_prompt
(
prompt
:
SingletonPrompt
)
->
ParsedSingletonPrompt
:
prompt
:
SingletonPrompt
,
)
->
Union
[
ParsedStrPrompt
,
ParsedTextPrompt
,
ParsedTokensPrompt
,
ParsedEmbedsPrompt
]:
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
return
ParsedStrPrompt
(
type
=
"str"
,
content
=
prompt
)
elif
isinstance
(
prompt
,
dict
):
elif
isinstance
(
prompt
,
dict
):
...
@@ -131,23 +132,11 @@ def parse_singleton_prompt(
...
@@ -131,23 +132,11 @@ def parse_singleton_prompt(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
def
is_token_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
TokensPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"prompt_token_ids"
in
prompt
def
is_embeds_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
EmbedsPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"prompt_embeds"
in
prompt
def
is_explicit_encoder_decoder_prompt
(
def
is_explicit_encoder_decoder_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
prompt
:
PromptType
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
def
is_embeds_inputs
(
inputs
:
SingletonInputs
)
->
TypeIs
[
EmbedsInputs
]:
return
isinstance
(
inputs
,
dict
)
and
inputs
[
"type"
]
==
"embeds"
def
split_enc_dec_inputs
(
def
split_enc_dec_inputs
(
inputs
:
ProcessorInputs
,
inputs
:
ProcessorInputs
,
)
->
tuple
[
Optional
[
SingletonInputs
],
SingletonInputs
]:
)
->
tuple
[
Optional
[
SingletonInputs
],
SingletonInputs
]:
...
...
vllm/inputs/preprocess.py
View file @
cb234955
...
@@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...
@@ -14,14 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
.data
import
(
DecoderOnlyInputs
,
EmbedsInputs
,
EncoderDecoderInputs
,
from
.data
import
(
DecoderOnlyInputs
,
EmbedsInputs
,
EmbedsPrompt
,
ProcessorInputs
,
PromptType
,
SingletonInputs
,
EncoderDecoderInputs
,
ProcessorInputs
,
PromptType
,
SingletonPrompt
,
TokenInputs
,
embeds_inputs
,
token_inputs
)
SingletonInputs
,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
from
.parse
import
(
ParsedEmbedsPrompt
,
ParsedStrPrompt
,
ParsedTextPrompt
,
TokensPrompt
,
embeds_inputs
,
token_inputs
)
ParsedTokensPrompt
,
is_embeds_inputs
,
from
.parse
import
is_explicit_encoder_decoder_prompt
,
parse_singleton_prompt
is_explicit_encoder_decoder_prompt
,
parse_singleton_prompt
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -140,13 +140,10 @@ class InputPreprocessor:
...
@@ -140,13 +140,10 @@ class InputPreprocessor:
"""
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
https://github.com/huggingface/transformers/blob/
specifically,
4037a2b5b1278736e566aec12e169100275545ea/
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
Arguments:
...
@@ -183,6 +180,23 @@ class InputPreprocessor:
...
@@ -183,6 +180,23 @@ class InputPreprocessor:
return
prompt_token_ids
return
prompt_token_ids
def
_get_tokenization_kw
(
self
,
overrides
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
dict
[
str
,
Any
]:
kwargs
=
dict
[
str
,
Any
]()
if
self
.
model_config
.
hf_config
.
model_type
==
"whisper"
:
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs
[
"add_special_tokens"
]
=
False
if
overrides
:
kwargs
.
update
(
overrides
)
return
kwargs
def
_tokenize_prompt
(
def
_tokenize_prompt
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
...
@@ -194,18 +208,11 @@ class InputPreprocessor:
...
@@ -194,18 +208,11 @@ class InputPreprocessor:
corresponding token IDs.
corresponding token IDs.
"""
"""
tokenizer
=
self
.
get_tokenizer_group
()
tokenizer
=
self
.
get_tokenizer_group
()
if
tokenization_kwargs
is
None
:
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
tokenization_kwargs
=
{}
if
self
.
model_config
.
hf_config
.
model_type
==
"whisper"
:
encoder_config
=
self
.
model_config
.
encoder_config
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs
[
"add_special_tokens"
]
=
False
if
(
self
.
model_config
.
encoder_config
is
not
None
if
encoder_config
and
encoder_config
.
get
(
"do_lower_case"
,
False
):
and
self
.
model_config
.
encoder_config
.
get
(
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
prompt
=
prompt
.
lower
()
return
tokenizer
.
encode
(
prompt
=
prompt
,
return
tokenizer
.
encode
(
prompt
=
prompt
,
...
@@ -220,18 +227,36 @@ class InputPreprocessor:
...
@@ -220,18 +227,36 @@ class InputPreprocessor:
)
->
list
[
int
]:
)
->
list
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer
=
self
.
get_tokenizer_group
()
tokenizer
=
self
.
get_tokenizer_group
()
if
tokenization_kwargs
is
None
:
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
tokenization_kwargs
=
{}
if
self
.
model_config
.
hf_config
.
model_type
==
"whisper"
:
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs
[
"add_special_tokens"
]
=
False
return
await
tokenizer
.
encode_async
(
prompt
=
prompt
,
return
await
tokenizer
.
encode_async
(
prompt
=
prompt
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
**
tokenization_kwargs
)
**
tokenization_kwargs
)
def
_get_mm_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
],
)
->
AnyTokenizer
:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if
not
self
.
tokenizer
:
return
cast
(
AnyTokenizer
,
object
())
# Dummy
tokenizer_group
=
self
.
get_tokenizer_group
()
return
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
async
def
_get_mm_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
],
)
->
AnyTokenizer
:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if
not
self
.
tokenizer
:
return
cast
(
AnyTokenizer
,
object
())
# Dummy
tokenizer_group
=
self
.
get_tokenizer_group
()
return
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
def
_process_multimodal
(
def
_process_multimodal
(
self
,
self
,
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
...
@@ -244,13 +269,7 @@ class InputPreprocessor:
...
@@ -244,13 +269,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
returning the corresponding token IDs and metadata.
"""
"""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer
=
self
.
_get_mm_tokenizer
(
lora_request
)
# initialized without a tokenizer while using also multi-modal input
if
not
self
.
tokenizer
:
tokenizer
=
object
()
# Dummy
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
tokenizer
=
tokenizer
)
...
@@ -270,14 +289,7 @@ class InputPreprocessor:
...
@@ -270,14 +289,7 @@ class InputPreprocessor:
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
"""Async version of :meth:`_process_multimodal`."""
"""Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
# initialized without a tokenizer while using also multi-modal input
if
not
self
.
tokenizer
:
tokenizer
=
object
()
# Dummy
else
:
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
tokenizer
=
tokenizer
)
...
@@ -287,28 +299,160 @@ class InputPreprocessor:
...
@@ -287,28 +299,160 @@ class InputPreprocessor:
return
mm_processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
,
return
mm_processor
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
,
return_mm_hashes
)
return_mm_hashes
)
def
_get_prompt_data
(
self
,
parsed_prompt
:
Union
[
ParsedStrPrompt
,
def
_process_embeds
(
ParsedTextPrompt
,
self
,
ParsedTokensPrompt
]):
parsed_content
:
EmbedsPrompt
,
prompt_text
=
None
)
->
EmbedsInputs
:
prompt_token_ids
=
None
if
envs
.
VLLM_USE_V1
:
token_type_ids
=
None
raise
ValueError
(
"prompt_embeds is only available in V0."
)
cache_salt
=
None
prompt_embeds
=
parsed_content
[
"prompt_embeds"
]
if
parsed_prompt
[
"type"
]
==
"str"
:
# prompt_embeds must be (seq_len, hidden_size), but if the user
prompt_text
=
parsed_prompt
[
"content"
]
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if
prompt_embeds
.
ndim
==
3
:
prompt_embeds
=
prompt_embeds
.
squeeze
(
dim
=
0
)
if
prompt_embeds
.
ndim
!=
2
:
raise
ValueError
(
"prompt_embeds must be of shape (seq_len, hidden_size)."
)
return
embeds_inputs
(
prompt_embeds
=
prompt_embeds
,
cache_salt
=
parsed_content
.
get
(
"cache_salt"
))
async
def
_process_embeds_async
(
self
,
parsed_content
:
EmbedsPrompt
,
)
->
EmbedsInputs
:
return
self
.
_process_embeds
(
parsed_content
)
def
_process_tokens
(
self
,
parsed_content
:
TokensPrompt
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_token_ids
=
parsed_content
[
"prompt_token_ids"
]
token_type_ids
=
parsed_content
.
get
(
"token_type_ids"
)
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
inputs
=
self
.
_process_multimodal
(
prompt_token_ids
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
else
:
else
:
cache_salt
=
parsed_prompt
[
"content"
].
get
(
"cache_salt"
)
inputs
=
token_inputs
(
if
parsed_prompt
[
"type"
]
==
"text"
:
prompt_token_ids
=
prompt_token_ids
,
prompt_text
=
parsed_prompt
[
"content"
][
"prompt"
]
token_type_ids
=
token_type_ids
,
elif
parsed_prompt
[
"type"
]
==
"tokens"
:
)
prompt_token_ids
=
parsed_prompt
[
"content"
].
get
(
"prompt_token_ids"
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
token_type_ids
=
parsed_prompt
[
"content"
].
get
(
"token_type_ids"
)
inputs
[
"cache_salt"
]
=
cache_salt
else
:
assert_never
(
parsed_prompt
)
return
inputs
async
def
_process_tokens_async
(
self
,
parsed_content
:
TokensPrompt
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_token_ids
=
parsed_content
[
"prompt_token_ids"
]
token_type_ids
=
parsed_content
.
get
(
"token_type_ids"
)
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
inputs
=
await
self
.
_process_multimodal_async
(
prompt_token_ids
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
else
:
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
inputs
[
"cache_salt"
]
=
cache_salt
return
inputs
def
_process_text
(
self
,
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_text
=
parsed_content
[
"prompt"
]
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
inputs
=
self
.
_process_multimodal
(
prompt_text
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
else
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
inputs
=
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
inputs
[
"cache_salt"
]
=
cache_salt
return
inputs
return
prompt_text
,
prompt_token_ids
,
token_type_ids
,
cache_salt
async
def
_process_text_async
(
self
,
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_text
=
parsed_content
[
"prompt"
]
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
inputs
=
await
self
.
_process_multimodal_async
(
prompt_text
,
multi_modal_data
,
parsed_content
.
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
else
:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
lora_request
=
lora_request
,
tokenization_kwargs
=
tokenization_kwargs
,
)
inputs
=
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
inputs
[
"cache_salt"
]
=
cache_salt
return
inputs
def
_prompt_to_llm_inputs
(
def
_prompt_to_llm_inputs
(
self
,
self
,
...
@@ -333,39 +477,28 @@ class InputPreprocessor:
...
@@ -333,39 +477,28 @@ class InputPreprocessor:
parsed
=
parse_singleton_prompt
(
prompt
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"embeds"
:
if
parsed
[
"type"
]
==
"embeds"
:
return
self
.
_process_prompt_embeds
(
parsed
)
return
self
.
_process_embeds
(
parsed
[
"content"
])
if
parsed
[
"type"
]
==
"tokens"
:
prompt_text
,
prompt_token_ids
,
token_type_ids
,
cache_salt
=
\
return
self
.
_process_tokens
(
self
.
_get_prompt_data
(
parsed
)
parsed
[
"content"
],
# If multimodal data is present, process and return immediately
if
parsed
[
"type"
]
!=
"str"
and
parsed
[
"content"
].
get
(
"multi_modal_data"
)
is
not
None
:
inputs
=
self
.
_process_multimodal
(
prompt_text
if
prompt_text
is
not
None
else
prompt_token_ids
,
parsed
[
"content"
][
"multi_modal_data"
],
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
if
cache_salt
is
not
None
:
if
parsed
[
"type"
]
==
"text"
:
inputs
[
"cache_salt"
]
=
cache_salt
return
self
.
_process_text
(
return
inputs
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
if
prompt_token_ids
is
None
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
if
parsed
[
"type"
]
==
"str"
:
return
self
.
_process_text
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
cache_salt
=
cache_salt
,
)
assert_never
(
parsed
)
assert_never
(
parsed
)
async
def
_prompt_to_llm_inputs_async
(
async
def
_prompt_to_llm_inputs_async
(
...
@@ -375,79 +508,49 @@ class InputPreprocessor:
...
@@ -375,79 +508,49 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
)
->
SingletonInputs
:
"""Async version of :meth:`_
extract_prompt_componen
ts`."""
"""Async version of :meth:`_
prompt_to_llm_inpu
ts`."""
parsed
=
parse_singleton_prompt
(
prompt
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"embeds"
:
if
parsed
[
"type"
]
==
"embeds"
:
return
self
.
_process_prompt_embeds
(
parsed
)
return
await
self
.
_process_embeds_async
(
parsed
[
"content"
])
if
parsed
[
"type"
]
==
"tokens"
:
prompt_text
,
prompt_token_ids
,
token_type_ids
,
cache_salt
=
\
return
await
self
.
_process_tokens_async
(
self
.
_get_prompt_data
(
parsed
)
parsed
[
"content"
],
if
parsed
[
"type"
]
!=
"str"
and
parsed
[
"content"
].
get
(
"multi_modal_data"
)
is
not
None
:
inputs
=
await
self
.
_process_multimodal_async
(
prompt_token_ids
if
prompt_text
is
None
else
prompt_text
,
parsed
[
"content"
][
"multi_modal_data"
],
parsed
[
"content"
].
get
(
"mm_processor_kwargs"
),
lora_request
=
lora_request
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
if
cache_salt
is
not
None
:
if
parsed
[
"type"
]
==
"text"
:
inputs
[
"cache_salt"
]
=
cache_salt
return
await
self
.
_process_text_async
(
return
inputs
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
if
prompt_token_ids
is
None
:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
if
parsed
[
"type"
]
==
"str"
:
return
await
self
.
_process_text_async
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
cache_salt
=
cache_salt
,
)
def
_process_prompt_embeds
(
self
,
parsed
:
ParsedEmbedsPrompt
)
->
EmbedsInputs
:
if
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"prompt_embeds is only available in V0."
)
prompt_embeds_content
=
parsed
[
"content"
]
prompt_embeds
=
prompt_embeds_content
[
"prompt_embeds"
]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if
prompt_embeds
.
ndim
==
3
and
prompt_embeds
.
shape
[
0
]
==
1
:
prompt_embeds
=
prompt_embeds
.
squeeze
(
dim
=
0
)
if
prompt_embeds
.
ndim
!=
2
:
raise
ValueError
(
"prompt_embeds must be of shape (seq_len, hidden_size)."
)
return
embeds_inputs
(
prompt_embeds
=
prompt_embeds
)
assert_never
(
parsed
)
assert_never
(
parsed
)
def
_build_enc_dec_llm_inputs
(
def
_build_enc_dec_llm_inputs
(
self
,
self
,
encoder_inputs
:
Union
[
TokenInputs
,
MultiModal
Inputs
]
,
encoder_inputs
:
Singleton
Inputs
,
decoder_inputs
:
Optional
[
Union
[
TokenInputs
,
MultiModal
Inputs
]
]
,
decoder_inputs
:
Optional
[
Singleton
Inputs
],
)
->
EncoderDecoderInputs
:
)
->
EncoderDecoderInputs
:
if
(
encoder_inputs
[
"type"
]
==
"token"
if
(
encoder_inputs
[
"type"
]
==
"embeds"
or
encoder_inputs
[
"type"
]
==
"multimodal"
):
or
decoder_inputs
and
decoder_inputs
[
"type"
]
==
"embeds"
):
pass
raise
ValueError
(
"Embedding inputs are not supported for encoder-"
else
:
"decoder models"
)
assert_never
(
encoder_inputs
)
# type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible
# Needed for mypy
assert
"prompt_token_ids"
in
encoder_inputs
encoder_inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
],
encoder_inputs
)
decoder_inputs
=
cast
(
Optional
[
Union
[
TokenInputs
,
MultiModalInputs
]],
decoder_inputs
)
if
decoder_inputs
is
None
:
if
decoder_inputs
is
None
:
if
self
.
model_config
.
hf_config
.
model_type
==
"whisper"
:
if
self
.
model_config
.
hf_config
.
model_type
==
"whisper"
:
...
@@ -460,74 +563,78 @@ class InputPreprocessor:
...
@@ -460,74 +563,78 @@ class InputPreprocessor:
dec_token_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
dec_token_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
None
)
None
)
decoder_inputs
=
token_inputs
(
dec_token_ids
)
decoder_inputs
=
token_inputs
(
dec_token_ids
)
elif
(
decoder_inputs
[
"type"
]
==
"token"
else
:
or
decoder_inputs
[
"type"
]
==
"multimodal"
):
dec_token_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_inputs
[
"prompt_token_ids"
])
decoder_inputs
[
"prompt_token_ids"
]
=
dec_token_ids
if
"multi_modal_data"
in
decoder_inputs
:
if
"multi_modal_data"
in
decoder_inputs
:
raise
ValueError
(
"Multi-modal decoder inputs of encoder-"
raise
ValueError
(
"Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet"
)
"decoder models are not supported yet"
)
else
:
assert_never
(
encoder_inputs
)
# type: ignore[arg-type]
dec_token_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_inputs
[
"prompt_token_ids"
])
decoder_inputs
[
"prompt_token_ids"
]
=
dec_token_ids
return
EncoderDecoderInputs
(
return
EncoderDecoderInputs
(
encoder
=
encoder_inputs
,
encoder
=
encoder_inputs
,
decoder
=
decoder_inputs
,
decoder
=
decoder_inputs
,
)
)
def
_s
eparate
_enc_dec_
inputs_from_mm_processor_out
puts
(
def
_s
plit
_enc_dec_
mm_in
puts
(
self
,
self
,
inputs
:
SingletonInputs
,
inputs
:
Union
[
SingletonInputs
,
MultiModalEncDecInputs
],
decoder_inputs_to_override
:
Optional
[
Union
[
TokenInputs
,
decoder_inputs_to_override
:
Optional
[
SingletonInputs
]
=
None
,
MultiModalInputs
]]
=
None
,
)
->
tuple
[
SingletonInputs
,
SingletonInputs
]:
)
->
tuple
[
SingletonInputs
,
SingletonInputs
]:
"""
"""
For encoder/decoder models only:
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
"""
if
(
inputs
[
"type"
]
==
"embeds"
or
decoder_inputs_to_override
and
decoder_inputs_to_override
[
"type"
]
==
"embeds"
):
raise
ValueError
(
"Embedding inputs are not supported for encoder-"
"decoder models"
)
# Needed for mypy
inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
,
MultiModalEncDecInputs
],
inputs
,
)
decoder_inputs_to_override
=
cast
(
Optional
[
Union
[
TokenInputs
,
MultiModalInputs
]],
decoder_inputs_to_override
,
)
encoder_inputs
:
SingletonInputs
encoder_inputs
:
SingletonInputs
decoder_inputs
:
SingletonInputs
decoder_inputs
:
SingletonInputs
if
inputs
[
"type"
]
==
"multimodal"
:
# Multimodal data inputs
if
inputs
[
"type"
]
==
"multimodal"
:
# Multimodal data inputs
assert
(
"encoder_prompt"
in
inputs
if
not
(
"encoder_prompt"
in
inputs
and
"encoder_prompt_token_ids"
in
inputs
)
and
"encoder_prompt_token_ids"
in
inputs
):
raise
RuntimeError
(
"You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models."
)
inputs
=
cast
(
MultiModalEncDecInputs
,
inputs
)
inputs
=
cast
(
MultiModalEncDecInputs
,
inputs
)
encoder_inputs
=
token_inputs
(
encoder_inputs
=
token_inputs
(
prompt
=
inputs
[
"encoder_prompt"
],
prompt
=
inputs
[
"encoder_prompt"
],
prompt_token_ids
=
inputs
[
"encoder_prompt_token_ids"
],
prompt_token_ids
=
inputs
[
"encoder_prompt_token_ids"
],
)
)
if
decoder_inputs_to_override
is
not
None
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
decoder_inputs_to_override
.
get
(
"prompt"
,
""
),
prompt_token_ids
=
decoder_inputs_to_override
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_hashes
=
inputs
[
"mm_hashes"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
else
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
inputs
[
"prompt"
],
prompt_token_ids
=
inputs
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_hashes
=
inputs
[
"mm_hashes"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
cache_salt
=
inputs
.
get
(
"cache_salt"
)
decoder_prompt_inputs
=
decoder_inputs_to_override
or
inputs
if
cache_salt
is
not
None
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
decoder_prompt_inputs
.
get
(
"prompt"
,
""
),
prompt_token_ids
=
decoder_prompt_inputs
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_hashes
=
inputs
[
"mm_hashes"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
if
cache_salt
:
=
inputs
.
get
(
"cache_salt"
):
decoder_inputs
[
"cache_salt"
]
=
cache_salt
decoder_inputs
[
"cache_salt"
]
=
cache_salt
elif
inputs
[
"type"
]
==
"token"
:
elif
inputs
[
"type"
]
==
"token"
:
# Text-only inputs
# Text-only inputs
encoder_inputs
=
token_inputs
(
prompt
=
""
,
prompt_token_ids
=
[])
encoder_inputs
=
token_inputs
(
prompt
=
""
,
prompt_token_ids
=
[])
decoder_inputs
=
decoder_inputs_to_override
or
inputs
decoder_inputs
=
decoder_inputs_to_override
or
inputs
else
:
else
:
assert_never
(
inputs
)
# type: ignore[arg-type]
assert_never
(
inputs
)
# type: ignore[arg-type]
return
encoder_inputs
,
decoder_inputs
return
encoder_inputs
,
decoder_inputs
def
_process_encoder_decoder_prompt
(
def
_process_encoder_decoder_prompt
(
...
@@ -580,11 +687,9 @@ class InputPreprocessor:
...
@@ -580,11 +687,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
assert
decoder_inputs
is
None
or
not
is_embeds_inputs
(
decoder_inputs
)
encoder_inputs
,
decoder_inputs
=
(
encoder_inputs
,
decoder_inputs
=
(
self
.
_s
eparate
_enc_dec_inputs
_from_mm_processor_out
puts
(
self
.
_s
plit
_enc_dec_
mm_
inputs
(
encoder_in
puts
,
encoder_inputs
,
decoder_inputs
))
decoder_inputs
))
else
:
else
:
inputs
=
self
.
_prompt_to_llm_inputs
(
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
prompt
,
...
@@ -593,16 +698,11 @@ class InputPreprocessor:
...
@@ -593,16 +698,11 @@ class InputPreprocessor:
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
self
.
_split_enc_dec_mm_inputs
(
inputs
))
inputs
))
else
:
else
:
encoder_inputs
=
inputs
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert
not
is_embeds_inputs
(
encoder_inputs
)
assert
decoder_inputs
is
None
or
not
is_embeds_inputs
(
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
async
def
_process_encoder_decoder_prompt_async
(
async
def
_process_encoder_decoder_prompt_async
(
...
@@ -635,11 +735,9 @@ class InputPreprocessor:
...
@@ -635,11 +735,9 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
assert
decoder_inputs
is
None
or
not
is_embeds_inputs
(
decoder_inputs
)
encoder_inputs
,
decoder_inputs
=
(
encoder_inputs
,
decoder_inputs
=
(
self
.
_s
eparate
_enc_dec_inputs
_from_mm_processor_out
puts
(
self
.
_s
plit
_enc_dec_
mm_
inputs
(
encoder_in
puts
,
encoder_inputs
,
decoder_inputs
))
decoder_inputs
))
else
:
else
:
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
prompt
,
...
@@ -648,16 +746,11 @@ class InputPreprocessor:
...
@@ -648,16 +746,11 @@ class InputPreprocessor:
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
self
.
_split_enc_dec_mm_inputs
(
inputs
))
inputs
))
else
:
else
:
encoder_inputs
=
inputs
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert
not
is_embeds_inputs
(
encoder_inputs
)
assert
decoder_inputs
is
None
or
not
is_embeds_inputs
(
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
def
_build_decoder_only_llm_inputs
(
def
_build_decoder_only_llm_inputs
(
...
@@ -665,19 +758,13 @@ class InputPreprocessor:
...
@@ -665,19 +758,13 @@ class InputPreprocessor:
prompt_inputs
:
DecoderOnlyInputs
,
prompt_inputs
:
DecoderOnlyInputs
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
if
(
prompt_inputs
[
"type"
]
==
"token"
if
"prompt_token_ids"
in
prompt_inputs
:
or
prompt_inputs
[
"type"
]
==
"multimodal"
):
prompt_inputs
=
cast
(
Union
[
TokenInputs
,
MultiModalInputs
],
# Mypy does not do type inference well with typedicts and Literal
prompt_inputs
)
# Needed for mypy
# values
assert
not
is_embeds_inputs
(
prompt_inputs
)
prompt_inputs
[
"prompt_token_ids"
]
=
self
.
_apply_prompt_adapter
(
prompt_inputs
[
"prompt_token_ids"
]
=
self
.
_apply_prompt_adapter
(
prompt_inputs
[
"prompt_token_ids"
],
prompt_inputs
[
"prompt_token_ids"
],
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
)
elif
(
prompt_inputs
[
"type"
]
==
"embeds"
):
pass
else
:
assert_never
(
prompt_inputs
)
# type: ignore[arg-type]
return
prompt_inputs
return
prompt_inputs
...
...
vllm/multimodal/processing.py
View file @
cb234955
...
@@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1670,15 +1670,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
placeholders
=
mm_placeholders
.
get
(
modality
,
[])
placeholders
=
mm_placeholders
.
get
(
modality
,
[])
if
len
(
placeholders
)
!=
item_count
:
if
len
(
placeholders
)
!=
item_count
:
# NOTE: If you are a model developer, this can also arise from
# an inconsistency between `_call_hf_processor` and
# `_get_mm_fields_config` implementations
raise
RuntimeError
(
raise
RuntimeError
(
f
"Expected there to be
{
item_count
}
prompt updates "
f
"Expected there to be
{
item_count
}
prompt updates "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"instead found
{
len
(
placeholders
)
}
prompt updates! "
f
"instead found
{
len
(
placeholders
)
}
prompt updates! "
"Either the prompt text has missing/incorrect tokens for "
"This is likely because you forgot to include input "
"multi-modal inputs, or there is a problem with your "
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"implementation of merged multi-modal processor for this "
"in the prompt. If the model has a chat template, make "
"model (usually arising from an inconsistency between "
"sure you have applied it before calling `LLM.generate`."
)
"`_call_hf_processor` and `_get_prompt_updates`)."
)
def
_maybe_apply_prompt_updates
(
def
_maybe_apply_prompt_updates
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment