Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1579b22
Unverified
Commit
f1579b22
authored
Feb 28, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 27, 2025
Browse files
[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
78648758
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
363 additions
and
245 deletions
+363
-245
vllm/model_executor/models/nvlm_d.py
vllm/model_executor/models/nvlm_d.py
+7
-6
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+10
-11
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+9
-15
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+6
-6
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+9
-7
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+8
-7
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+6
-5
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+5
-5
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+303
-183
No files found.
vllm/model_executor/models/nvlm_d.py
View file @
f1579b22
...
...
@@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from
typing
import
Mapping
,
Optional
from
collections.abc
import
Mapping
,
Sequence
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
Prompt
Replacement
Details
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
Prompt
Update
Details
)
from
vllm.multimodal.profiling
import
ProcessorInputs
from
.intern_vit
import
InternVisionModel
...
...
@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
if
"image_num_patches"
in
out_mm_kwargs
:
...
...
@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if
num_patches
is
not
None
:
assert
isinstance
(
num_patches
,
int
)
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
hf_processor
.
get_image_repl_full
(
feature_size
,
num_patches
)
+
"
\n
"
,
features
=
hf_processor
.
get_image_repl_features
(
...
...
vllm/model_executor/models/phi3v.py
View file @
f1579b22
...
...
@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BoundPromptReplacement
,
BaseProcessingInfo
,
BoundPromptUpdate
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
Prompt
Replacement
Details
)
PromptReplacement
,
PromptUpdate
,
Prompt
Update
Details
)
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_tokens
:
list
[
str
]
=
hf_processor
.
img_tokens
# type: ignore
...
...
@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens
=
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
image_tokens
+
[
bos_token_id
],
features
=
image_tokens
,
)
...
...
@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
)
for
image_token
in
image_tokens
[:
num_images
]
]
def
_apply_prompt_
replacement
s
(
def
_apply_prompt_
update
s
(
self
,
token_ids
:
list
[
int
],
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
token_ids
,
text
,
placeholders
=
super
().
_apply_prompt_
replacement
s
(
token_ids
,
text
,
placeholders
=
super
().
_apply_prompt_
update
s
(
token_ids
=
token_ids
,
mm_prompt_
repl
s
=
mm_prompt_
repl
s
,
mm_prompt_
update
s
=
mm_prompt_
update
s
,
mm_item_counts
=
mm_item_counts
,
)
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
f1579b22
...
...
@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from
typing
import
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
Union
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
Prompt
Replacement
)
BaseProcessingInfo
,
Prompt
Update
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
PoolingSequenceGroupOutput
)
...
...
@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
pass
return
{
"image"
:
0
}
class
PrithviGeoSpatialMAEInputBuilder
(
...
...
@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
pass
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
pass
)
->
Sequence
[
PromptUpdate
]:
return
[]
def
apply
(
self
,
...
...
@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
):
""" Prithvi Masked Autoencoder"""
def
_instantiate_model
(
self
,
config
:
dict
)
->
nn
.
Module
|
None
:
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]
:
# We might be able/need to support different tasks with this same model
if
config
[
"task_args"
][
"task"
]
==
"SemanticSegmentationTask"
:
...
...
@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE."
)
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
self
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
f1579b22
...
...
@@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
Any
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
Prompt
Replacement
Details
)
Prompt
Update
,
PromptUpdate
Details
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
...
...
@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens
=
[
audio_token_id
]
*
num_features
return
Prompt
Replacement
Details
(
return
Prompt
Update
Details
(
full
=
[
audio_bos_id
]
+
audio_tokens
+
[
audio_eos_id
],
features
=
audio_tokens
,
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
f1579b22
...
...
@@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
,
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Se
t
,
Tuple
,
Type
,
TypedDict
,
Union
)
from
typing
import
(
Any
,
Callable
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDic
t
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self
,
in_features
:
int
,
hidden_features
:
int
,
act_layer
:
T
ype
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
t
ype
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
...
...
@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
,
act_layer
:
T
ype
[
nn
.
Module
]
=
QuickGELU
,
act_layer
:
t
ype
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
...
...
@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_processor
=
self
.
info
.
get_image_processor
(
**
hf_processor_mm_kwargs
)
...
...
vllm/model_executor/models/qwen_vl.py
View file @
f1579b22
...
...
@@ -9,9 +9,10 @@ import copy
import
math
import
re
import
unicodedata
from
collections.abc
import
Collection
,
Mapping
,
Sequence
from
collections.abc
import
Set
as
AbstractSet
from
functools
import
lru_cache
,
partial
from
typing
import
(
AbstractSet
,
Callable
,
Collection
,
List
,
Literal
,
Mapping
,
Optional
,
TypedDict
,
Union
)
from
typing
import
Callable
,
List
,
Literal
,
Optional
,
TypedDict
,
Union
import
torch
from
torch
import
nn
...
...
@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
Prompt
Replacement
Details
)
Prompt
Update
,
PromptUpdate
Details
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs
=
mm_kwargs
,
)
def
_hf_processor_applies_
repl
(
def
_hf_processor_applies_
updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
...
...
@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
tokenizer
=
self
.
info
.
get_tokenizer
()
special_tokens
:
dict
[
str
,
int
]
=
tokenizer
.
special_tokens
# type: ignore
...
...
@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement
(
modality
=
"image"
,
target
=
[
img_start_id
,
img_end_id
],
replacement
=
Prompt
Replacement
Details
(
replacement
=
Prompt
Update
Details
(
full
=
[
img_start_id
]
+
image_tokens
+
[
img_end_id
],
features
=
image_tokens
,
),
...
...
vllm/model_executor/models/ultravox.py
View file @
f1579b22
...
...
@@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
cached_property
from
typing
import
(
Any
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch.utils.checkpoint
...
...
@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors
)
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
...
...
@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
...
...
vllm/model_executor/models/whisper.py
View file @
f1579b22
# SPDX-License-Identifier: Apache-2.0
import
math
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
from
torch
import
nn
...
...
@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
...
...
@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
))
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
Sequence
[
PromptUpdate
]:
num_tokens
=
self
.
info
.
get_max_audio_tokens
()
return
[
PromptReplacement
(
...
...
vllm/multimodal/processing.py
View file @
f1579b22
...
...
@@ -6,11 +6,14 @@ from collections import defaultdict
from
collections.abc
import
(
Callable
,
Generator
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
)
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
functools
import
lru_cache
from
itertools
import
groupby
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
)
TypeVar
,
Union
,
cast
)
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.inputs
import
InputProcessingContext
...
...
@@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
@
dataclass
class
Prompt
Replacement
Details
:
"""Details about the
replacement
token sequence or text."""
class
Prompt
Update
Details
:
"""Details about the token sequence or text
that are part of the update
."""
full
:
PromptSeq
"""The full
replacem
ent."""
"""The full
cont
ent."""
features
:
PromptSeq
"""
The part of the
replacem
ent that corresponds to feature placeholders;
The part of the
cont
ent that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
inference.
"""
@
staticmethod
def
from_seq
(
seq
:
PromptSeq
)
->
"Prompt
Replacement
Details"
:
return
Prompt
Replacement
Details
(
full
=
seq
,
features
=
seq
)
def
from_seq
(
seq
:
PromptSeq
)
->
"Prompt
Update
Details"
:
return
Prompt
Update
Details
(
full
=
seq
,
features
=
seq
)
Prompt
Repl
=
Union
[
PromptSeq
,
Prompt
Replacement
Details
]
Prompt
UpdateInfo
=
Union
[
PromptSeq
,
Prompt
Update
Details
]
"""
The
replacement
token sequence or text.
The token sequence or text
that are part of the update
.
If only part of the
replacem
ent corresponds to feature placeholders, you can
use :class:`Prompt
Replacement
Details` to specify which part.
If only part of the
cont
ent corresponds to feature placeholders, you can
use :class:`Prompt
Update
Details` to specify which part.
"""
PromptUpdateContent
=
Union
[
Callable
[[
int
],
PromptUpdateInfo
],
PromptUpdateInfo
]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
class
UpdateMode
(
str
,
Enum
):
INSERT
=
"insert"
REPLACE
=
"replace"
@
dataclass
class
PromptUpdate
:
"""
Defines how to update a prompt with placeholder tokens.
"""
modality
:
str
"""The modality for which the update is made."""
target
:
PromptSeq
"""The token sequence (or text) to update."""
@
property
@
abstractmethod
def
content
(
self
)
->
PromptUpdateContent
:
"""The placeholder tokens that are part of the update."""
raise
NotImplementedError
@
property
@
abstractmethod
def
mode
(
self
)
->
UpdateMode
:
"""Defines how to update the prompt."""
raise
NotImplementedError
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptUpdate"
:
return
BoundPromptUpdate
(
_origin
=
self
,
tokenizer
=
tokenizer
,
)
@
dataclass
class
PromptReplacement
:
class
PromptInsertion
(
PromptUpdate
):
"""
Defines how to insert placeholder tokens into a prompt.
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token:
.. code-block:: python
PromptInsertion(
modality="image",
target="<s>",
insertion="<image>" * image_feature_size,
)
"""
insertion
:
PromptUpdateContent
=
field
(
repr
=
False
)
"""
Given the index of the processed item within :attr:`modality`,
output the token sequence (or text) to insert right after :attr:`target`.
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
@
property
def
content
(
self
)
->
PromptUpdateContent
:
return
self
.
insertion
@
property
def
mode
(
self
)
->
UpdateMode
:
return
UpdateMode
.
INSERT
@
dataclass
class
PromptReplacement
(
PromptUpdate
):
"""
Defines how to replace portions of an input prompt with placeholder tokens.
...
...
@@ -93,7 +190,7 @@ class PromptReplacement:
PromptReplacement(
modality="image",
target="<image>",
replacement=Prompt
Replacement
Details(
replacement=Prompt
Update
Details(
full="".join([
"<image_bos>",
"<image>" * image_feature_size,
...
...
@@ -111,7 +208,7 @@ class PromptReplacement:
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=Prompt
Replacement
Details(
replacement=Prompt
Update
Details(
full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]),
features=[image_token_id] * image_feature_size,
...
...
@@ -119,29 +216,22 @@ class PromptReplacement:
)
"""
modality
:
str
"""The modality for which the replacement is made."""
target
:
PromptSeq
"""The token sequence (or text) to find and replace."""
replacement
:
Union
[
Callable
[[
int
],
PromptRepl
],
PromptRepl
]
=
field
(
repr
=
False
)
replacement
:
PromptUpdateContent
=
field
(
repr
=
False
)
"""
Given the index of the processed item within :attr:`modality`,
output the
replacement
token sequence (or text).
output the token sequence (or text)
to replace :attr:`target`
.
For convenience, you can directly pass in the
replacement
token sequence
(or text)
instead of a function if it does not depend on the input.
For convenience, you can directly pass in the token sequence
(or text)
instead of a function if it does not depend on the input.
"""
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptReplacement"
:
return
BoundPromptReplacem
ent
(
tokenizer
=
tokenizer
,
modality
=
self
.
modality
,
_target
=
self
.
target
,
_replacement
=
self
.
replacement
,
)
@
property
def
content
(
self
)
->
PromptUpdateCont
ent
:
return
self
.
replacement
@
property
def
mode
(
self
)
->
UpdateMode
:
return
UpdateMode
.
REPLACE
@
lru_cache
(
maxsize
=
2048
)
...
...
@@ -232,64 +322,73 @@ class _BoundPromptSequence:
@
dataclass
class
_BoundPrompt
ReplacementGroup
:
class
_BoundPrompt
Content
:
full
:
_BoundPromptSequence
features
:
_BoundPromptSequence
@
dataclass
class
BoundPrompt
Replacement
:
class
BoundPrompt
Update
:
"""
A :class:`Prompt
Replacement
` bound to a tokenizer to automatically
convert
:attr:`target` and the result of :meth:`get_
replacem
ent` between
A :class:`Prompt
Update
` bound to a tokenizer to automatically
convert
:attr:`target` and the result of :meth:`get_
cont
ent` between
token sequence and text representations.
"""
_origin
:
PromptUpdate
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
modality
:
str
_target
:
PromptSeq
_replacement
:
Union
[
Callable
[[
int
],
PromptRepl
],
PromptRepl
]
=
field
(
repr
=
False
)
def
__post_init__
(
self
)
->
None
:
self
.
_replacement_cache
=
dict
[
int
,
_BoundPromptReplacementGroup
]()
self
.
_content_cache
=
dict
[
int
,
_BoundPromptContent
]()
@
property
def
modality
(
self
)
->
str
:
return
self
.
_origin
.
modality
@
property
def
target
(
self
)
->
_BoundPromptSequence
:
"""The token sequence (or text) to find and replace."""
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
self
.
_target
)
"""The token sequence (or text) to update."""
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
self
.
_origin
.
target
)
def
get_replacement
(
self
,
item_idx
:
int
)
->
_BoundPromptReplacementGroup
:
@
property
def
content
(
self
)
->
PromptUpdateContent
:
"""The placeholder tokens that are part of the update."""
return
self
.
_origin
.
content
@
property
def
mode
(
self
)
->
UpdateMode
:
"""Defines how to update the prompt."""
return
self
.
_origin
.
mode
def
get_content
(
self
,
item_idx
:
int
)
->
_BoundPromptContent
:
"""
Given the index of the processed item within :attr:`modality`,
output the
replacement
token sequence (or text).
output the token sequence (or text)
to update
.
"""
replacem
ent
=
self
.
_replacem
ent
if
callable
(
replacem
ent
):
cont
ent
=
self
.
cont
ent
if
callable
(
cont
ent
):
cache_key
=
item_idx
if
cache_key
in
self
.
_
replacem
ent_cache
:
return
self
.
_
replacem
ent_cache
[
cache_key
]
if
cache_key
in
self
.
_
cont
ent_cache
:
return
self
.
_
cont
ent_cache
[
cache_key
]
replacement
=
replacem
ent
(
item_idx
)
content
=
cont
ent
(
item_idx
)
else
:
cache_key
=
None
if
not
isinstance
(
replacem
ent
,
Prompt
Replacement
Details
):
replacem
ent
=
Prompt
Replacement
Details
.
from_seq
(
replacem
ent
)
if
not
isinstance
(
cont
ent
,
Prompt
Update
Details
):
cont
ent
=
Prompt
Update
Details
.
from_seq
(
cont
ent
)
bound_full
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
replacem
ent
.
full
)
cont
ent
.
full
)
bound_features
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
replacement
.
features
)
bound_replacement
=
_BoundPromptReplacementGroup
(
full
=
bound_full
,
features
=
bound_features
,
)
content
.
features
)
bound_content
=
_BoundPromptContent
(
full
=
bound_full
,
features
=
bound_features
)
if
cache_key
is
not
None
:
self
.
_
replacem
ent_cache
[
cache_key
]
=
bound_
replacem
ent
self
.
_
cont
ent_cache
[
cache_key
]
=
bound_
cont
ent
return
bound_
replacem
ent
return
bound_
cont
ent
class
_TokenMatch
(
NamedTuple
):
...
...
@@ -326,12 +425,12 @@ def iter_token_matches(
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tMatch
(
ABC
):
prompt_repl
:
BoundPrompt
Replacement
class
_Prompt
Targe
tMatch
(
ABC
):
_origin
:
BoundPrompt
Update
@
property
def
modality
(
self
)
->
str
:
return
self
.
prompt_repl
.
modality
return
self
.
_origin
.
modality
@
property
@
abstractmethod
...
...
@@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tTokenMatch
(
_Prompt
Replacemen
tMatch
):
class
_Prompt
Targe
tTokenMatch
(
_Prompt
Targe
tMatch
):
match
:
_TokenMatch
@
property
...
...
@@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
@
dataclass
(
repr
=
False
)
class
_Prompt
Replacemen
tTextMatch
(
_Prompt
Replacemen
tMatch
):
class
_Prompt
Targe
tTextMatch
(
_Prompt
Targe
tMatch
):
match
:
re
.
Match
[
str
]
@
property
...
...
@@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
def
find_token_matches
(
prompt
:
list
[
int
],
prompt_
repl
s
:
Sequence
[
BoundPrompt
Replacement
],
)
->
list
[
_PromptReplacementToken
Match
]:
"""Return each target of :code:`prompt_
repl
s` found in :code:`prompt`."""
prompt_
update
s
:
Sequence
[
BoundPrompt
Update
],
)
->
Sequence
[
_PromptTarget
Match
]:
"""Return each target of :code:`prompt_
update
s` found in :code:`prompt`."""
return
[
_PromptReplacementTokenMatch
(
prompt_repl
,
match
)
for
prompt_repl
in
prompt_repls
for
match
in
iter_token_matches
(
prompt
,
prompt_repl
.
target
.
token_ids
)
_PromptTargetTokenMatch
(
update
,
match
)
for
update
in
prompt_updates
for
match
in
iter_token_matches
(
prompt
,
update
.
target
.
token_ids
)
]
def
find_text_matches
(
prompt
:
str
,
prompt_
repl
s
:
Sequence
[
BoundPrompt
Replacement
],
)
->
list
[
_PromptReplacementTex
tMatch
]:
"""Return each target of :code:`prompt_
repl
s` found in :code:`prompt`."""
prompt_
update
s
:
Sequence
[
BoundPrompt
Update
],
)
->
Sequence
[
_PromptTarge
tMatch
]:
"""Return each target of :code:`prompt_
update
s` found in :code:`prompt`."""
return
[
_PromptReplacementTextMatch
(
prompt_repl
,
match
)
for
prompt_repl
in
prompt_repls
for
match
in
re
.
finditer
(
re
.
escape
(
prompt_repl
.
target
.
text
),
prompt
)
_PromptTargetTextMatch
(
update
,
match
)
for
update
in
prompt_updates
for
match
in
re
.
finditer
(
re
.
escape
(
update
.
target
.
text
),
prompt
)
]
def
_resolve_matches
(
prompt
:
PromptSeq
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Replacemen
tMatch
]],
)
->
list
[
_Prompt
Replacemen
tMatch
]:
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
)
->
list
[
_Prompt
Targe
tMatch
]:
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
matches
=
[
m
for
matches
in
mm_matches
.
values
()
for
m
in
matches
]
seen_matches
:
list
[
Optional
[
_PromptReplacementMatch
]]
=
[
None
]
*
len
(
prompt
)
seen_matches
:
list
[
Optional
[
_PromptTargetMatch
]]
=
[
None
]
*
len
(
prompt
)
for
match
in
matches
:
for
idx
in
range
(
match
.
start_idx
,
match
.
end_idx
):
...
...
@@ -441,74 +537,91 @@ def _resolve_matches(
return
sorted
(
matches
,
key
=
lambda
x
:
x
.
start_idx
)
def
_
replace
_matches
(
def
_
apply
_matches
(
prompt
:
_S
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Replacemen
tMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
_S
]:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
out_seqs
=
list
[
_S
]()
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
out_seqs
=
list
[
Union
[
str
,
list
[
int
]]
]()
prev_end_idx
=
0
next_idx_by_modality
=
defaultdict
[
str
,
int
](
lambda
:
0
)
for
match
in
_resolve_matches
(
prompt
,
mm_matches
):
modality
=
match
.
modality
for
(
start_idx
,
end_idx
),
group
in
groupby
(
_resolve_matches
(
prompt
,
mm_matches
),
key
=
lambda
x
:
(
x
.
start_idx
,
x
.
end_idx
),
):
matches
=
tuple
(
group
)
assert
len
(
matches
)
==
1
item_idx
=
next_idx_by_modality
[
modality
]
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
for
match
in
matches
:
modality
=
match
.
modality
start_idx
=
match
.
start_idx
end_idx
=
match
.
end_idx
item_idx
=
next_idx_by_modality
[
modality
]
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
repl_info
=
match
.
prompt_repl
replacement
=
repl_info
.
get_replacement
(
item_idx
)
origin
=
match
.
_origin
content
=
origin
.
get_content
(
item_idx
)
mode
=
origin
.
mode
if
isinstance
(
prompt
,
str
):
repl_seq
=
replacement
.
full
.
text
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
]
+
repl_seq
)
else
:
repl_seq
=
replacement
.
full
.
token_ids
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
]
+
repl_seq
)
if
mode
==
UpdateMode
.
INSERT
:
out_seqs
.
append
(
prompt
[
prev_end_idx
:
end_idx
])
num_inserts
=
mm_item_counts
.
get
(
modality
,
0
)
elif
mode
==
UpdateMode
.
REPLACE
:
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
])
num_inserts
=
1
else
:
assert_never
(
mode
)
prev_end_idx
=
end_idx
next_idx_by_modality
[
modality
]
+=
1
for
_
in
range
(
num_inserts
):
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
if
isinstance
(
prompt
,
str
):
out_seqs
.
append
(
content
.
full
.
text
)
else
:
out_seqs
.
append
(
content
.
full
.
token_ids
)
next_idx_by_modality
[
modality
]
+=
1
prev_end_idx
=
end_idx
out_seqs
.
append
(
prompt
[
prev_end_idx
:])
return
out_seqs
return
cast
(
list
[
_S
],
out_seqs
)
def
replace
_token_matches
(
def
apply
_token_matches
(
prompt
:
list
[
int
],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
ReplacementToken
Match
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Target
Match
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
int
]:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
if
not
mm_matches
:
return
prompt
token_id_seqs
=
_
replace
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
token_id_seqs
=
_
apply
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
return
flatten_2d_lists
(
token_id_seqs
)
def
replace
_text_matches
(
def
apply
_text_matches
(
prompt
:
str
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
ReplacementTex
tMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
_Prompt
Targe
tMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
str
:
"""Apply the
replacement
s in :code:`mm_matches` to :code:`prompt`."""
"""Apply the
update
s in :code:`mm_matches` to :code:`prompt`."""
if
not
mm_matches
:
return
prompt
texts
=
_
replace
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
texts
=
_
apply
_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
return
""
.
join
(
texts
)
def
_iter_placeholders
(
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Iterable
[
PlaceholderFeaturesInfo
]:
...
...
@@ -517,7 +630,7 @@ def _iter_placeholders(
Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_
repl
s` takes priority.
appears earlier in `mm_prompt_
update
s` takes priority.
Note that empty matches are ignored.
"""
...
...
@@ -528,37 +641,37 @@ def _iter_placeholders(
while
start_idx
<
prompt_len
:
found
=
False
for
modality
,
modality_
repl
s
in
mm_prompt_
repl
s
.
items
():
for
modality
,
modality_
update
s
in
mm_prompt_
update
s
.
items
():
item_idx
=
item_idx_by_modality
[
modality
]
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
for
repl
_info
in
modality_
repl
s
:
replacement
=
repl
_info
.
get_
replacem
ent
(
item_idx
)
repl
_tokens_full
=
replacem
ent
.
full
.
token_ids
repl
_len_full
=
len
(
repl
_tokens_full
)
end_idx_full
=
start_idx
+
repl
_len_full
for
update
_info
in
modality_
update
s
:
content
=
update
_info
.
get_
cont
ent
(
item_idx
)
content
_tokens_full
=
cont
ent
.
full
.
token_ids
content
_len_full
=
len
(
content
_tokens_full
)
end_idx_full
=
start_idx
+
content
_len_full
if
repl
_len_full
==
0
or
end_idx_full
>
prompt_len
:
if
content
_len_full
==
0
or
end_idx_full
>
prompt_len
:
continue
if
prompt
[
start_idx
:
end_idx_full
]
==
repl
_tokens_full
:
repl
_tokens_feat
=
replacem
ent
.
features
.
token_ids
if
prompt
[
start_idx
:
end_idx_full
]
==
content
_tokens_full
:
content
_tokens_feat
=
cont
ent
.
features
.
token_ids
try
:
match
=
next
(
iter_token_matches
(
repl
_tokens_full
,
repl
_tokens_feat
))
iter_token_matches
(
content
_tokens_full
,
content
_tokens_feat
))
yield
PlaceholderFeaturesInfo
(
modality
=
modality
,
item_idx
=
item_idx
,
start_idx
=
start_idx
+
match
.
start_idx
,
tokens
=
repl
_tokens_feat
,
tokens
=
content
_tokens_feat
,
)
except
StopIteration
:
raise
AssertionError
(
f
"
{
repl
_tokens_feat
=
}
should be a "
f
"subsequence of
{
repl
_tokens_full
=
}
"
)
from
None
f
"
{
content
_tokens_feat
=
}
should be a "
f
"subsequence of
{
content
_tokens_full
=
}
"
)
from
None
# Exclude overlapping matches
start_idx
=
end_idx_full
...
...
@@ -574,11 +687,11 @@ def _iter_placeholders(
def
find_mm_placeholders
(
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
it
=
_iter_placeholders
(
mm_prompt_
repl
s
,
prompt
,
mm_item_counts
)
it
=
_iter_placeholders
(
mm_prompt_
update
s
,
prompt
,
mm_item_counts
)
return
dict
(
full_groupby_modality
(
it
))
...
...
@@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
if
get_repls
:
=
getattr
(
self
,
"_get_prompt_replacements"
,
None
):
logger
.
warning_once
(
"`_get_prompt_replacements` has been renamed "
"to `_get_prompt_updates`. The old name will "
"be removed in an upcoming release."
)
self
.
_get_prompt_updates
=
get_repls
# type: ignore[method-assign]
super
().
__init__
()
self
.
info
=
info
...
...
@@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
raise
NotImplementedError
@
abstractmethod
def
_get_prompt_
replacement
s
(
def
_get_prompt_
update
s
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
Prompt
Replacement
]:
)
->
list
[
Prompt
Update
]:
"""
Given the original multi-modal items for this modality
and HF-processed data, output the
replacement
s to perform.
and HF-processed data, output the
update
s to perform.
Notes:
- You should not assume that HF processor always performs prompt
replacement
: in :meth:`_apply_hf_processor_missing`, this method
updates
: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
- The
replacement
information returned by this method is also used
to
determine the placeholder token positions for each multi-modal
- The
update
information returned by this method is also used
to
determine the placeholder token positions for each multi-modal
item.
"""
raise
NotImplementedError
def
_find_mm_placeholders
(
self
,
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
new_token_ids
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
return
find_mm_placeholders
(
mm_prompt_
repl
s
,
new_token_ids
,
return
find_mm_placeholders
(
mm_prompt_
update
s
,
new_token_ids
,
mm_item_counts
)
def
_get_hf_mm_data
(
...
...
@@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
,
)
def
_hf_processor_applies_
repl
(
def
_hf_processor_applies_
updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
"""
Return whether the HF processor applies prompt
replacement
s.
Return whether the HF processor applies prompt
update
s.
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
...
...
@@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt
replacement
s have been applied.
In addition, return whether prompt
update
s have been applied.
"""
processor_data
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_items
)
...
...
@@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
),
)
is_
repl
_applied
=
self
.
_hf_processor_applies_
repl
(
is_
update
_applied
=
self
.
_hf_processor_applies_
updates
(
prompt_text
=
prompt_text
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
return
prompt_ids
,
mm_kwargs
,
is_
repl
_applied
return
prompt_ids
,
mm_kwargs
,
is_
update
_applied
def
_apply_hf_processor_text_only
(
self
,
prompt_text
:
str
)
->
list
[
int
]:
"""
...
...
@@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
*
,
enable_hf_prompt_
replacement
:
bool
,
enable_hf_prompt_
update
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
"""
Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt
replacement
s have been applied
In addition, return whether prompt
update
s have been applied
(for most HF processors, this should be :code:`True`).
Note:
If :code:`enable_hf_prompt_
replacement
=False`, we use HF processor
to perform prompt
replacement
if available; HF processor requires
If :code:`enable_hf_prompt_
update
=False`, we use HF processor
to perform prompt
updates
if available; HF processor requires
that the prompt corresponds to multi-modal items.
"""
if
isinstance
(
prompt
,
str
):
if
enable_hf_prompt_
replacement
:
if
enable_hf_prompt_
update
:
return
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
prompt
,
mm_items
=
mm_items
,
...
...
@@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt
=
prompt
,
mm_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_
replacement
=
True
,
enable_hf_prompt_
update
=
True
,
)
mm_maybe_cached_kw_items
=
{
...
...
@@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt
replacement
s until the new multimodal
# so we can't apply prompt
update
s until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids
,
mm_missing_kwargs
,
is_
repl
_applied
,
is_
update
_applied
,
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
mm_items
=
mm_missing_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_
replacement
=
False
,
enable_hf_prompt_
update
=
False
,
)
mm_missing_next_idx
=
{
...
...
@@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
return
prompt_ids
,
mm_kwargs
,
is_
repl
_applied
return
prompt_ids
,
mm_kwargs
,
is_
update
_applied
def
_bind_and_group_
repl
s
(
def
_bind_and_group_
update
s
(
self
,
prompt_
repl
s
:
list
[
Prompt
Replacement
],
)
->
dict
[
str
,
list
[
BoundPrompt
Replacement
]]:
prompt_
update
s
:
list
[
Prompt
Update
],
)
->
dict
[
str
,
list
[
BoundPrompt
Update
]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_
repl
s
)
it
=
(
update
.
bind
(
tokenizer
)
for
update
in
prompt_
update
s
)
return
dict
(
full_groupby_modality
(
it
))
def
_apply_prompt_
replacement
s
(
def
_apply_prompt_
update
s
(
self
,
token_ids
:
list
[
int
],
mm_prompt_
repl
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Replacement
]],
mm_prompt_
update
s
:
Mapping
[
str
,
Sequence
[
BoundPrompt
Update
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
mm_token_matches
=
{
modality
:
find_token_matches
(
token_ids
,
prompt_repl
s
)
for
modality
,
prompt_repl
s
in
mm_prompt_
repl
s
.
items
()
modality
:
find_token_matches
(
token_ids
,
update
s
)
for
modality
,
update
s
in
mm_prompt_
update
s
.
items
()
}
mm_match_counts
=
{
modality
:
len
(
matches
)
...
...
@@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# up a token, then the token ID of "foo" will not appear at all
# ----
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
#
replacement
on the decoded token IDs, then encode them back.
# of the search text in the prompt, we instead perform string
-based
#
updates
on the decoded token IDs, then encode them back.
if
all
(
mm_match_counts
.
get
(
modality
,
0
)
>=
item_count
for
modality
,
item_count
in
mm_item_counts
.
items
()
):
# yapf: disable
token_ids
=
replace
_token_matches
(
token_ids
=
apply
_token_matches
(
token_ids
,
mm_token_matches
,
mm_item_counts
,
)
text
=
decode_tokens
(
tokenizer
,
token_ids
)
matched_
repl
s
=
{
modality
:
[
match
.
prompt_repl
for
match
in
token_matches
]
matched_
update
s
=
{
modality
:
[
match
.
_origin
for
match
in
token_matches
]
for
modality
,
token_matches
in
mm_token_matches
.
items
()
}
else
:
text
=
decode_tokens
(
tokenizer
,
token_ids
)
mm_text_matches
=
{
modality
:
find_text_matches
(
text
,
prompt_repl
s
)
for
modality
,
prompt_repl
s
in
mm_prompt_
repl
s
.
items
()
modality
:
find_text_matches
(
text
,
update
s
)
for
modality
,
update
s
in
mm_prompt_
update
s
.
items
()
}
text
=
replace
_text_matches
(
text
=
apply
_text_matches
(
text
,
mm_text_matches
,
mm_item_counts
,
...
...
@@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids
=
encode_tokens
(
tokenizer
,
text
,
add_special_tokens
=
False
)
matched_
repl
s
=
{
modality
:
[
match
.
prompt_repl
for
match
in
token_matches
]
matched_
update
s
=
{
modality
:
[
match
.
_origin
for
match
in
token_matches
]
for
modality
,
token_matches
in
mm_text_matches
.
items
()
}
placeholders
=
self
.
_find_mm_placeholders
(
matched_
repl
s
,
matched_
update
s
,
token_ids
,
mm_item_counts
,
)
...
...
@@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if
len
(
placeholders
)
!=
item_count
:
raise
RuntimeError
(
f
"Expected there to be
{
item_count
}
prompt
replacement
s "
f
"Expected there to be
{
item_count
}
prompt
update
s "
f
"corresponding to
{
item_count
}
{
modality
}
items, but "
f
"instead found
{
len
(
placeholders
)
}
prompt
replacement
s! "
f
"instead found
{
len
(
placeholders
)
}
prompt
update
s! "
"Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_
replacement
s`)."
)
"`_call_hf_processor` and `_get_prompt_
update
s`)."
)
def
apply
(
self
,
...
...
@@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and
replac
e sequences in the token IDs with placeholder tokens.
2. Find and
updat
e sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
...
...
@@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
prompt_ids
,
mm_kwargs
,
is_
repl
_applied
,
is_
update
_applied
,
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
mm_items
,
hf_processor_mm_kwargs
,
)
unbound_prompt_
repl
s
=
self
.
_get_prompt_
replacement
s
(
unbound_prompt_
update
s
=
self
.
_get_prompt_
update
s
(
mm_items
,
hf_processor_mm_kwargs
,
mm_kwargs
,
)
mm_prompt_repls
=
self
.
_bind_and_group_repls
(
unbound_prompt_repls
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
if
is_
repl
_applied
:
if
is_
update
_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
mm_prompt_
repl
s
,
mm_prompt_
update
s
,
prompt_ids
,
mm_item_counts
,
)
...
...
@@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids
,
prompt
,
mm_placeholders
,
)
=
self
.
_apply_prompt_
replacement
s
(
)
=
self
.
_apply_prompt_
update
s
(
prompt_ids
,
mm_prompt_
repl
s
,
mm_prompt_
update
s
,
mm_item_counts
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment