Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc55d130
Unverified
Commit
bc55d130
authored
Feb 13, 2025
by
Isotr0py
Committed by
GitHub
Feb 12, 2025
Browse files
[VLM] Implement merged multimodal processor for Mllama (#11427)
parent
d88c8666
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
456 additions
and
233 deletions
+456
-233
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+67
-4
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+11
-2
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+83
-7
vllm/inputs/registry.py
vllm/inputs/registry.py
+2
-1
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+203
-205
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+16
-0
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+57
-3
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+17
-11
No files found.
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
bc55d130
...
@@ -7,11 +7,11 @@ import torch
...
@@ -7,11 +7,11 @@ import torch
from
transformers
import
(
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
,
from
transformers
import
(
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
)
BatchEncoding
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
global_force_attn_backend_context_manager
)
from
vllm.model_executor.models.mllama
import
(
MLLAMA_IMAGE_TOKEN_ID
,
from
vllm.model_executor.models.mllama
import
MllamaForConditionalGeneration
MllamaForConditionalGeneration
)
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
...
@@ -21,6 +21,7 @@ from ....utils import large_gpu_test
...
@@ -21,6 +21,7 @@ from ....utils import large_gpu_test
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
3
_LIMIT_IMAGE_PER_PROMPT
=
3
MLLAMA_IMAGE_TOKEN_ID
=
128256
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
...
@@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
...
@@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
def
test_explicit_implicit_prompt
(
image_assets
:
_ImageAssets
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
):
stop_sign
=
image_assets
[
0
].
pil_image
# yapf: disable
prompts
=
[
# explicit prompt
{
"encoder_prompt"
:
{
"prompt"
:
"<|image|>"
,
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
"decoder_prompt"
:
{
"prompt_token_ids"
:
[
128000
,
791
,
2262
,
315
,
279
,
2217
,
220
,
128256
,
374
],
# noqa: E501
}
},
{
"encoder_prompt"
:
"Not <|image|>"
,
"decoder_prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
# implicit prompt
{
"prompt"
:
"<|begin_of_text|>The content of the image <|image|> is"
,
# noqa: E501
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
{
"prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
]
# yapf: enable
llm
=
LLM
(
model
=
model
,
dtype
=
dtype
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
enforce_eager
=
True
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
max_tokens
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
n_prompts
=
len
(
prompts
)
explicit_outputs
=
outputs
[:
n_prompts
//
2
]
implicit_outputs
=
outputs
[
n_prompts
//
2
:]
for
exp_output
,
imp_output
in
zip
(
explicit_outputs
,
implicit_outputs
):
assert
exp_output
.
outputs
[
0
].
text
==
imp_output
.
outputs
[
0
].
text
@
large_gpu_test
(
min_gb
=
48
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
...
@@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
...
@@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
images
=
images
)
images
=
images
)
class
DummyModel
:
image_token_id
=
MLLAMA_IMAGE_TOKEN_ID
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"input_indices_and_output"
,
"input_indices_and_output"
,
...
@@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
...
@@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
)
dummy
:
dict
[
str
,
str
]
=
{}
dummy
=
DummyModel
()
cross_attention_mask
,
kv_range_for_decode
=
MllamaForConditionalGeneration
\
cross_attention_mask
,
kv_range_for_decode
=
MllamaForConditionalGeneration
\
.
get_cross_attention_mask
(
dummy
,
.
get_cross_attention_mask
(
dummy
,
...
@@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
...
@@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
)
dummy
:
dict
[
str
,
str
]
=
{}
dummy
=
DummyModel
()
full_text_row_masked_out_mask
=
MllamaForConditionalGeneration
\
full_text_row_masked_out_mask
=
MllamaForConditionalGeneration
\
.
get_full_text_row_masked_out_mask
(
dummy
,
.
get_full_text_row_masked_out_mask
(
dummy
,
...
...
tests/models/multimodal/processing/test_common.py
View file @
bc55d130
...
@@ -85,6 +85,14 @@ def _test_processing_correctness(
...
@@ -85,6 +85,14 @@ def _test_processing_correctness(
partial
(
random_audio
,
rng
,
min_len
=
512
,
max_len
=
1024
,
sr
=
16000
),
partial
(
random_audio
,
rng
,
min_len
=
512
,
max_len
=
1024
,
sr
=
16000
),
}
}
tokenizer_encode_kwargs
=
{}
if
model_config
.
hf_config
.
model_type
==
"mllama"
:
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs
=
{
"add_special_tokens"
:
False
}
for
batch_idx
in
range
(
num_batches
):
for
batch_idx
in
range
(
num_batches
):
mm_data
=
{
mm_data
=
{
k
:
k
:
...
@@ -122,7 +130,7 @@ def _test_processing_correctness(
...
@@ -122,7 +130,7 @@ def _test_processing_correctness(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
baseline_tokenized_result
=
baseline_processor
.
apply
(
baseline_tokenized_result
=
baseline_processor
.
apply
(
tokenizer
.
encode
(
prompt
),
tokenizer
.
encode
(
prompt
,
**
tokenizer_encode_kwargs
),
mm_data
=
mm_data
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
hf_processor_mm_kwargs
=
{},
)
)
...
@@ -131,7 +139,7 @@ def _test_processing_correctness(
...
@@ -131,7 +139,7 @@ def _test_processing_correctness(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
cached_tokenized_result
=
cached_processor
.
apply
(
cached_tokenized_result
=
cached_processor
.
apply
(
tokenizer
.
encode
(
prompt
),
tokenizer
.
encode
(
prompt
,
**
tokenizer_encode_kwargs
),
mm_data
=
mm_data
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
hf_processor_mm_kwargs
=
{},
)
)
...
@@ -155,6 +163,7 @@ def _test_processing_correctness(
...
@@ -155,6 +163,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
,
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
,
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
"mistral-community/pixtral-12b"
,
"mistral-community/pixtral-12b"
,
"openbmb/MiniCPM-o-2_6"
,
"openbmb/MiniCPM-o-2_6"
,
...
...
vllm/inputs/preprocess.py
View file @
bc55d130
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
,
cast
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
...
@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
...
@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalInputs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
...
@@ -495,6 +496,51 @@ class InputPreprocessor:
...
@@ -495,6 +496,51 @@ class InputPreprocessor:
decoder
=
decoder_inputs
,
decoder
=
decoder_inputs
,
)
)
def
_separate_enc_dec_inputs_from_mm_processor_outputs
(
self
,
inputs
:
SingletonInputs
,
decoder_inputs_to_override
:
Optional
[
SingletonInputs
]
=
None
,
)
->
Tuple
[
SingletonInputs
,
SingletonInputs
]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs
:
SingletonInputs
decoder_inputs
:
SingletonInputs
if
inputs
[
"type"
]
==
"multimodal"
:
# Multimodal data inputs
assert
(
"encoder_prompt"
in
inputs
and
"encoder_prompt_token_ids"
in
inputs
)
inputs
=
cast
(
MultiModalEncDecInputs
,
inputs
)
encoder_inputs
=
token_inputs
(
prompt
=
inputs
[
"encoder_prompt"
],
prompt_token_ids
=
inputs
[
"encoder_prompt_token_ids"
],
)
if
decoder_inputs_to_override
is
not
None
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
decoder_inputs_to_override
.
get
(
"prompt"
,
""
),
prompt_token_ids
=
decoder_inputs_to_override
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
else
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
inputs
[
"prompt"
],
prompt_token_ids
=
inputs
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
elif
inputs
[
"type"
]
==
"token"
:
# Text-only inputs
encoder_inputs
=
token_inputs
(
prompt
=
""
,
prompt_token_ids
=
[])
decoder_inputs
=
decoder_inputs_to_override
or
inputs
else
:
assert_never
(
inputs
)
# type: ignore[arg-type]
return
encoder_inputs
,
decoder_inputs
def
_process_encoder_decoder_prompt
(
def
_process_encoder_decoder_prompt
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
...
@@ -539,7 +585,6 @@ class InputPreprocessor:
...
@@ -539,7 +585,6 @@ class InputPreprocessor:
prompt
[
"encoder_prompt"
],
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
decoder_inputs
=
None
else
:
else
:
...
@@ -547,13 +592,28 @@ class InputPreprocessor:
...
@@ -547,13 +592,28 @@ class InputPreprocessor:
decoder_input
,
decoder_input
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
else
:
encoder_
inputs
=
self
.
_prompt_to_llm_inputs
(
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
...
@@ -583,13 +643,29 @@ class InputPreprocessor:
...
@@ -583,13 +643,29 @@ class InputPreprocessor:
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
encoder_task
,
decoder_task
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
else
:
encoder_
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
decoder_inputs
=
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_inputs
,
decoder_inputs
)
...
...
vllm/inputs/registry.py
View file @
bc55d130
...
@@ -350,7 +350,8 @@ class InputRegistry:
...
@@ -350,7 +350,8 @@ class InputRegistry:
)
)
processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
)
processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
)
profiler
=
MultiModalProfiler
(
processor
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
,
is_encoder_data
=
is_encoder_data
)
else
:
else
:
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
is_encoder_data
:
if
is_encoder_data
:
...
...
vllm/model_executor/models/mllama.py
View file @
bc55d130
This diff is collapsed.
Click to expand it.
vllm/multimodal/inputs.py
View file @
bc55d130
...
@@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
...
@@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
For each modality, information about the placeholder tokens in
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
:code:`prompt_token_ids`.
"""
"""
class
MultiModalEncDecInputs
(
MultiModalInputs
):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt
:
str
"""The processed encoder prompt text."""
encoder_prompt_token_ids
:
list
[
int
]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids
:
NotRequired
[
list
[
int
]]
"""The token type IDs of the encoder prompt."""
vllm/multimodal/processing.py
View file @
bc55d130
...
@@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
...
@@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
from
vllm.utils
import
LRUCache
,
flatten_2d_lists
,
full_groupby
from
vllm.utils
import
LRUCache
,
flatten_2d_lists
,
full_groupby
from
.hasher
import
MultiModalHasher
from
.hasher
import
MultiModalHasher
from
.inputs
import
(
MultiModalDataDict
,
MultiModal
FieldConfig
,
from
.inputs
import
(
MultiModalDataDict
,
MultiModal
EncDecInputs
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes
=
mm_hashes
,
mm_hashes
=
mm_hashes
,
mm_placeholders
=
mm_placeholder_ranges
,
mm_placeholders
=
mm_placeholder_ranges
,
)
)
class
EncDecMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
@
abstractmethod
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
"""Create input prompt for the encoder."""
raise
NotImplementedError
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer
=
self
.
info
.
get_tokenizer
()
if
isinstance
(
prompt
,
str
):
decoder_prompt
=
prompt
decoder_prompt_ids
=
encode_tokens
(
tokenizer
,
prompt
,
add_special_tokens
=
False
)
else
:
decoder_prompt
=
decode_tokens
(
tokenizer
,
prompt
)
decoder_prompt_ids
=
prompt
mm_inputs
=
MultiModalEncDecInputs
(
encoder_prompt
=
encoder_inputs
[
"prompt"
],
encoder_prompt_token_ids
=
encoder_inputs
[
"prompt_token_ids"
],
**
encoder_inputs
)
mm_inputs
.
update
({
"prompt"
:
decoder_prompt
,
"prompt_token_ids"
:
decoder_prompt_ids
})
return
mm_inputs
vllm/multimodal/profiling.py
View file @
bc55d130
...
@@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
def
get_dummy_data
(
self
,
seq_len
:
int
,
is_encoder_data
:
bool
=
False
,
)
->
DummyData
:
# Avoid circular import
# Avoid circular import
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
...
@@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
total_len
=
len
(
prompt_token_ids
)
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
if
(
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
)
or
is_encoder_data
:
logger
.
warning
(
if
total_len
>
seq_len
:
"The context length (%d) of the model is too short "
logger
.
warning
(
"to hold the multi-modal embeddings in the worst case "
"The context length (%d) of the model is too short "
"(%d tokens in total, out of which %s are reserved for "
"to hold the multi-modal embeddings in the worst case "
"multi-modal embeddings). This may cause certain multi-modal "
"(%d tokens in total, out of which %s are reserved for "
"inputs to fail during inference, even when the input text is "
"multi-modal embeddings). This may cause certain "
"short. To avoid this, you should increase `max_model_len`, "
"multi-modal inputs to fail during inference, even when "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
"the input text is short. To avoid this, you should "
total_len
,
total_placeholders_by_modality
)
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment