Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4da1f667
Unverified
Commit
4da1f667
authored
Feb 14, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 14, 2025
Browse files
[VLM] Keep track of whether prompt replacements have been applied (#13215)
parent
556ef7f7
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
373 additions
and
329 deletions
+373
-329
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+8
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+1
-2
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+45
-12
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+51
-39
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+99
-122
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+0
-10
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+35
-65
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+9
-4
vllm/multimodal/parse.py
vllm/multimodal/parse.py
+57
-1
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+68
-74
No files found.
vllm/model_executor/models/glm4v.py
View file @
4da1f667
...
...
@@ -484,6 +484,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class
GLM4VMultiModalProcessor
(
BaseMultiModalProcessor
[
GLM4VProcessingInfo
]):
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
...
...
vllm/model_executor/models/llava.py
View file @
4da1f667
...
...
@@ -294,7 +294,7 @@ class PixtralHFMultiModalProcessor(
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
if
pixel_values
is
not
None
:
# Before/after https://github.com/huggingface/transformers/pull/35122
if
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.48.
2
"
):
if
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.48.
3
"
):
images
=
mm_data
[
"images"
]
assert
isinstance
(
images
,
list
)
...
...
@@ -819,7 +819,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids
,
mm_item_counts
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
mm_placeholder_ranges
=
{
...
...
vllm/model_executor/models/llava_onevision.py
View file @
4da1f667
...
...
@@ -299,36 +299,69 @@ class LlavaOnevisionMultiModalProcessor(
mm_kwargs
=
mm_kwargs
,
)
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
# So, we process each component separately
# NOTE: No prompt replacement is applied in this case
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
video_token
=
processor
.
video_token
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
text_image_outputs
=
super
().
_call_hf_processor
(
text_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_data
=
{}
,
mm_kwargs
=
mm_kwargs
,
)
images
=
mm_data
.
pop
(
"images"
,
[])
assert
isinstance
(
images
,
list
)
if
images
:
processor_outputs
=
super
().
_call_hf_processor
(
prompt
=
image_token
*
len
(
images
),
mm_data
=
{
"images"
:
images
},
mm_kwargs
=
mm_kwargs
,
)
image_outputs
=
{
k
:
v
for
k
,
v
in
processor_outputs
.
items
()
if
k
in
(
"pixel_values"
,
"image_sizes"
)
}
else
:
image_outputs
=
{}
pixel_values_videos
=
[]
for
video
in
videos
:
item_processor_data
=
dict
(
prompt
=
video_token
,
videos
=
video
)
item_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
item_processor_data
,
prompt
=
video_token
,
mm_data
=
{
"videos"
:
video
}
,
mm_kwargs
=
mm_kwargs
,
)
pixel_values_videos
.
append
(
item_outputs
.
pop
(
"pixel_values_videos"
)[
0
])
pixel_values_videos
.
append
(
item_outputs
[
"pixel_values_videos"
][
0
])
video_outputs
=
{
"pixel_values_videos"
:
pixel_values_videos
}
combined_outputs
=
dict
(
**
text_image_outputs
,
pixel_values_videos
=
pixel_values_videos
,
text_outputs
,
**
image_outputs
,
**
video_outputs
,
)
return
BatchFeature
(
combined_outputs
)
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
base_result
=
super
().
_hf_processor_applies_repl
(
prompt_text
=
prompt_text
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
base_result
and
mm_items
.
get_count
(
"video"
,
strict
=
False
)
==
0
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
...
...
vllm/model_executor/models/minicpmo.py
View file @
4da1f667
...
...
@@ -27,8 +27,8 @@ from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.types
from
torch
import
nn
from
transformers
import
BatchFeature
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
transformers.models.whisper.modeling_whisper
import
(
ACT2FN
,
WHISPER_ATTENTION_CLASSES
,
WhisperConfig
,
WhisperEncoder
)
...
...
@@ -37,23 +37,21 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.parse
import
(
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
VideoItem
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
PromptReplacement
)
from
vllm.multimodal.parse
import
(
AudioItem
,
DictEmbeddingItems
,
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
PromptReplacement
from
vllm.multimodal.profiling
import
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
.minicpmv
import
(
MiniCPMV2_6
,
MiniCPMVDummyInputsBuilder
,
MiniCPMVEmbeddingItems
,
MiniCPMVMultiModalDataParser
,
MiniCPMVMultiModalProcessor
,
MiniCPMVProcessingInfo
)
MiniCPMVMultiModalDataParser
,
MiniCPMVMultiModalProcessor
,
MiniCPMVProcessingInfo
,
_minicpmv_field_config
)
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
CPU_DEVICE
=
torch
.
device
(
"cpu"
)
MiniCPMOEmbeddingItems
=
MiniCPMVEmbeddingItems
class
MiniCPMOAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
...
...
@@ -103,28 +101,49 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs
]
class
MiniCPMOAudioEmbeddingItems
(
MiniCPMOEmbeddingItems
):
def
_minicpmo_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
audio_num_slices
=
hf_inputs
.
get
(
"audio_num_slices"
,
torch
.
empty
(
0
))
def
__init__
(
self
,
data
:
Dict
)
->
None
:
super
().
__init__
(
data
,
"audio"
)
audio_embeds
=
self
.
data
.
get
(
"audio_embeds"
,
None
)
if
audio_embeds
is
None
:
raise
ValueError
(
"Incorrect type of video_embeds"
,
"Got type: None"
)
self
.
data
[
"audio_embeds"
]
=
audio_embeds
return
dict
(
**
_minicpmv_field_config
(
hf_inputs
),
audio_features
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
),
audio_feature_lens
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
),
audio_num_slices
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_orders_in_mm_data
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
),
)
def
get
(
self
,
index
:
int
)
->
object
:
return
self
.
data
[
"audio_embeds"
][
index
]
class
MiniCPMOAudioEmbeddingItems
(
DictEmbeddingItems
):
def
__init__
(
self
,
data
:
Mapping
[
str
,
torch
.
Tensor
],
fields_config
:
Mapping
[
str
,
MultiModalFieldConfig
],
)
->
None
:
super
().
__init__
(
data
,
modality
=
"image"
,
fields_config
=
fields_config
,
required_fields
=
{
"audio_embeds"
},
)
class
MiniCPMOMultiModalDataParser
(
MiniCPMVMultiModalDataParser
):
def
_parse_audio_data
(
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
Vide
oItem
]],
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
Audi
oItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
MiniCPMOAudioEmbeddingItems
(
data
)
return
MiniCPMOAudioEmbeddingItems
(
data
,
fields_config
=
_minicpmo_field_config
(
data
),
)
return
super
().
_parse_audio_data
(
data
)
...
...
@@ -167,6 +186,10 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def
get_max_audio_chunks_with_most_features
(
self
)
->
int
:
return
30
def
get_max_audio_tokens
(
self
)
->
int
:
return
self
.
get_max_audio_tokens_per_chunk
(
)
*
self
.
get_max_audio_chunks_with_most_features
()
def
get_audio_len_by_num_chunks
(
self
,
num_chunks
:
int
)
->
int
:
sampling_rate
=
self
.
get_default_audio_sampling_rate
()
# exclude <audio> </audio>
...
...
@@ -194,7 +217,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
return
num_frames
class
MiniCPMODummyInputsBuilder
(
MiniCPMVDummyInputsBuilder
):
class
MiniCPMODummyInputsBuilder
(
MiniCPMVDummyInputsBuilder
[
MiniCPMOProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
...
...
@@ -222,8 +246,7 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class
MiniCPMOMultiModalProcessor
(
MiniCPMVMultiModalProcessor
,
BaseMultiModalProcessor
[
MiniCPMOProcessingInfo
]):
MiniCPMVMultiModalProcessor
[
MiniCPMOProcessingInfo
]):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
MiniCPMOMultiModalDataParser
(
...
...
@@ -369,21 +392,10 @@ class MiniCPMOMultiModalProcessor(
def
_get_mm_fields_config
(
self
,
hf_inputs
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
audio_num_slices
=
hf_inputs
.
get
(
"audio_num_slices"
,
torch
.
empty
(
0
))
return
dict
(
**
super
().
_get_mm_fields_config
(
hf_inputs
,
hf_processor_mm_kwargs
),
audio_features
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
),
audio_feature_lens
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
),
audio_num_slices
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_orders_in_mm_data
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"audio"
,
audio_num_slices
))
return
_minicpmo_field_config
(
hf_inputs
)
class
MultiModalProjector
(
nn
.
Module
):
...
...
@@ -406,7 +418,7 @@ class MultiModalProjector(nn.Module):
class
MiniCPMWhisperEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
WhisperConfig
,
layer_idx
:
int
=
None
):
def
__init__
(
self
,
config
:
WhisperConfig
,
layer_idx
:
int
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
WHISPER_ATTENTION_CLASSES
[
...
...
vllm/model_executor/models/minicpmv.py
View file @
4da1f667
...
...
@@ -35,6 +35,7 @@ import torch.types
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
...
...
@@ -51,9 +52,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
(
ImageItem
,
ImageSize
,
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
VideoItem
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageItem
,
ImageSize
,
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
VideoItem
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
...
...
@@ -115,93 +117,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageEmbeddingInputs
]
class
MiniCPMVEmbeddingItems
(
ModalityDataItems
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
torch
.
Tensor
]]):
def
__init__
(
self
,
data
:
Dict
,
modality
:
str
)
->
None
:
super
().
__init__
(
data
,
modality
)
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
return
self
.
data
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{}
def
get_count
(
self
)
->
int
:
return
len
(
self
.
data
[
f
"
{
self
.
modality
}
_embeds"
])
def
get
(
self
,
index
:
int
)
->
Dict
[
str
,
torch
.
Tensor
]:
out
=
{}
for
k
,
v
in
self
.
data
.
items
():
out
[
k
]
=
v
[
index
]
return
out
class
MiniCPMVImageEmbeddingItems
(
MiniCPMVEmbeddingItems
):
def
__init__
(
self
,
data
:
Dict
)
->
None
:
super
().
__init__
(
data
,
"image"
)
image_embeds
=
self
.
data
.
get
(
"image_embeds"
,
None
)
image_sizes
=
self
.
data
.
get
(
"image_sizes"
,
None
)
if
image_embeds
is
None
:
raise
ValueError
(
"In correct type of image_embeds"
,
"Got type: None"
)
if
not
isinstance
(
image_embeds
[
0
],
torch
.
Tensor
):
raise
ValueError
(
"In correct type of image_embeds"
,
f
"Got type:
{
type
(
image_embeds
[
0
])
}
"
)
if
image_sizes
is
None
:
raise
ValueError
(
"In correct type of image_sizes"
,
"Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`"
)
if
len
(
image_embeds
[
0
].
shape
)
==
2
:
image_embeds
=
[
image_embeds
]
image_sizes
=
[
image_sizes
]
self
.
data
[
"image_embeds"
]
=
image_embeds
self
.
data
[
"image_sizes"
]
=
image_sizes
def
get_image_size
(
self
,
index
:
int
)
->
ImageSize
:
image_size
=
self
.
data
[
"image_sizes"
][
index
]
return
ImageSize
(
width
=
image_size
[
0
],
height
=
image_size
[
1
])
class
MiniCPMVVideoEmbeddingItems
(
MiniCPMVEmbeddingItems
):
def
__init__
(
self
,
data
:
Dict
)
->
None
:
super
().
__init__
(
data
,
"video"
)
video_embeds
=
self
.
data
.
get
(
"video_embeds"
,
None
)
image_sizes
=
self
.
data
.
get
(
"image_sizes"
,
None
)
num_frames
=
self
.
data
.
get
(
"num_frames"
,
None
)
if
video_embeds
is
None
:
raise
ValueError
(
"In correct type of video_embeds"
,
"Got type: None"
)
if
not
isinstance
(
video_embeds
[
0
],
torch
.
Tensor
):
raise
ValueError
(
"In correct type of video_embeds"
,
f
"Got type:
{
type
(
video_embeds
[
0
])
}
"
)
if
image_sizes
is
None
:
raise
ValueError
(
"In correct type of image_sizes"
,
"Got type: None."
"If you're using `image_size_list`, "
"please rename it to `image_sizes`"
)
if
num_frames
is
None
:
raise
ValueError
(
"In correct type of numframes"
,
"Got type: None"
)
if
len
(
video_embeds
[
0
].
shape
)
==
2
:
video_embeds
=
[
video_embeds
]
image_sizes
=
[
image_sizes
]
num_frames
=
[
num_frames
]
self
.
data
[
"video_embeds"
]
=
video_embeds
self
.
data
[
"image_sizes"
]
=
image_sizes
self
.
data
[
"num_frames"
]
=
num_frames
def
get_frame_size
(
self
,
index
:
int
)
->
ImageSize
:
frame_size
=
self
.
data
[
"image_sizes"
][
index
]
return
ImageSize
(
width
=
frame_size
[
0
],
height
=
frame_size
[
1
])
def
get_num_frames
(
self
,
index
:
int
)
->
int
:
return
self
.
data
[
"num_frames"
][
index
]
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
...
...
@@ -311,6 +226,71 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return
tuple
(
int
(
x
)
for
x
in
version_str
.
split
(
"."
))
def
_minicpmv_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
image_num_slices
=
hf_inputs
.
get
(
"image_num_slices"
,
torch
.
empty
(
0
))
video_num_slices
=
hf_inputs
.
get
(
"video_num_slices"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
image_num_slices
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
video_pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_num_slices
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
class
MiniCPMVImageEmbeddingItems
(
DictEmbeddingItems
):
def
__init__
(
self
,
data
:
Mapping
[
str
,
torch
.
Tensor
],
fields_config
:
Mapping
[
str
,
MultiModalFieldConfig
],
)
->
None
:
super
().
__init__
(
data
,
modality
=
"image"
,
fields_config
=
fields_config
,
required_fields
=
{
"image_embeds"
,
"image_sizes"
},
)
def
get_image_size
(
self
,
index
:
int
)
->
ImageSize
:
image_size
=
self
.
get
(
index
)[
"image_sizes"
].
tolist
()
return
ImageSize
(
width
=
image_size
[
0
],
height
=
image_size
[
1
])
class
MiniCPMVVideoEmbeddingItems
(
DictEmbeddingItems
):
def
__init__
(
self
,
data
:
Mapping
[
str
,
torch
.
Tensor
],
fields_config
:
Mapping
[
str
,
MultiModalFieldConfig
],
)
->
None
:
super
().
__init__
(
data
,
modality
=
"video"
,
fields_config
=
fields_config
,
required_fields
=
{
"video_embeds"
,
"video_image_sizes"
},
)
def
get_frame_size
(
self
,
index
:
int
)
->
ImageSize
:
frame_size
=
self
.
get
(
index
)[
"video_image_sizes"
].
tolist
()
return
ImageSize
(
width
=
frame_size
[
0
],
height
=
frame_size
[
1
])
def
get_num_frames
(
self
,
index
:
int
)
->
int
:
return
len
(
self
.
get
(
index
)[
"video_image_sizes"
])
class
MiniCPMVMultiModalDataParser
(
MultiModalDataParser
):
def
_parse_image_data
(
...
...
@@ -318,7 +298,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
MiniCPMVImageEmbeddingItems
(
data
)
return
MiniCPMVImageEmbeddingItems
(
data
,
fields_config
=
_minicpmv_field_config
(
data
),
)
return
super
().
_parse_image_data
(
data
)
def
_parse_video_data
(
...
...
@@ -326,7 +310,11 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
MiniCPMVVideoEmbeddingItems
(
data
)
return
MiniCPMVVideoEmbeddingItems
(
data
,
fields_config
=
_minicpmv_field_config
(
data
),
)
return
super
().
_parse_video_data
(
data
)
...
...
@@ -392,10 +380,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return
self
.
get_max_video_frame_tokens
(
)
*
self
.
get_num_frames_with_most_features
(
seq_len
)
def
get_max_audio_tokens
(
self
)
->
int
:
return
self
.
get_max_audio_tokens_per_chunk
(
)
*
self
.
get_max_audio_chunks_with_most_features
()
def
get_slice_query_num
(
self
)
->
int
:
hf_config
=
self
.
get_hf_config
()
query_num
=
getattr
(
hf_config
,
"query_num"
,
64
)
...
...
@@ -476,8 +460,12 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return
ImageSize
(
width
=
image_size
,
height
=
image_size
*
num_slices
)
class
MiniCPMVDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MiniCPMVProcessingInfo
]
):
_I
=
TypeVar
(
"_I"
,
bound
=
MiniCPMVProcessingInfo
,
default
=
MiniCPMVProcessingInfo
)
class
MiniCPMVDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
def
get_dummy_processor_inputs
(
self
,
...
...
@@ -514,8 +502,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[MiniCPMVProcessingInfo]
mm_data
=
mm_data
)
class
MiniCPMVMultiModalProcessor
(
BaseMultiModalProcessor
[
MiniCPMVProcessingInfo
]):
class
MiniCPMVMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
MiniCPMVMultiModalDataParser
()
...
...
@@ -675,7 +662,7 @@ class MiniCPMVMultiModalProcessor(
self
.
info
.
get_video_max_slice_num
()
)
*
inputs
[
modality
][
"num_frames"
][
index
]
else
:
raise
ValueError
(
f
"Un
E
xpected modality:
{
modality
}
"
)
raise
ValueError
(
f
"Un
e
xpected modality:
{
modality
}
"
)
def
check_mm_inputs
(
self
,
inputs
:
Dict
[
str
,
object
],
matches
:
List
[
str
])
->
None
:
...
...
@@ -700,7 +687,7 @@ class MiniCPMVMultiModalProcessor(
inputs
[
"video"
][
"video_image_sizes"
][
index
],
inputs
[
"video"
][
"num_frames"
][
index
])
else
:
raise
ValueError
(
f
"Un
E
xpected modality:
{
modality
}
"
)
raise
ValueError
(
f
"Un
e
xpected modality:
{
modality
}
"
)
def
call_base_hf_processor
(
self
,
...
...
@@ -742,6 +729,14 @@ class MiniCPMVMultiModalProcessor(
}
}
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
...
...
@@ -770,28 +765,10 @@ class MiniCPMVMultiModalProcessor(
def
_get_mm_fields_config
(
self
,
hf_inputs
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_num_slices
=
hf_inputs
.
get
(
"image_num_slices"
,
torch
.
empty
(
0
))
video_num_slices
=
hf_inputs
.
get
(
"video_num_slices"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
image_num_slices
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_slices
),
video_pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_num_slices
),
video_num_slices
=
MultiModalFieldConfig
.
batched
(
"video"
))
return
_minicpmv_field_config
(
hf_inputs
)
def
apply
(
self
,
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
4da1f667
...
...
@@ -243,16 +243,6 @@ class Qwen2AudioMultiModalProcessor(
)
]
def
_always_apply_prompt_replacements
(
self
)
->
bool
:
# Qwen2-Audio processor will start inserting placeholder tokens
# in an upcoming release:
# https://github.com/huggingface/transformers/pull/35534
# NOTE: `_find_placeholders_by_modality` may incorrectly think that HF
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return
not
hasattr
(
self
.
info
.
get_hf_processor
(),
"audio_token"
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen2AudioMultiModalProcessor
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
4da1f667
...
...
@@ -58,8 +58,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
from
vllm.multimodal.parse
import
(
ImageSize
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageSize
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
...
...
@@ -657,49 +658,25 @@ class Qwen2VisionTransformer(nn.Module):
return
loaded_params
class
Qwen2VLEmbeddingItems
(
ModalityDataItems
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
torch
.
Tensor
]]):
def
__init__
(
self
,
data
:
dict
,
modality
:
str
)
->
None
:
super
().
__init__
(
data
,
modality
)
grid_thw
=
data
[
f
"
{
modality
}
_grid_thw"
]
slice_idxs
=
[
0
]
+
grid_thw
.
prod
(
-
1
).
cumsum_
(
0
).
tolist
()
self
.
_slices
=
[
slice
(
slice_idxs
[
i
],
slice_idxs
[
i
+
1
])
for
i
in
range
(
len
(
grid_thw
))
]
def
get_count
(
self
)
->
int
:
return
len
(
self
.
data
[
f
"
{
self
.
modality
}
_grid_thw"
])
def
get
(
self
,
index
:
int
)
->
dict
[
str
,
torch
.
Tensor
]:
out
=
{}
for
k
,
v
in
self
.
data
.
items
():
if
v
!=
f
"
{
self
.
modality
}
_grid_thw"
:
v
=
v
[
self
.
_slices
[
index
]]
out
[
k
]
=
v
return
out
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{}
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
self
.
data
class
Qwen2VLImageEmbeddingItems
(
Qwen2VLEmbeddingItems
):
def
__init__
(
self
,
data
:
dict
)
->
None
:
super
().
__init__
(
data
,
"image"
)
def
_qwen2vl_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
image_grid_thw
=
hf_inputs
.
get
(
"image_grid_thw"
,
torch
.
empty
((
0
,
3
)))
image_grid_sizes
=
image_grid_thw
.
prod
(
-
1
)
class
Qwen2VLVideoEmbeddingItems
(
Qwen2VLEmbeddingItems
):
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
def
__init__
(
self
,
data
:
dict
)
->
None
:
super
().
__init__
(
data
,
"video"
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_grid_thw
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
class
Qwen2VLMultiModalDataParser
(
MultiModalDataParser
):
...
...
@@ -709,7 +686,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
Qwen2VLEmbeddingItems
(
data
,
modality
=
"image"
)
return
DictEmbeddingItems
(
data
,
modality
=
"image"
,
fields_config
=
_qwen2vl_field_config
(
data
),
required_fields
=
{
"image_embeds"
,
"image_grid_thw"
},
)
return
super
().
_parse_image_data
(
data
)
...
...
@@ -718,7 +700,12 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
if
isinstance
(
data
,
dict
):
return
Qwen2VLEmbeddingItems
(
data
,
modality
=
"video"
)
return
DictEmbeddingItems
(
data
,
modality
=
"video"
,
fields_config
=
_qwen2vl_field_config
(
data
),
required_fields
=
{
"video_embeds"
,
"video_grid_thw"
},
)
return
super
().
_parse_video_data
(
data
)
...
...
@@ -999,24 +986,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_grid_thw
=
hf_inputs
.
get
(
"image_grid_thw"
,
torch
.
empty
((
0
,
3
)))
image_grid_sizes
=
image_grid_thw
.
prod
(
-
1
)
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_grid_sizes
),
image_grid_thw
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_videos
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_embeds
=
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
return
_qwen2vl_field_config
(
hf_inputs
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen2VLMultiModalProcessor
,
...
...
vllm/model_executor/models/qwen_vl.py
View file @
4da1f667
...
...
@@ -520,10 +520,7 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
return
_get_tokenizer_without_image_pad
(
tokenizer
)
def
get_hf_processor
(
self
)
->
QwenVLProcessor
:
tokenizer
=
self
.
ctx
.
tokenizer
assert
isinstance
(
tokenizer
,
PreTrainedTokenizer
)
return
QwenVLProcessor
(
self
.
get_hf_config
(),
tokenizer
)
return
QwenVLProcessor
(
self
.
get_hf_config
(),
self
.
get_tokenizer
())
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
...
...
@@ -605,6 +602,14 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs
=
mm_kwargs
,
)
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
...
...
vllm/multimodal/parse.py
View file @
4da1f667
...
...
@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import
numpy
as
np
import
torch
from
PIL.Image
import
Image
from
transformers
import
BatchFeature
from
typing_extensions
import
TypeAlias
,
TypeGuard
,
assert_never
from
vllm.utils
import
is_list_of
from
.audio
import
resample_audio
from
.inputs
import
(
AudioItem
,
HfAudioItem
,
HfImageItem
,
HfVideoItem
,
ImageItem
,
ModalityData
,
MultiModalDataDict
,
VideoItem
)
ImageItem
,
ModalityData
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
_T
=
TypeVar
(
"_T"
)
_I
=
TypeVar
(
"_I"
)
...
...
@@ -111,6 +113,60 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return
len
(
self
.
get
(
item_idx
))
class
DictEmbeddingItems
(
ModalityDataItems
[
Mapping
[
str
,
torch
.
Tensor
],
Mapping
[
str
,
torch
.
Tensor
]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def
__init__
(
self
,
data
:
Mapping
[
str
,
torch
.
Tensor
],
modality
:
str
,
fields_config
:
Mapping
[
str
,
MultiModalFieldConfig
],
required_fields
:
set
[
str
],
)
->
None
:
super
().
__init__
(
data
,
modality
)
missing_required_fields
=
required_fields
-
fields_config
.
keys
()
if
missing_required_fields
:
fields
=
set
(
fields_config
.
keys
())
msg
=
f
"
{
required_fields
=
}
should be a subset of
{
fields
=
}
"
raise
ValueError
(
msg
)
missing_required_data_keys
=
required_fields
-
data
.
keys
()
if
missing_required_data_keys
:
data_keys
=
set
(
data
.
keys
())
msg
=
(
f
"The data should contain the fields:
{
required_fields
}
, "
f
"but only found the following keys:
{
data_keys
}
"
)
raise
ValueError
(
msg
)
self
.
fields_config
=
fields_config
self
.
required_fields
=
required_fields
self
.
_kwargs
=
MultiModalKwargs
.
from_hf_inputs
(
BatchFeature
(
dict
(
data
)),
fields_config
,
)
def
get_count
(
self
)
->
int
:
return
self
.
_kwargs
.
get_item_count
(
self
.
modality
)
def
get
(
self
,
index
:
int
)
->
Mapping
[
str
,
torch
.
Tensor
]:
return
{
k
:
v
.
data
for
k
,
v
in
self
.
_kwargs
.
get_item
(
self
.
modality
,
index
).
items
()
}
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{}
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
self
.
data
class
AudioProcessorItems
(
ProcessorBatchItems
[
HfAudioItem
]):
def
__init__
(
self
,
data
:
Sequence
[
HfAudioItem
])
->
None
:
...
...
vllm/multimodal/processing.py
View file @
4da1f667
...
...
@@ -23,7 +23,8 @@ from .hasher import MultiModalHasher
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
.parse
import
(
DictEmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
MultiModalDataParser
)
if
TYPE_CHECKING
:
from
.profiling
import
BaseDummyInputsBuilder
...
...
@@ -830,15 +831,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
,
)
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
"""
Return whether the HF processor applies prompt replacements.
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
are passed.
"""
return
not
any
(
isinstance
(
items
,
(
EmbeddingItems
,
DictEmbeddingItems
))
for
items
in
mm_items
.
values
())
def
_apply_hf_processor_text_mm
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
]:
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt replacements have been applied.
"""
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
...
...
@@ -856,7 +876,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
),
)
return
prompt_ids
,
mm_kwargs
is_repl_applied
=
self
.
_hf_processor_applies_repl
(
prompt_text
=
prompt_text
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
prompt_ids
,
mm_kwargs
,
is_repl_applied
def
_apply_hf_processor_text_only
(
self
,
prompt_text
:
str
)
->
list
[
int
]:
"""
...
...
@@ -866,7 +892,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
correspond to each other, we create dummy multi-modal items
to go along with the text.
"""
prompt_ids
,
_
=
self
.
_apply_hf_processor_text_mm
(
prompt_ids
,
_
,
_
=
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
prompt_text
,
mm_items
=
MultiModalDataItems
({}),
hf_processor_mm_kwargs
=
{},
...
...
@@ -908,7 +934,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_counts
,
)
_
,
mm_kwargs
=
self
.
_apply_hf_processor_text_mm
(
_
,
mm_kwargs
,
_
=
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
dummy_inputs
.
prompt_text
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
...
...
@@ -923,13 +949,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
*
,
enable_hf_prompt_replacement
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
]:
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
"""
Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt replacements have been applied
(for most HF processors, this should be :code:`True`).
Note:
If :code:`enable_hf_prompt_replacement=False`, the prompt should
correspond to the multi-modal items.
If :code:`enable_hf_prompt_replacement=False`, we use HF processor
to perform prompt replacement if available; HF processor requires
that the prompt corresponds to multi-modal items.
"""
if
isinstance
(
prompt
,
str
):
if
enable_hf_prompt_replacement
:
...
...
@@ -943,19 +973,19 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else
:
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt
)
mm_
missing_
kwargs
=
self
.
_apply_hf_processor_mm_only
(
mm_kwargs
=
self
.
_apply_hf_processor_mm_only
(
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
prompt_ids
,
mm_
missing_kwargs
return
prompt_ids
,
mm_
kwargs
,
False
def
_cached_apply_hf_processor
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
]:
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
...
...
@@ -992,8 +1022,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we need to pass `enable_hf_prompt_replacement=False`
prompt_ids
,
mm_missing_kwargs
=
self
.
_apply_hf_processor_main
(
# so we can't apply prompt replacements until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids
,
mm_missing_kwargs
,
is_repl_applied
,
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
mm_items
=
mm_missing_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
...
...
@@ -1036,7 +1071,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
return
prompt_ids
,
mm_kwargs
return
prompt_ids
,
mm_kwargs
,
is_repl_applied
def
_bind_and_group_repls
(
self
,
...
...
@@ -1047,18 +1082,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_repls
)
return
dict
(
full_groupby_modality
(
it
))
def
_always_apply_prompt_replacements
(
self
)
->
bool
:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via
:meth:`_find_placeholders_by_modality`.
This is useful in cases where :meth:`_find_placeholders_by_modality`
cannot be reliably used to detect whether HF has performed processing.
"""
return
False
def
_apply_prompt_replacements
(
self
,
token_ids
:
list
[
int
],
...
...
@@ -1155,29 +1178,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self
,
mm_placeholders
:
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]],
mm_item_counts
:
Mapping
[
str
,
int
],
*
,
allow_missing
:
bool
=
False
,
)
->
Mapping
[
str
,
int
]:
missing_repl_counts
=
dict
[
str
,
int
]()
)
->
None
:
for
modality
,
item_count
in
mm_item_counts
.
items
():
placeholders
=
mm_placeholders
.
get
(
modality
,
[])
if
len
(
placeholders
)
!=
item_count
and
not
allow_missing
:
if
len
(
placeholders
)
!=
item_count
:
raise
RuntimeError
(
f
"Expected there to be
{
item_count
}
prompt replacements "
f
"corresponding to
{
item_count
}
{
modality
}
items, but
only
"
f
"found
{
len
(
placeholders
)
}
prompt replacements!
Either
"
"the prompt text has missing/incorrect tokens for "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"
instead
found
{
len
(
placeholders
)
}
prompt replacements! "
"
Either
the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`)."
)
missing_repl_counts
[
modality
]
=
item_count
-
len
(
placeholders
)
return
missing_repl_counts
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
...
...
@@ -1217,7 +1232,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else
:
mm_hashes
=
None
prompt_ids
,
mm_kwargs
=
self
.
_cached_apply_hf_processor
(
(
prompt_ids
,
mm_kwargs
,
is_repl_applied
,
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
mm_items
,
hf_processor_mm_kwargs
,
...
...
@@ -1233,51 +1252,26 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
hf_mm_placeholders
=
self
.
_find_mm_placeholders
(
if
is_repl_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
mm_prompt_repls
,
prompt_ids
,
mm_item_counts
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
if
self
.
_always_apply_prompt_replacements
():
mm_missing_repl_counts
=
mm_item_counts
mm_missing_repls
=
dict
(
mm_prompt_repls
)
else
:
mm_missing_repl_counts
=
self
.
_validate_mm_placeholders
(
hf_mm_placeholders
,
mm_item_counts
,
allow_missing
=
True
,
)
mm_missing_repls
=
dict
[
str
,
list
[
BoundPromptReplacement
]]()
for
modality
,
missing_repl_count
in
mm_missing_repl_counts
.
items
():
if
missing_repl_count
==
0
:
mm_missing_repls
[
modality
]
=
[]
elif
missing_repl_count
==
mm_item_counts
.
get
(
modality
,
0
):
mm_missing_repls
[
modality
]
=
mm_prompt_repls
[
modality
]
else
:
raise
ValueError
(
"Partial prompt replacement within "
f
"
{
modality
=
}
is not supported"
)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if
all
(
len
(
repls
)
==
0
for
repls
in
mm_missing_repls
.
values
()):
tokenizer
=
self
.
info
.
get_tokenizer
()
prompt
=
decode_tokens
(
tokenizer
,
prompt_ids
)
mm_placeholders
=
hf_mm_placeholders
else
:
(
prompt_ids
,
prompt
,
missing_
mm_placeholders
,
mm_placeholders
,
)
=
self
.
_apply_prompt_replacements
(
prompt_ids
,
mm_
missing
_repls
,
mm_
missing_repl
_counts
,
mm_
prompt
_repls
,
mm_
item
_counts
,
)
mm_placeholders
=
{
**
hf_mm_placeholders
,
**
missing_mm_placeholders
}
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
mm_placeholder_ranges
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment