Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc55d130
"vscode:/vscode.git/clone" did not exist on "8fb2c135be35d4d07b3ef25e36c51733345a01eb"
Unverified
Commit
bc55d130
authored
Feb 13, 2025
by
Isotr0py
Committed by
GitHub
Feb 12, 2025
Browse files
[VLM] Implement merged multimodal processor for Mllama (#11427)
parent
d88c8666
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
456 additions
and
233 deletions
+456
-233
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+67
-4
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+11
-2
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+83
-7
vllm/inputs/registry.py
vllm/inputs/registry.py
+2
-1
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+203
-205
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+16
-0
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+57
-3
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+17
-11
No files found.
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
bc55d130
...
...
@@ -7,11 +7,11 @@ import torch
from
transformers
import
(
AutoConfig
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.model_executor.models.mllama
import
(
MLLAMA_IMAGE_TOKEN_ID
,
MllamaForConditionalGeneration
)
from
vllm.model_executor.models.mllama
import
MllamaForConditionalGeneration
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
...
...
@@ -21,6 +21,7 @@ from ....utils import large_gpu_test
from
...utils
import
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
3
MLLAMA_IMAGE_TOKEN_ID
=
128256
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
...
...
@@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
def
test_explicit_implicit_prompt
(
image_assets
:
_ImageAssets
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
):
stop_sign
=
image_assets
[
0
].
pil_image
# yapf: disable
prompts
=
[
# explicit prompt
{
"encoder_prompt"
:
{
"prompt"
:
"<|image|>"
,
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
"decoder_prompt"
:
{
"prompt_token_ids"
:
[
128000
,
791
,
2262
,
315
,
279
,
2217
,
220
,
128256
,
374
],
# noqa: E501
}
},
{
"encoder_prompt"
:
"Not <|image|>"
,
"decoder_prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
# implicit prompt
{
"prompt"
:
"<|begin_of_text|>The content of the image <|image|> is"
,
# noqa: E501
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
{
"prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
]
# yapf: enable
llm
=
LLM
(
model
=
model
,
dtype
=
dtype
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
enforce_eager
=
True
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
max_tokens
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
n_prompts
=
len
(
prompts
)
explicit_outputs
=
outputs
[:
n_prompts
//
2
]
implicit_outputs
=
outputs
[
n_prompts
//
2
:]
for
exp_output
,
imp_output
in
zip
(
explicit_outputs
,
implicit_outputs
):
assert
exp_output
.
outputs
[
0
].
text
==
imp_output
.
outputs
[
0
].
text
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
...
...
@@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
images
=
images
)
class
DummyModel
:
image_token_id
=
MLLAMA_IMAGE_TOKEN_ID
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"input_indices_and_output"
,
...
...
@@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
use_cuda_graph
=
False
,
)
dummy
:
dict
[
str
,
str
]
=
{}
dummy
=
DummyModel
()
cross_attention_mask
,
kv_range_for_decode
=
MllamaForConditionalGeneration
\
.
get_cross_attention_mask
(
dummy
,
...
...
@@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
use_cuda_graph
=
False
,
)
dummy
:
dict
[
str
,
str
]
=
{}
dummy
=
DummyModel
()
full_text_row_masked_out_mask
=
MllamaForConditionalGeneration
\
.
get_full_text_row_masked_out_mask
(
dummy
,
...
...
tests/models/multimodal/processing/test_common.py
View file @
bc55d130
...
...
@@ -85,6 +85,14 @@ def _test_processing_correctness(
partial
(
random_audio
,
rng
,
min_len
=
512
,
max_len
=
1024
,
sr
=
16000
),
}
tokenizer_encode_kwargs
=
{}
if
model_config
.
hf_config
.
model_type
==
"mllama"
:
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs
=
{
"add_special_tokens"
:
False
}
for
batch_idx
in
range
(
num_batches
):
mm_data
=
{
k
:
...
...
@@ -122,7 +130,7 @@ def _test_processing_correctness(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
baseline_tokenized_result
=
baseline_processor
.
apply
(
tokenizer
.
encode
(
prompt
),
tokenizer
.
encode
(
prompt
,
**
tokenizer_encode_kwargs
),
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
...
...
@@ -131,7 +139,7 @@ def _test_processing_correctness(
f
"Failed (
{
batch_idx
=
}
,
{
prompt
=
}
,
{
mm_data
=
}
)"
)
cached_tokenized_result
=
cached_processor
.
apply
(
tokenizer
.
encode
(
prompt
),
tokenizer
.
encode
(
prompt
,
**
tokenizer_encode_kwargs
),
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
...
...
@@ -155,6 +163,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
,
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
"mistral-community/pixtral-12b"
,
"openbmb/MiniCPM-o-2_6"
,
...
...
vllm/inputs/preprocess.py
View file @
bc55d130
# SPDX-License-Identifier: Apache-2.0
import
asyncio
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
,
cast
from
typing_extensions
import
assert_never
...
...
@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalInputs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
...
...
@@ -495,6 +496,51 @@ class InputPreprocessor:
decoder
=
decoder_inputs
,
)
def
_separate_enc_dec_inputs_from_mm_processor_outputs
(
self
,
inputs
:
SingletonInputs
,
decoder_inputs_to_override
:
Optional
[
SingletonInputs
]
=
None
,
)
->
Tuple
[
SingletonInputs
,
SingletonInputs
]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs
:
SingletonInputs
decoder_inputs
:
SingletonInputs
if
inputs
[
"type"
]
==
"multimodal"
:
# Multimodal data inputs
assert
(
"encoder_prompt"
in
inputs
and
"encoder_prompt_token_ids"
in
inputs
)
inputs
=
cast
(
MultiModalEncDecInputs
,
inputs
)
encoder_inputs
=
token_inputs
(
prompt
=
inputs
[
"encoder_prompt"
],
prompt_token_ids
=
inputs
[
"encoder_prompt_token_ids"
],
)
if
decoder_inputs_to_override
is
not
None
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
decoder_inputs_to_override
.
get
(
"prompt"
,
""
),
prompt_token_ids
=
decoder_inputs_to_override
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
else
:
decoder_inputs
=
MultiModalInputs
(
type
=
"multimodal"
,
prompt
=
inputs
[
"prompt"
],
prompt_token_ids
=
inputs
[
"prompt_token_ids"
],
mm_kwargs
=
inputs
[
"mm_kwargs"
],
mm_placeholders
=
inputs
[
"mm_placeholders"
],
)
elif
inputs
[
"type"
]
==
"token"
:
# Text-only inputs
encoder_inputs
=
token_inputs
(
prompt
=
""
,
prompt_token_ids
=
[])
decoder_inputs
=
decoder_inputs_to_override
or
inputs
else
:
assert_never
(
inputs
)
# type: ignore[arg-type]
return
encoder_inputs
,
decoder_inputs
def
_process_encoder_decoder_prompt
(
self
,
prompt
:
PromptType
,
...
...
@@ -539,7 +585,6 @@ class InputPreprocessor:
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
else
:
...
...
@@ -547,11 +592,26 @@ class InputPreprocessor:
decoder_input
,
request_id
=
request_id
,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
encoder_
inputs
=
self
.
_prompt_to_llm_inputs
(
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
...
...
@@ -583,11 +643,27 @@ class InputPreprocessor:
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
encoder_
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
inputs
))
else
:
encoder_inputs
=
inputs
decoder_inputs
=
None
...
...
vllm/inputs/registry.py
View file @
bc55d130
...
...
@@ -350,7 +350,8 @@ class InputRegistry:
)
processor
=
mm_registry
.
create_processor
(
model_config
,
tokenizer
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
)
dummy_data
=
profiler
.
get_dummy_data
(
seq_len
,
is_encoder_data
=
is_encoder_data
)
else
:
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
is_encoder_data
:
...
...
vllm/model_executor/models/mllama.py
View file @
bc55d130
...
...
@@ -23,14 +23,15 @@ import torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
transformers.models.mllama.configuration_mllama
as
config_mllama
from
PIL
import
Image
from
PIL
.Image
import
Image
from
torch
import
nn
from
transformers
import
BatchFeature
,
MllamaConfig
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
from
transformers.models.mllama.processing_mllama
import
(
get_cross_attention_token_mask
)
MllamaProcessor
,
get_cross_attention_token_mask
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
...
...
@@ -38,8 +39,6 @@ from vllm.attention.ops.paged_attn import PagedAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DummyData
,
EncoderDecoderInputs
,
InputContext
,
TokenInputs
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -54,8 +53,13 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
is_list_of
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataDict
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
...
...
@@ -63,8 +67,6 @@ from .llama import LlamaDecoderLayer, LlamaMLP
from
.utils
import
maybe_prefix
logger
=
init_logger
(
__name__
)
MLLAMA_IMAGE_TOKEN_ID
=
128256
MLLAMA_IMAGE_TOKEN
=
"<|image|>"
class
MllamaImagePixelInputs
(
TypedDict
):
...
...
@@ -81,158 +83,191 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs
def
_get_num_image_in_last_group
(
prompt_token_ids
:
List
[
int
])
->
int
:
num_images
=
0
for
token_id
in
prompt_token_ids
[::
-
1
]:
if
token_id
==
MLLAMA_IMAGE_TOKEN_ID
:
num_images
+=
1
elif
num_images
>
0
:
break
return
num_images
def
calc_token_per_chunk
(
image_size
:
int
)
->
int
:
assert
image_size
%
14
==
0
,
"chunk size should be multiple of 14"
token_per_chunk
=
(
image_size
//
14
)
**
2
+
1
return
token_per_chunk
def
input_processor_for_mllama
(
ctx
:
InputContext
,
inputs
:
EncoderDecoderInputs
,
)
->
EncoderDecoderInputs
:
# Example input to processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }
class
MllamaProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
)
->
MllamaConfig
:
return
self
.
ctx
.
get_hf_config
(
MllamaConfig
)
def
get_hf_processor
(
self
)
->
MllamaProcessor
:
return
self
.
ctx
.
get_hf_processor
(
MllamaProcessor
)
# move encoder prompt to decoder
dec_inputs
=
TokenInputs
(
**
inputs
[
"encoder"
])
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
multi_modal_data
=
dec_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
# text-only
return
EncoderDecoderInputs
(
encoder
=
token_inputs
([]),
decoder
=
dec_inputs
,
def
get_token_per_chunk_from_config
(
self
)
->
int
:
image_size
=
self
.
get_hf_config
().
vision_config
.
image_size
return
calc_token_per_chunk
(
image_size
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
vision_config
=
self
.
get_hf_config
().
vision_config
token_per_chunk
=
self
.
get_token_per_chunk_from_config
()
mm_max_tokens
=
vision_config
.
max_num_tiles
*
token_per_chunk
return
{
"image"
:
mm_max_tokens
}
def
get_num_tiles_per_image
(
self
,
image_height
:
int
,
image_width
:
int
)
->
int
:
vision_config
=
self
.
get_hf_config
().
vision_config
max_num_tiles
=
vision_config
.
max_num_tiles
image_size
=
vision_config
.
image_size
tiled_height
,
tiled_width
=
get_optimal_tiled_canvas
(
image_height
,
image_width
,
max_num_tiles
,
tile_size
=
image_size
,
)
num_tiles_height
=
tiled_height
//
image_size
num_tiles_width
=
tiled_width
//
image_size
return
num_tiles_height
*
num_tiles_width
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
image_data
=
[
image_data
]
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_config
=
self
.
get_hf_config
().
vision_config
image_size
=
vision_config
.
image_size
max_num_tiles
=
vision_config
.
max_num_tiles
# Result in the max possible feature size (h:w = 16:1)
return
ImageSize
(
height
=
max_num_tiles
*
image_size
,
width
=
image_size
)
assert
is_list_of
(
image_data
,
Image
.
Image
)
num_image_tokens
=
dec_inputs
[
'prompt_token_ids'
].
count
(
MLLAMA_IMAGE_TOKEN_ID
)
if
num_image_tokens
!=
len
(
image_data
):
raise
ValueError
(
f
"The number of image tokens (
{
num_image_tokens
}
) must be"
f
" the same as the number of images (
{
len
(
image_data
)
}
)"
)
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images
=
_get_num_image_in_last_group
(
dec_inputs
[
"prompt_token_ids"
])
hf_config
=
ctx
.
model_config
.
hf_config
vision_config
=
hf_config
.
vision_config
num_tiles
=
0
for
image
in
image_data
[::
-
1
]:
width
,
height
=
image
.
size
tile_size
=
vision_config
.
image_size
canvas_height
,
canvas_width
=
get_optimal_tiled_canvas
(
image_height
=
height
,
image_width
=
width
,
max_image_tiles
=
vision_config
.
max_num_til
es
,
tile_size
=
tile_size
,
class
MllamaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MllamaProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_imag
es
,
mm_data
=
mm_data
,
)
num_tiles_height
=
canvas_height
//
tile_size
num_tiles_width
=
canvas_width
//
tile_size
num_tiles
+=
num_tiles_height
*
num_tiles_width
num_decode_images
-=
1
if
num_decode_images
==
0
:
break
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert
vision_config
.
image_size
%
14
==
0
,
\
"chunk size should be multiple of 14"
token_per_chunk
=
(
vision_config
.
image_size
//
14
)
**
2
+
1
num_tokens
=
num_tiles
*
token_per_chunk
# Example output from processor:
class
MllamaMultiModalProcessor
(
EncDecMultiModalProcessor
[
MllamaProcessingInfo
]
):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
tokenizer
=
self
.
info
.
get_tokenizer
()
if
mm_data
:
num_tiles
=
[
self
.
info
.
get_num_tiles_per_image
(
img
.
height
,
img
.
width
)
for
img
in
mm_data
[
"images"
]
]
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
)
processed_outputs
[
"num_tiles"
]
=
torch
.
tensor
(
num_tiles
)
for
k
in
(
'pixel_values'
,
'aspect_ratio_ids'
,
"aspect_ratio_mask"
):
processed_outputs
[
k
]
=
processed_outputs
[
k
].
squeeze
(
0
)
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128
256, ..., 128256],
# 'prompt': '<|image|><|
image|>...<|image|>',
# 'prompt_token_ids': [128256, 128
000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|
begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# 'prompt_token_ids': [128000],
# },
# }
return
EncoderDecoderInputs
(
encoder
=
token_inputs
(
prompt_token_ids
=
[
MLLAMA_IMAGE_TOKEN_ID
]
*
num_tokens
,
prompt
=
MLLAMA_IMAGE_TOKEN
*
num_tokens
,
multi_modal_data
=
multi_modal_data
,
),
decoder
=
dec_inputs
,
)
def
get_max_mllama_image_tokens
(
ctx
:
InputContext
)
->
int
:
hf_config
=
ctx
.
model_config
.
hf_config
token_per_chunk
=
(
hf_config
.
vision_config
.
image_size
//
14
)
**
2
+
1
return
hf_config
.
vision_config
.
max_num_tiles
*
token_per_chunk
def
dummy_decoder_seq_data
(
seq_len
:
int
,
num_images
:
int
):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert
seq_len
>=
num_images
,
\
"seq_len should be greater than or equal to num_images"
processed_token_ids
=
processed_outputs
.
pop
(
"input_ids"
)
start_idx
,
end_idx
=
0
,
processed_token_ids
.
size
(
1
)
processed_prompt_text
=
tokenizer
.
decode
(
processed_token_ids
[
0
])
hf_processor
=
self
.
info
.
get_hf_processor
()
bos_token
=
hf_processor
.
bos_token
# Remove the bos_token from the start of prompt,
# because we all know there would be image_token.
if
processed_prompt_text
.
startswith
(
bos_token
):
start_idx
+=
1
# Remove the bos_token from the end of prompt,
# because text is empty in this case.
if
processed_prompt_text
.
endswith
(
bos_token
):
end_idx
-=
1
processed_outputs
[
"input_ids"
]
=
processed_token_ids
[:,
start_idx
:
end_idx
]
else
:
processed_outputs
=
tokenizer
(
prompt
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
return
processed_outputs
return
SequenceData
.
from_prompt_token_counts
(
(
MLLAMA_IMAGE_TOKEN_ID
,
num_images
),
(
0
,
seq_len
-
num_images
),
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratio_ids
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratio_mask
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_tiles
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
dummy_encoder_seq_data
(
ctx
:
InputContext
,
num_images
:
int
):
num_tokens
=
get_max_mllama_image_tokens
(
ctx
)
*
num_images
return
SequenceData
.
from_prompt_token_counts
(
(
MLLAMA_IMAGE_TOKEN_ID
,
num_tokens
))
def
dummy_image
(
num_images
:
int
,
):
width
=
height
=
1024
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
dummy_decoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
DummyData
(
dummy_decoder_seq_data
(
seq_len
,
num_images
))
def
dummy_encoder_data_for_mllama
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_images
=
mm_counts
[
"image"
]
return
DummyData
(
dummy_encoder_seq_data
(
ctx
,
num_images
),
dummy_image
(
num_images
))
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
data
=
mm_data
.
get
(
"image"
,
[])
num_images
=
1
if
isinstance
(
data
,
Image
)
else
len
(
data
)
image_token_id
=
self
.
info
.
get_hf_config
().
image_token_index
return
[
image_token_id
]
*
num_images
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
token_per_chunk
=
self
.
info
.
get_token_per_chunk_from_config
()
image_token_id
=
self
.
info
.
get_hf_config
().
image_token_index
def
get_replacement_mllama
(
item_idx
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
num_tile
=
self
.
info
.
get_num_tiles_per_image
(
image_height
=
image_size
.
height
,
image_width
=
image_size
.
width
,
)
num_tokens
=
num_tile
*
token_per_chunk
return
[
image_token_id
]
*
num_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement_mllama
,
)
]
def
_prepare_aspect_ratio_attention_mask
(
...
...
@@ -1107,11 +1142,9 @@ class MllamaForCausalLM(nn.Module):
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_mllama_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_decoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
MllamaMultiModalProcessor
,
info
=
MllamaProcessingInfo
,
dummy_inputs
=
MllamaDummyInputsBuilder
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
...
...
@@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
text_config
.
vocab_size
...
...
@@ -1130,6 +1163,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
pad_token_id
=
\
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
else
-
1
self
.
image_size
=
config
.
vision_config
.
image_size
self
.
image_token_id
=
config
.
image_token_index
self
.
vision_model
=
MllamaVisionModel
(
config
.
vision_config
,
quant_config
,
...
...
@@ -1204,48 +1238,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
pixel_values
is
not
None
:
assert
aspect_ratio_ids
is
not
None
assert
aspect_ratio_mask
is
not
None
max_num_images
=
max
([
len
(
x
[
0
])
for
x
in
pixel_values
])
if
max_num_images
==
0
:
raise
ValueError
(
"No images provided."
)
max_num_tiles
=
max
(
max
([
len
(
x
)
for
x
in
y
[
0
]])
for
y
in
pixel_values
)
device
=
next
(
self
.
multi_modal_projector
.
parameters
()).
device
bsz
=
len
(
pixel_values
)
out_num_tiles
=
[]
out_images
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
3
,
self
.
image_size
,
self
.
image_size
,
dtype
=
torch
.
float32
,
device
=
device
,
)
out_ar_ids
=
torch
.
ones
(
bsz
,
max_num_images
,
dtype
=
torch
.
int64
,
device
=
device
)
out_ar_mask
=
torch
.
zeros
(
bsz
,
max_num_images
,
max_num_tiles
,
dtype
=
torch
.
int64
,
device
=
device
)
for
b
in
range
(
len
(
pixel_values
)):
_num_tiles
=
[]
for
i
in
range
(
len
(
pixel_values
[
b
][
0
])):
img
=
pixel_values
[
b
][
0
][
i
]
out_images
[
b
,
i
,
:
img
.
shape
[
0
]]
=
img
out_ar_ids
[
b
,
i
]
=
aspect_ratio_ids
[
b
][
0
][
i
]
out_ar_mask
[
b
,
i
]
=
aspect_ratio_mask
[
b
][
0
][
i
]
_num_tiles
.
append
(
img
.
shape
[
0
])
out_num_tiles
.
append
(
_num_tiles
)
return
MllamaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
out_imag
es
,
aspect_ratio_ids
=
out_ar
_ids
,
aspect_ratio_mask
=
out_ar
_mask
,
data
=
pixel_valu
es
,
aspect_ratio_ids
=
aspect_ratio
_ids
,
aspect_ratio_mask
=
aspect_ratio
_mask
,
)
if
image_embeds
is
not
None
:
...
...
@@ -1312,7 +1310,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
batch_token_ids
.
append
(
token_ids
[
start
:
start
+
seq_len
])
start
+=
seq_len
sparse_mask
=
[
get_cross_attention_token_mask
(
t
,
MLLAMA_IMAGE_TOKEN_ID
)
get_cross_attention_token_mask
(
t
,
self
.
image_token_id
)
for
t
in
batch_token_ids
]
...
...
@@ -1384,8 +1382,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor
=
kwargs
.
pop
(
"num_tiles"
)
num_tiles
=
[
t
[
0
]
.
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
(
self
.
image_size
//
14
)
**
2
+
1
num_tiles
=
[
t
.
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
...
...
vllm/multimodal/inputs.py
View file @
bc55d130
...
...
@@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
class
MultiModalEncDecInputs
(
MultiModalInputs
):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt
:
str
"""The processed encoder prompt text."""
encoder_prompt_token_ids
:
list
[
int
]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids
:
NotRequired
[
list
[
int
]]
"""The token type IDs of the encoder prompt."""
vllm/multimodal/processing.py
View file @
bc55d130
...
...
@@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
from
vllm.utils
import
LRUCache
,
flatten_2d_lists
,
full_groupby
from
.hasher
import
MultiModalHasher
from
.inputs
import
(
MultiModalDataDict
,
MultiModal
FieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
.inputs
import
(
MultiModalDataDict
,
MultiModal
EncDecInputs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
if
TYPE_CHECKING
:
...
...
@@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes
=
mm_hashes
,
mm_placeholders
=
mm_placeholder_ranges
,
)
class
EncDecMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
@
abstractmethod
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
"""Create input prompt for the encoder."""
raise
NotImplementedError
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer
=
self
.
info
.
get_tokenizer
()
if
isinstance
(
prompt
,
str
):
decoder_prompt
=
prompt
decoder_prompt_ids
=
encode_tokens
(
tokenizer
,
prompt
,
add_special_tokens
=
False
)
else
:
decoder_prompt
=
decode_tokens
(
tokenizer
,
prompt
)
decoder_prompt_ids
=
prompt
mm_inputs
=
MultiModalEncDecInputs
(
encoder_prompt
=
encoder_inputs
[
"prompt"
],
encoder_prompt_token_ids
=
encoder_inputs
[
"prompt_token_ids"
],
**
encoder_inputs
)
mm_inputs
.
update
({
"prompt"
:
decoder_prompt
,
"prompt_token_ids"
:
decoder_prompt_ids
})
return
mm_inputs
vllm/multimodal/profiling.py
View file @
bc55d130
...
...
@@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
def
get_dummy_data
(
self
,
seq_len
:
int
,
is_encoder_data
:
bool
=
False
,
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
...
...
@@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
if
(
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
)
or
is_encoder_data
:
if
total_len
>
seq_len
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment