Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1579b22
Unverified
Commit
f1579b22
authored
Feb 28, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 27, 2025
Browse files
[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
78648758
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
363 additions
and
245 deletions
+363
-245
vllm/model_executor/models/nvlm_d.py
vllm/model_executor/models/nvlm_d.py
+7
-6
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+10
-11
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+9
-15
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+6
-6
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+9
-7
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+8
-7
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+6
-5
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+5
-5
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+303
-183
No files found.
vllm/model_executor/models/nvlm_d.py
View file @
f1579b22
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
from
typing
import
Mapping
,
Optional
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
Prompt
Replacement
Details
)
Prompt
Update
Details
)
from
vllm.multimodal.profiling
import
ProcessorInputs
from
vllm.multimodal.profiling
import
ProcessorInputs
from
.intern_vit
import
InternVisionModel
from
.intern_vit
import
InternVisionModel
...
@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
...
@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
if
"image_num_patches"
in
out_mm_kwargs
:
if
"image_num_patches"
in
out_mm_kwargs
:
...
@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
...
@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if
num_patches
is
not
None
:
if
num_patches
is
not
None
:
assert
isinstance
(
num_patches
,
int
)
assert
isinstance
(
num_patches
,
int
)
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
hf_processor
.
get_image_repl_full
(
feature_size
,
full
=
hf_processor
.
get_image_repl_full
(
feature_size
,
num_patches
)
+
"
\n
"
,
num_patches
)
+
"
\n
"
,
features
=
hf_processor
.
get_image_repl_features
(
features
=
hf_processor
.
get_image_repl_features
(
...
...
vllm/model_executor/models/phi3v.py
View file @
f1579b22
...
@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
...
@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
BoundPromptUpdate
,
BoundPromptReplacement
,
PlaceholderFeaturesInfo
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
Prompt
Replacement
Details
)
Prompt
Update
Details
)
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
...
@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_tokens
:
list
[
str
]
=
hf_processor
.
img_tokens
# type: ignore
image_tokens
:
list
[
str
]
=
hf_processor
.
img_tokens
# type: ignore
...
@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
...
@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens
=
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
image_tokens
=
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
image_tokens
+
[
bos_token_id
],
full
=
image_tokens
+
[
bos_token_id
],
features
=
image_tokens
,
features
=
image_tokens
,
)
)
...
@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
...
@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
)
for
image_token
in
image_tokens
[:
num_images
]
)
for
image_token
in
image_tokens
[:
num_images
]
]
]
def
_apply_prompt_
replacement
s
(
def
_apply_prompt_
update
s
(
self
,
self
,
token_ids
:
list
[
int
],
token_ids
:
list
[
int
],
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
token_ids
,
text
,
placeholders
=
super
().
_apply_prompt_
replacement
s
(
token_ids
,
text
,
placeholders
=
super
().
_apply_prompt_
update
s
(
token_ids
=
token_ids
,
token_ids
=
token_ids
,
mm_prompt_
repl
s
=
mm_prompt_
repl
s
,
mm_prompt_
update
s
=
mm_prompt_
update
s
,
mm_item_counts
=
mm_item_counts
,
mm_item_counts
=
mm_item_counts
,
)
)
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
f1579b22
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
typing
import
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...
@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs
,
MultiModalKwargs
)
MultiModalInputs
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
Prompt
Replacement
)
BaseProcessingInfo
,
Prompt
Update
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
PoolingSequenceGroupOutput
)
PoolingSequenceGroupOutput
)
...
@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
...
@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
pass
return
{
"image"
:
0
}
class
PrithviGeoSpatialMAEInputBuilder
(
class
PrithviGeoSpatialMAEInputBuilder
(
...
@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
pass
return
[]
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
pass
def
apply
(
def
apply
(
self
,
self
,
...
@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
):
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
):
""" Prithvi Masked Autoencoder"""
""" Prithvi Masked Autoencoder"""
def
_instantiate_model
(
self
,
config
:
dict
)
->
nn
.
Module
|
None
:
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]
:
# We might be able/need to support different tasks with this same model
# We might be able/need to support different tasks with this same model
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
...
@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
...
@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE."
)
"by PrithviGeospatialMAE."
)
def
_parse_and_validate_multimodal_data
(
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
self
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
f1579b22
...
@@ -21,9 +21,9 @@
...
@@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
from
typing
import
Any
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
...
@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser
)
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
Prompt
Replacement
Details
)
Prompt
Update
,
PromptUpdate
Details
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
feature_attention_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
vocab
=
tokenizer
.
get_vocab
()
...
@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens
=
[
audio_token_id
]
*
num_features
audio_tokens
=
[
audio_token_id
]
*
num_features
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
[
audio_bos_id
]
+
audio_tokens
+
[
audio_eos_id
],
full
=
[
audio_bos_id
]
+
audio_tokens
+
[
audio_eos_id
],
features
=
audio_tokens
,
features
=
audio_tokens
,
)
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
f1579b22
...
@@ -23,9 +23,10 @@
...
@@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
,
partial
from
functools
import
cached_property
,
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Se
t
,
from
typing
import
(
Any
,
Callable
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDic
t
,
Tuple
,
Type
,
TypedDict
,
Union
)
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
...
@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems
,
MultiModalDataItems
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
...
@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self
,
self
,
in_features
:
int
,
in_features
:
int
,
hidden_features
:
int
,
hidden_features
:
int
,
act_layer
:
T
ype
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
t
ype
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
...
@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim
:
int
,
dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
,
mlp_ratio
:
float
,
act_layer
:
T
ype
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
t
ype
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
...
@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
...
@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
),
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_processor
=
self
.
info
.
get_image_processor
(
image_processor
=
self
.
info
.
get_image_processor
(
**
hf_processor_mm_kwargs
)
**
hf_processor_mm_kwargs
)
...
...
vllm/model_executor/models/qwen_vl.py
View file @
f1579b22
...
@@ -9,9 +9,10 @@ import copy
...
@@ -9,9 +9,10 @@ import copy
import
math
import
math
import
re
import
re
import
unicodedata
import
unicodedata
from
collections.abc
import
Collection
,
Mapping
,
Sequence
from
collections.abc
import
Set
as
AbstractSet
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
typing
import
(
AbstractSet
,
Callable
,
Collection
,
List
,
Literal
,
Mapping
,
from
typing
import
Callable
,
List
,
Literal
,
Optional
,
TypedDict
,
Union
Optional
,
TypedDict
,
Union
)
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
...
@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
Prompt
Replacement
Details
)
Prompt
Update
,
PromptUpdate
Details
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
...
@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs
=
mm_kwargs
,
mm_kwargs
=
mm_kwargs
,
)
)
def
_hf_processor_applies_
repl
(
def
_hf_processor_applies_
updates
(
self
,
self
,
prompt_text
:
str
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
...
@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
...
@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
special_tokens
:
dict
[
str
,
special_tokens
:
dict
[
str
,
int
]
=
tokenizer
.
special_tokens
# type: ignore
int
]
=
tokenizer
.
special_tokens
# type: ignore
...
@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
...
@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
[
img_start_id
,
img_end_id
],
target
=
[
img_start_id
,
img_end_id
],
replacement
=
Prompt
Replacement
Details
(
replacement
=
Prompt
Update
Details
(
full
=
[
img_start_id
]
+
image_tokens
+
[
img_end_id
],
full
=
[
img_start_id
]
+
image_tokens
+
[
img_end_id
],
features
=
image_tokens
,
features
=
image_tokens
,
),
),
...
...
vllm/model_executor/models/ultravox.py
View file @
f1579b22
...
@@ -3,9 +3,9 @@
...
@@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
"""PyTorch Ultravox model."""
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
TypedDict
,
Union
)
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
...
@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors
)
NestedTensors
)
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
...
@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
...
@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
vocab
=
tokenizer
.
get_vocab
()
...
...
vllm/model_executor/models/whisper.py
View file @
f1579b22
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
math
import
math
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
Union
)
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
...
@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser
)
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptReplacement
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
...
@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
...
@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
))
return
dict
(
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
))
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
num_tokens
=
self
.
info
.
get_max_audio_tokens
()
num_tokens
=
self
.
info
.
get_max_audio_tokens
()
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
...
vllm/multimodal/processing.py
View file @
f1579b22
...
@@ -6,11 +6,14 @@ from collections import defaultdict
...
@@ -6,11 +6,14 @@ from collections import defaultdict
from
collections.abc
import
(
Callable
,
Generator
,
ItemsView
,
Iterable
,
Mapping
,
from
collections.abc
import
(
Callable
,
Generator
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
)
Sequence
)
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
functools
import
lru_cache
from
functools
import
lru_cache
from
itertools
import
groupby
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
)
TypeVar
,
Union
,
cast
)
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
...
@@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
...
@@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
@
dataclass
@
dataclass
class
Prompt
Replacement
Details
:
class
Prompt
Update
Details
:
"""Details about the
replacement
token sequence or text."""
"""Details about the token sequence or text
that are part of the update
."""
full
:
PromptSeq
full
:
PromptSeq
"""The full
replacem
ent."""
"""The full
cont
ent."""
features
:
PromptSeq
features
:
PromptSeq
"""
"""
The part of the
replacem
ent that corresponds to feature placeholders;
The part of the
cont
ent that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
this will be replaced by the output of the vision encoder during model
inference.
inference.
"""
"""
@
staticmethod
@
staticmethod
def
from_seq
(
seq
:
PromptSeq
)
->
"Prompt
Replacement
Details"
:
def
from_seq
(
seq
:
PromptSeq
)
->
"Prompt
Update
Details"
:
return
Prompt
Replacement
Details
(
full
=
seq
,
features
=
seq
)
return
Prompt
Update
Details
(
full
=
seq
,
features
=
seq
)
Prompt
Repl
=
Union
[
PromptSeq
,
Prompt
Replacement
Details
]
Prompt
UpdateInfo
=
Union
[
PromptSeq
,
Prompt
Update
Details
]
"""
"""
The
replacement
token sequence or text.
The token sequence or text
that are part of the update
.
If only part of the
replacem
ent corresponds to feature placeholders, you can
If only part of the
cont
ent corresponds to feature placeholders, you can
use :class:`Prompt
Replacement
Details` to specify which part.
use :class:`Prompt
Update
Details` to specify which part.
"""
"""
PromptUpdateContent
=
Union
[
Callable
[[
int
],
PromptUpdateInfo
],
PromptUpdateInfo
]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
class
UpdateMode
(
str
,
Enum
):
INSERT
=
"insert"
REPLACE
=
"replace"
@
dataclass
class
PromptUpdate
:
"""
Defines how to update a prompt with placeholder tokens.
"""
modality
:
str
"""The modality for which the update is made."""
target
:
PromptSeq
"""The token sequence (or text) to update."""
@
property
@
abstractmethod
def
content
(
self
)
->
PromptUpdateContent
:
"""The placeholder tokens that are part of the update."""
raise
NotImplementedError
@
property
@
abstractmethod
def
mode
(
self
)
->
UpdateMode
:
"""Defines how to update the prompt."""
raise
NotImplementedError
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptUpdate"
:
return
BoundPromptUpdate
(
_origin
=
self
,
tokenizer
=
tokenizer
,
)
@
dataclass
@
dataclass
class
PromptReplacement
:
class
PromptInsertion
(
PromptUpdate
):
"""
Defines how to insert placeholder tokens into a prompt.
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token:
.. code-block:: python
PromptInsertion(
modality="image",
target="<s>",
insertion="<image>" * image_feature_size,
)
"""
insertion
:
PromptUpdateContent
=
field
(
repr
=
False
)
"""
Given the index of the processed item within :attr:`modality`,
output the token sequence (or text) to insert right after :attr:`target`.
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
@
property
def
content
(
self
)
->
PromptUpdateContent
:
return
self
.
insertion
@
property
def
mode
(
self
)
->
UpdateMode
:
return
UpdateMode
.
INSERT
@
dataclass
class
PromptReplacement
(
PromptUpdate
):
"""
"""
Defines how to replace portions of an input prompt with placeholder tokens.
Defines how to replace portions of an input prompt with placeholder tokens.
...
@@ -93,7 +190,7 @@ class PromptReplacement:
...
@@ -93,7 +190,7 @@ class PromptReplacement:
PromptReplacement(
PromptReplacement(
modality="image",
modality="image",
target="<image>",
target="<image>",
replacement=Prompt
Replacement
Details(
replacement=Prompt
Update
Details(
full="".join([
full="".join([
"<image_bos>",
"<image_bos>",
"<image>" * image_feature_size,
"<image>" * image_feature_size,
...
@@ -111,7 +208,7 @@ class PromptReplacement:
...
@@ -111,7 +208,7 @@ class PromptReplacement:
PromptReplacement(
PromptReplacement(
modality="image",
modality="image",
target=[image_token_id],
target=[image_token_id],
replacement=Prompt
Replacement
Details(
replacement=Prompt
Update
Details(
full=([image_bos_id] + [image_token_id] * image_feature_size
full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]),
+ [image_eos_id]),
features=[image_token_id] * image_feature_size,
features=[image_token_id] * image_feature_size,
...
@@ -119,29 +216,22 @@ class PromptReplacement:
...
@@ -119,29 +216,22 @@ class PromptReplacement:
)
)
"""
"""
modality
:
str
replacement
:
PromptUpdateContent
=
field
(
repr
=
False
)
"""The modality for which the replacement is made."""
target
:
PromptSeq
"""The token sequence (or text) to find and replace."""
replacement
:
Union
[
Callable
[[
int
],
PromptRepl
],
PromptRepl
]
=
field
(
repr
=
False
)
"""
"""
Given the index of the processed item within :attr:`modality`,
Given the index of the processed item within :attr:`modality`,
output the
replacement
token sequence (or text).
output the token sequence (or text)
to replace :attr:`target`
.
For convenience, you can directly pass in the
replacement
token sequence
For convenience, you can directly pass in the token sequence
(or text)
(or text)
instead of a function if it does not depend on the input.
instead of a function if it does not depend on the input.
"""
"""
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptReplacement"
:
@
property
return
BoundPromptReplacem
ent
(
def
content
(
self
)
->
PromptUpdateCont
ent
:
tokenizer
=
tokenizer
,
return
self
.
replacement
modality
=
self
.
modality
,
_target
=
self
.
target
,
@
property
_replacement
=
self
.
replacement
,
def
mode
(
self
)
->
UpdateMode
:
)
return
UpdateMode
.
REPLACE
@
lru_cache
(
maxsize
=
2048
)
@
lru_cache
(
maxsize
=
2048
)
...
@@ -232,64 +322,73 @@ class _BoundPromptSequence:
...
@@ -232,64 +322,73 @@ class _BoundPromptSequence:
@
dataclass
@
dataclass
class
_BoundPrompt
ReplacementGroup
:
class
_BoundPrompt
Content
:
full
:
_BoundPromptSequence
full
:
_BoundPromptSequence
features
:
_BoundPromptSequence
features
:
_BoundPromptSequence
@
dataclass
@
dataclass
class
BoundPrompt
Replacement
:
class
BoundPrompt
Update
:
"""
"""
A :class:`Prompt
Replacement
` bound to a tokenizer to automatically
A :class:`Prompt
Update
` bound to a tokenizer to automatically
convert
convert
:attr:`target` and the result of :meth:`get_
replacem
ent` between
:attr:`target` and the result of :meth:`get_
cont
ent` between
token sequence and text representations.
token sequence and text representations.
"""
"""
_origin
:
PromptUpdate
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
modality
:
str
_target
:
PromptSeq
_replacement
:
Union
[
Callable
[[
int
],
PromptRepl
],
PromptRepl
]
=
field
(
repr
=
False
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
self
.
_replacement_cache
=
dict
[
int
,
_BoundPromptReplacementGroup
]()
self
.
_content_cache
=
dict
[
int
,
_BoundPromptContent
]()
@
property
def
modality
(
self
)
->
str
:
return
self
.
_origin
.
modality
@
property
@
property
def
target
(
self
)
->
_BoundPromptSequence
:
def
target
(
self
)
->
_BoundPromptSequence
:
"""The token sequence (or text) to find and replace."""
"""The token sequence (or text) to update."""
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
self
.
_target
)
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
self
.
_origin
.
target
)
def
get_replacement
(
self
,
item_idx
:
int
)
->
_BoundPromptReplacementGroup
:
@
property
def
content
(
self
)
->
PromptUpdateContent
:
"""The placeholder tokens that are part of the update."""
return
self
.
_origin
.
content
@
property
def
mode
(
self
)
->
UpdateMode
:
"""Defines how to update the prompt."""
return
self
.
_origin
.
mode
def
get_content
(
self
,
item_idx
:
int
)
->
_BoundPromptContent
:
"""
"""
Given the index of the processed item within :attr:`modality`,
Given the index of the processed item within :attr:`modality`,
output the
replacement
token sequence (or text).
output the token sequence (or text)
to update
.
"""
"""
replacem
ent
=
self
.
_replacem
ent
cont
ent
=
self
.
cont
ent
if
callable
(
replacem
ent
):
if
callable
(
cont
ent
):
cache_key
=
item_idx
cache_key
=
item_idx
if
cache_key
in
self
.
_
replacem
ent_cache
:
if
cache_key
in
self
.
_
cont
ent_cache
:
return
self
.
_
replacem
ent_cache
[
cache_key
]
return
self
.
_
cont
ent_cache
[
cache_key
]
replacement
=
replacem
ent
(
item_idx
)
content
=
cont
ent
(
item_idx
)
else
:
else
:
cache_key
=
None
cache_key
=
None
if
not
isinstance
(
replacem
ent
,
Prompt
Replacement
Details
):
if
not
isinstance
(
cont
ent
,
Prompt
Update
Details
):
replacem
ent
=
Prompt
Replacement
Details
.
from_seq
(
replacem
ent
)
cont
ent
=
Prompt
Update
Details
.
from_seq
(
cont
ent
)
bound_full
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
bound_full
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
replacem
ent
.
full
)
cont
ent
.
full
)
bound_features
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
bound_features
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
replacement
.
features
)
content
.
features
)
bound_replacement
=
_BoundPromptReplacementGroup
(
bound_content
=
_BoundPromptContent
(
full
=
bound_full
,
full
=
bound_full
,
features
=
bound_features
)
features
=
bound_features
,
)
if
cache_key
is
not
None
:
if
cache_key
is
not
None
:
self
.
_
replacem
ent_cache
[
cache_key
]
=
bound_
replacem
ent
self
.
_
cont
ent_cache
[
cache_key
]
=
bound_
cont
ent
return
bound_
replacem
ent
return
bound_
cont
ent
class
_TokenMatch
(
NamedTuple
):
class
_TokenMatch
(
NamedTuple
):
...
@@ -326,12 +425,12 @@ def iter_token_matches(
...
@@ -326,12 +425,12 @@ def iter_token_matches(
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tMatch
(
ABC
):
class
_Prompt
Targe
tMatch
(
ABC
):
prompt_repl
:
BoundPrompt
Replacement
_origin
:
BoundPrompt
Update
@
property
@
property
def
modality
(
self
)
->
str
:
def
modality
(
self
)
->
str
:
return
self
.
prompt_repl
.
modality
return
self
.
_origin
.
modality
@
property
@
property
@
abstractmethod
@
abstractmethod
...
@@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
...
@@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tTokenMatch
(
_Prompt
Replacemen
tMatch
):
class
_Prompt
Targe
tTokenMatch
(
_Prompt
Targe
tMatch
):
match
:
_TokenMatch
match
:
_TokenMatch
@
property
@
property
...
@@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
...
@@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tTextMatch
(
_Prompt
Replacemen
tMatch
):
class
_Prompt
Targe
tTextMatch
(
_Prompt
Targe
tMatch
):
match
:
re
.
Match
[
str
]
match
:
re
.
Match
[
str
]
@
property
@
property
...
@@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
...
@@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
def
find_token_matches
(
def
find_token_matches
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
prompt_
repl
s
:
Sequence
[
BoundPrompt
Replacement
],
prompt_
update
s
:
Sequence
[
BoundPrompt
Update
],
)
->
list
[
_PromptReplacementToken
Match
]:
)
->
Sequence
[
_PromptTarget
Match
]:
"""Return each target of :code:`prompt_
repl
s` found in :code:`prompt`."""
"""Return each target of :code:`prompt_
update
s` found in :code:`prompt`."""
return
[
return
[
_PromptReplacementTokenMatch
(
prompt_repl
,
match
)
_PromptTargetTokenMatch
(
update
,
match
)
for
update
in
prompt_updates
for
prompt_repl
in
prompt_repls
for
match
in
iter_token_matches
(
prompt
,
update
.
target
.
token_ids
)
for
match
in
iter_token_matches
(
prompt
,
prompt_repl
.
target
.
token_ids
)
]
]
def
find_text_matches
(
def
find_text_matches
(
prompt
:
str
,
prompt
:
str
,
prompt_
repl
s
:
Sequence
[
BoundPrompt
Replacement
],
prompt_
update
s
:
Sequence
[
BoundPrompt
Update
],
)
->
list
[
_PromptReplacementTex
tMatch
]:
)
->
Sequence
[
_PromptTarge
tMatch
]:
"""Return each target of :code:`prompt_
repl
s` found in :code:`prompt`."""
"""Return each target of :code:`prompt_
update
s` found in :code:`prompt`."""
return
[
return
[
_PromptReplacementTextMatch
(
prompt_repl
,
match
)
_PromptTargetTextMatch
(
update
,
match
)
for
update
in
prompt_updates
for
prompt_repl
in
prompt_repls
for
match
in
re
.
finditer
(
re
.
escape
(
update
.
target
.
text
),
prompt
)
for
match
in
re
.
finditer
(
re
.
escape
(
prompt_repl
.
target
.
text
),
prompt
)
]
]
def
_resolve_matches
(
def
_resolve_matches
(
prompt
:
PromptSeq
,
prompt
:
PromptSeq
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Replacemen
tMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
)
->
list
[
_Prompt
Replacemen
tMatch
]:
)
->
list
[
_Prompt
Targe
tMatch
]:
"""
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
and sort them such that earlier matches take priority over later ones.
"""
"""
matches
=
[
m
for
matches
in
mm_matches
.
values
()
for
m
in
matches
]
matches
=
[
m
for
matches
in
mm_matches
.
values
()
for
m
in
matches
]
seen_matches
:
list
[
Optional
[
_PromptReplacementMatch
]]
=
[
None
seen_matches
:
list
[
Optional
[
_PromptTargetMatch
]]
=
[
None
]
*
len
(
prompt
)
]
*
len
(
prompt
)
for
match
in
matches
:
for
match
in
matches
:
for
idx
in
range
(
match
.
start_idx
,
match
.
end_idx
):
for
idx
in
range
(
match
.
start_idx
,
match
.
end_idx
):
...
@@ -441,74 +537,91 @@ def _resolve_matches(
...
@@ -441,74 +537,91 @@ def _resolve_matches(
return
sorted
(
matches
,
key
=
lambda
x
:
x
.
start_idx
)
return
sorted
(
matches
,
key
=
lambda
x
:
x
.
start_idx
)
def
_
replace
_matches
(
def
_
apply
_matches
(
prompt
:
_S
,
prompt
:
_S
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Replacemen
tMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
_S
]:
)
->
list
[
_S
]:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
out_seqs
=
list
[
_S
]()
out_seqs
=
list
[
Union
[
str
,
list
[
int
]]
]()
prev_end_idx
=
0
prev_end_idx
=
0
next_idx_by_modality
=
defaultdict
[
str
,
int
](
lambda
:
0
)
next_idx_by_modality
=
defaultdict
[
str
,
int
](
lambda
:
0
)
for
match
in
_resolve_matches
(
prompt
,
mm_matches
):
for
(
start_idx
,
end_idx
),
group
in
groupby
(
modality
=
match
.
modality
_resolve_matches
(
prompt
,
mm_matches
),
key
=
lambda
x
:
(
x
.
start_idx
,
x
.
end_idx
),
):
matches
=
tuple
(
group
)
assert
len
(
matches
)
==
1
item_idx
=
next_idx_by_modality
[
modality
]
for
match
in
matches
:
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
modality
=
match
.
modality
continue
start_idx
=
match
.
start_idx
item_idx
=
next_idx_by_modality
[
modality
]
end_idx
=
match
.
end_idx
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
repl_info
=
match
.
prompt_repl
origin
=
match
.
_origin
replacement
=
repl_info
.
get_replacement
(
item_idx
)
content
=
origin
.
get_content
(
item_idx
)
mode
=
origin
.
mode
if
isinstance
(
prompt
,
str
):
if
mode
==
UpdateMode
.
INSERT
:
repl_seq
=
replacement
.
full
.
text
out_seqs
.
append
(
prompt
[
prev_end_idx
:
end_idx
])
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
]
+
repl_seq
)
num_inserts
=
mm_item_counts
.
get
(
modality
,
0
)
else
:
elif
mode
==
UpdateMode
.
REPLACE
:
repl_seq
=
replacement
.
full
.
token_ids
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
])
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
]
+
repl_seq
)
num_inserts
=
1
else
:
assert_never
(
mode
)
prev_end_idx
=
end_idx
for
_
in
range
(
num_inserts
):
next_idx_by_modality
[
modality
]
+=
1
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
if
isinstance
(
prompt
,
str
):
out_seqs
.
append
(
content
.
full
.
text
)
else
:
out_seqs
.
append
(
content
.
full
.
token_ids
)
next_idx_by_modality
[
modality
]
+=
1
prev_end_idx
=
end_idx
out_seqs
.
append
(
prompt
[
prev_end_idx
:])
out_seqs
.
append
(
prompt
[
prev_end_idx
:])
return
out_seqs
return
cast
(
list
[
_S
],
out_seqs
)
def
replace
_token_matches
(
def
apply
_token_matches
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
ReplacementToken
Match
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Target
Match
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
int
]:
)
->
list
[
int
]:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
if
not
mm_matches
:
if
not
mm_matches
:
return
prompt
return
prompt
token_id_seqs
=
_
replace
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
token_id_seqs
=
_
apply
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
return
flatten_2d_lists
(
token_id_seqs
)
return
flatten_2d_lists
(
token_id_seqs
)
def
replace
_text_matches
(
def
apply
_text_matches
(
prompt
:
str
,
prompt
:
str
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
ReplacementTex
tMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
str
:
)
->
str
:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
if
not
mm_matches
:
if
not
mm_matches
:
return
prompt
return
prompt
texts
=
_
replace
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
texts
=
_
apply
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
return
""
.
join
(
texts
)
return
""
.
join
(
texts
)
def
_iter_placeholders
(
def
_iter_placeholders
(
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Iterable
[
PlaceholderFeaturesInfo
]:
)
->
Iterable
[
PlaceholderFeaturesInfo
]:
...
@@ -517,7 +630,7 @@ def _iter_placeholders(
...
@@ -517,7 +630,7 @@ def _iter_placeholders(
Matches are exclusive even when multiple modalities share
Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_
repl
s` takes priority.
appears earlier in `mm_prompt_
update
s` takes priority.
Note that empty matches are ignored.
Note that empty matches are ignored.
"""
"""
...
@@ -528,37 +641,37 @@ def _iter_placeholders(
...
@@ -528,37 +641,37 @@ def _iter_placeholders(
while
start_idx
<
prompt_len
:
while
start_idx
<
prompt_len
:
found
=
False
found
=
False
for
modality
,
modality_
repl
s
in
mm_prompt_
repl
s
.
items
():
for
modality
,
modality_
update
s
in
mm_prompt_
update
s
.
items
():
item_idx
=
item_idx_by_modality
[
modality
]
item_idx
=
item_idx_by_modality
[
modality
]
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
continue
for
repl
_info
in
modality_
repl
s
:
for
update
_info
in
modality_
update
s
:
replacement
=
repl
_info
.
get_
replacem
ent
(
item_idx
)
content
=
update
_info
.
get_
cont
ent
(
item_idx
)
repl
_tokens_full
=
replacem
ent
.
full
.
token_ids
content
_tokens_full
=
cont
ent
.
full
.
token_ids
repl
_len_full
=
len
(
repl
_tokens_full
)
content
_len_full
=
len
(
content
_tokens_full
)
end_idx_full
=
start_idx
+
repl
_len_full
end_idx_full
=
start_idx
+
content
_len_full
if
repl
_len_full
==
0
or
end_idx_full
>
prompt_len
:
if
content
_len_full
==
0
or
end_idx_full
>
prompt_len
:
continue
continue
if
prompt
[
start_idx
:
end_idx_full
]
==
repl
_tokens_full
:
if
prompt
[
start_idx
:
end_idx_full
]
==
content
_tokens_full
:
repl
_tokens_feat
=
replacem
ent
.
features
.
token_ids
content
_tokens_feat
=
cont
ent
.
features
.
token_ids
try
:
try
:
match
=
next
(
match
=
next
(
iter_token_matches
(
repl
_tokens_full
,
iter_token_matches
(
content
_tokens_full
,
repl
_tokens_feat
))
content
_tokens_feat
))
yield
PlaceholderFeaturesInfo
(
yield
PlaceholderFeaturesInfo
(
modality
=
modality
,
modality
=
modality
,
item_idx
=
item_idx
,
item_idx
=
item_idx
,
start_idx
=
start_idx
+
match
.
start_idx
,
start_idx
=
start_idx
+
match
.
start_idx
,
tokens
=
repl
_tokens_feat
,
tokens
=
content
_tokens_feat
,
)
)
except
StopIteration
:
except
StopIteration
:
raise
AssertionError
(
raise
AssertionError
(
f
"
{
repl
_tokens_feat
=
}
should be a "
f
"
{
content
_tokens_feat
=
}
should be a "
f
"subsequence of
{
repl
_tokens_full
=
}
"
)
from
None
f
"subsequence of
{
content
_tokens_full
=
}
"
)
from
None
# Exclude overlapping matches
# Exclude overlapping matches
start_idx
=
end_idx_full
start_idx
=
end_idx_full
...
@@ -574,11 +687,11 @@ def _iter_placeholders(
...
@@ -574,11 +687,11 @@ def _iter_placeholders(
def
find_mm_placeholders
(
def
find_mm_placeholders
(
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
it
=
_iter_placeholders
(
mm_prompt_
repl
s
,
prompt
,
mm_item_counts
)
it
=
_iter_placeholders
(
mm_prompt_
update
s
,
prompt
,
mm_item_counts
)
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
...
@@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
enable_sanity_checks
:
bool
=
True
)
->
None
:
if
get_repls
:
=
getattr
(
self
,
"_get_prompt_replacements"
,
None
):
logger
.
warning_once
(
"`_get_prompt_replacements` has been renamed "
"to `_get_prompt_updates`. The old name will "
"be removed in an upcoming release."
)
self
.
_get_prompt_updates
=
get_repls
# type: ignore[method-assign]
super
().
__init__
()
super
().
__init__
()
self
.
info
=
info
self
.
info
=
info
...
@@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
Prompt
Replacement
]:
)
->
list
[
Prompt
Update
]:
"""
"""
Given the original multi-modal items for this modality
Given the original multi-modal items for this modality
and HF-processed data, output the
replacement
s to perform.
and HF-processed data, output the
update
s to perform.
Notes:
Notes:
- You should not assume that HF processor always performs prompt
- You should not assume that HF processor always performs prompt
replacement
: in :meth:`_apply_hf_processor_missing`, this method
updates
: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately,
is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
instead of passing them in the same call.
- The
replacement
information returned by this method is also used
- The
update
information returned by this method is also used
to
to
determine the placeholder token positions for each multi-modal
determine the placeholder token positions for each multi-modal
item.
item.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
_find_mm_placeholders
(
def
_find_mm_placeholders
(
self
,
self
,
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
new_token_ids
:
list
[
int
],
new_token_ids
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
return
find_mm_placeholders
(
mm_prompt_
repl
s
,
new_token_ids
,
return
find_mm_placeholders
(
mm_prompt_
update
s
,
new_token_ids
,
mm_item_counts
)
mm_item_counts
)
def
_get_hf_mm_data
(
def
_get_hf_mm_data
(
...
@@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
,
mm_kwargs
,
)
)
def
_hf_processor_applies_
repl
(
def
_hf_processor_applies_
updates
(
self
,
self
,
prompt_text
:
str
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
)
->
bool
:
"""
"""
Return whether the HF processor applies prompt
replacement
s.
Return whether the HF processor applies prompt
update
s.
For most HF processors, this should be :code:`True` when multi-modal
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
data items are passed, but :code:`False` when multi-modal embeddings
...
@@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Apply the HF processor on the prompt text and multi-modal data
Apply the HF processor on the prompt text and multi-modal data
together.
together.
In addition, return whether prompt
replacement
s have been applied.
In addition, return whether prompt
update
s have been applied.
"""
"""
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
...
@@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
),
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
),
)
)
is_
repl
_applied
=
self
.
_hf_processor_applies_
repl
(
is_
update
_applied
=
self
.
_hf_processor_applies_
updates
(
prompt_text
=
prompt_text
,
prompt_text
=
prompt_text
,
mm_items
=
mm_items
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
)
return
prompt_ids
,
mm_kwargs
,
is_
repl
_applied
return
prompt_ids
,
mm_kwargs
,
is_
update
_applied
def
_apply_hf_processor_text_only
(
self
,
prompt_text
:
str
)
->
list
[
int
]:
def
_apply_hf_processor_text_only
(
self
,
prompt_text
:
str
)
->
list
[
int
]:
"""
"""
...
@@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
*
,
*
,
enable_hf_prompt_
replacement
:
bool
,
enable_hf_prompt_
update
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
"""
"""
Apply the HF processor on the prompt text and multi-modal data.
Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt
replacement
s have been applied
In addition, return whether prompt
update
s have been applied
(for most HF processors, this should be :code:`True`).
(for most HF processors, this should be :code:`True`).
Note:
Note:
If :code:`enable_hf_prompt_
replacement
=False`, we use HF processor
If :code:`enable_hf_prompt_
update
=False`, we use HF processor
to perform prompt
replacement
if available; HF processor requires
to perform prompt
updates
if available; HF processor requires
that the prompt corresponds to multi-modal items.
that the prompt corresponds to multi-modal items.
"""
"""
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
if
enable_hf_prompt_
replacement
:
if
enable_hf_prompt_
update
:
return
self
.
_apply_hf_processor_text_mm
(
return
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
prompt
,
prompt_text
=
prompt
,
mm_items
=
mm_items
,
mm_items
=
mm_items
,
...
@@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_
replacement
=
True
,
enable_hf_prompt_
update
=
True
,
)
)
mm_maybe_cached_kw_items
=
{
mm_maybe_cached_kw_items
=
{
...
@@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt
replacement
s until the new multimodal
# so we can't apply prompt
update
s until the new multimodal
# items are combined with the cached multimodal items
# items are combined with the cached multimodal items
(
(
prompt_ids
,
prompt_ids
,
mm_missing_kwargs
,
mm_missing_kwargs
,
is_
repl
_applied
,
is_
update
_applied
,
)
=
self
.
_apply_hf_processor_main
(
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_missing_data_items
,
mm_items
=
mm_missing_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_
replacement
=
False
,
enable_hf_prompt_
update
=
False
,
)
)
mm_missing_next_idx
=
{
mm_missing_next_idx
=
{
...
@@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
return
prompt_ids
,
mm_kwargs
,
is_
repl
_applied
return
prompt_ids
,
mm_kwargs
,
is_
update
_applied
def
_bind_and_group_
repl
s
(
def
_bind_and_group_
update
s
(
self
,
self
,
prompt_
repl
s
:
list
[
Prompt
Replacement
],
prompt_
update
s
:
list
[
Prompt
Update
],
)
->
dict
[
str
,
list
[
BoundPrompt
Replacement
]]:
)
->
dict
[
str
,
list
[
BoundPrompt
Update
]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_
repl
s
)
it
=
(
update
.
bind
(
tokenizer
)
for
update
in
prompt_
update
s
)
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
def
_apply_prompt_
replacement
s
(
def
_apply_prompt_
update
s
(
self
,
self
,
token_ids
:
list
[
int
],
token_ids
:
list
[
int
],
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
mm_token_matches
=
{
mm_token_matches
=
{
modality
:
find_token_matches
(
token_ids
,
prompt_repl
s
)
modality
:
find_token_matches
(
token_ids
,
update
s
)
for
modality
,
prompt_repl
s
in
mm_prompt_
repl
s
.
items
()
for
modality
,
update
s
in
mm_prompt_
update
s
.
items
()
}
}
mm_match_counts
=
{
mm_match_counts
=
{
modality
:
len
(
matches
)
modality
:
len
(
matches
)
...
@@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# up a token, then the token ID of "foo" will not appear at all
# up a token, then the token ID of "foo" will not appear at all
# ----
# ----
# Since it is inefficient to search for all possible tokenizations
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
# of the search text in the prompt, we instead perform string
-based
#
replacement
on the decoded token IDs, then encode them back.
#
updates
on the decoded token IDs, then encode them back.
if
all
(
if
all
(
mm_match_counts
.
get
(
modality
,
0
)
>=
item_count
mm_match_counts
.
get
(
modality
,
0
)
>=
item_count
for
modality
,
item_count
in
mm_item_counts
.
items
()
for
modality
,
item_count
in
mm_item_counts
.
items
()
):
# yapf: disable
):
# yapf: disable
token_ids
=
replace
_token_matches
(
token_ids
=
apply
_token_matches
(
token_ids
,
token_ids
,
mm_token_matches
,
mm_token_matches
,
mm_item_counts
,
mm_item_counts
,
)
)
text
=
decode_tokens
(
tokenizer
,
token_ids
)
text
=
decode_tokens
(
tokenizer
,
token_ids
)
matched_
repl
s
=
{
matched_
update
s
=
{
modality
:
[
match
.
prompt_repl
for
match
in
token_matches
]
modality
:
[
match
.
_origin
for
match
in
token_matches
]
for
modality
,
token_matches
in
mm_token_matches
.
items
()
for
modality
,
token_matches
in
mm_token_matches
.
items
()
}
}
else
:
else
:
text
=
decode_tokens
(
tokenizer
,
token_ids
)
text
=
decode_tokens
(
tokenizer
,
token_ids
)
mm_text_matches
=
{
mm_text_matches
=
{
modality
:
find_text_matches
(
text
,
prompt_repl
s
)
modality
:
find_text_matches
(
text
,
update
s
)
for
modality
,
prompt_repl
s
in
mm_prompt_
repl
s
.
items
()
for
modality
,
update
s
in
mm_prompt_
update
s
.
items
()
}
}
text
=
replace
_text_matches
(
text
=
apply
_text_matches
(
text
,
text
,
mm_text_matches
,
mm_text_matches
,
mm_item_counts
,
mm_item_counts
,
...
@@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids
=
encode_tokens
(
tokenizer
,
token_ids
=
encode_tokens
(
tokenizer
,
text
,
text
,
add_special_tokens
=
False
)
add_special_tokens
=
False
)
matched_
repl
s
=
{
matched_
update
s
=
{
modality
:
[
match
.
prompt_repl
for
match
in
token_matches
]
modality
:
[
match
.
_origin
for
match
in
token_matches
]
for
modality
,
token_matches
in
mm_text_matches
.
items
()
for
modality
,
token_matches
in
mm_text_matches
.
items
()
}
}
placeholders
=
self
.
_find_mm_placeholders
(
placeholders
=
self
.
_find_mm_placeholders
(
matched_
repl
s
,
matched_
update
s
,
token_ids
,
token_ids
,
mm_item_counts
,
mm_item_counts
,
)
)
...
@@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if
len
(
placeholders
)
!=
item_count
:
if
len
(
placeholders
)
!=
item_count
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Expected there to be
{
item_count
}
prompt
replacement
s "
f
"Expected there to be
{
item_count
}
prompt
update
s "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"instead found
{
len
(
placeholders
)
}
prompt
replacement
s! "
f
"instead found
{
len
(
placeholders
)
}
prompt
update
s! "
"Either the prompt text has missing/incorrect tokens for "
"Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_
replacement
s`)."
)
"`_call_hf_processor` and `_get_prompt_
update
s`)."
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
1. Apply HF Processor on prompt text and multi-modal data together,
1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
outputting token IDs and processed tensors.
2. Find and
replac
e sequences in the token IDs with placeholder tokens.
2. Find and
updat
e sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
3. Extract information about the placeholder tokens from the
...
@@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
(
prompt_ids
,
prompt_ids
,
mm_kwargs
,
mm_kwargs
,
is_
repl
_applied
,
is_
update
_applied
,
)
=
self
.
_cached_apply_hf_processor
(
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
prompt
,
mm_items
,
mm_items
,
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
)
)
unbound_prompt_
repl
s
=
self
.
_get_prompt_
replacement
s
(
unbound_prompt_
update
s
=
self
.
_get_prompt_
update
s
(
mm_items
,
mm_items
,
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
mm_kwargs
,
mm_kwargs
,
)
)
mm_prompt_repls
=
self
.
_bind_and_group_repls
(
unbound_prompt_repls
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_item_counts
=
mm_items
.
get_all_counts
()
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
if
is_
repl
_applied
:
if
is_
update
_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
mm_placeholders
=
self
.
_find_mm_placeholders
(
mm_prompt_
repl
s
,
mm_prompt_
update
s
,
prompt_ids
,
prompt_ids
,
mm_item_counts
,
mm_item_counts
,
)
)
...
@@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids
,
prompt_ids
,
prompt
,
prompt
,
mm_placeholders
,
mm_placeholders
,
)
=
self
.
_apply_prompt_
replacement
s
(
)
=
self
.
_apply_prompt_
update
s
(
prompt_ids
,
prompt_ids
,
mm_prompt_
repl
s
,
mm_prompt_
update
s
,
mm_item_counts
,
mm_item_counts
,
)
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment