Unverified Commit f1579b22 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Generalized prompt updates for multi-modal processor (#13964)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 78648758
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Mapping, Optional from collections.abc import Mapping, Sequence
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptReplacementDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
...@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): ...@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs: if "image_num_patches" in out_mm_kwargs:
...@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): ...@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptReplacementDetails( return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size, full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n", num_patches) + "\n",
features=hf_processor.get_image_repl_features( features=hf_processor.get_image_repl_features(
......
...@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo, BoundPromptUpdate,
BoundPromptReplacement,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptReplacement, PromptUpdate,
PromptReplacementDetails) PromptUpdateDetails)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
...@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails( return PromptUpdateDetails(
full=image_tokens + [bos_token_id], full=image_tokens + [bos_token_id],
features=image_tokens, features=image_tokens,
) )
...@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
) for image_token in image_tokens[:num_images] ) for image_token in image_tokens[:num_images]
] ]
def _apply_prompt_replacements( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls, mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts, mm_item_counts=mm_item_counts,
) )
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, Mapping, Optional, Set, Tuple, Union from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput) PoolingSequenceGroupOutput)
...@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): ...@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder( class PrithviGeoSpatialMAEInputBuilder(
...@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords=MultiModalFieldConfig.batched("image"), location_coords=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
pass return []
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
def apply( def apply(
self, self,
...@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder""" """ Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None: def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model # We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask": if config["task_args"]["task"] == "SemanticSegmentationTask":
...@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE.") "by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data( def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, from typing import Any, Optional, Set, Tuple, TypedDict, Union
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, ...@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor( ...@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask=MultiModalFieldConfig.batched("audio"), feature_attention_mask=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
...@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor( ...@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features audio_tokens = [audio_token_id] * num_features
return PromptReplacementDetails( return PromptUpdateDetails(
full=[audio_bos_id] + audio_tokens + [audio_eos_id], full=[audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens, features=audio_tokens,
) )
......
...@@ -23,9 +23,10 @@ ...@@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Tuple, Type, TypedDict, Union) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ...@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module): ...@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self, self,
in_features: int, in_features: int,
hidden_features: int, hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim: int, dim: int,
num_heads: int, num_heads: int,
mlp_ratio: float, mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Optional[Callable[[int], nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self.info._get_image_processor_kwargs(**mm_kwargs), self.info._get_image_processor_kwargs(**mm_kwargs),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor( image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs) **hf_processor_mm_kwargs)
......
...@@ -9,9 +9,10 @@ import copy ...@@ -9,9 +9,10 @@ import copy
import math import math
import re import re
import unicodedata import unicodedata
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Set as AbstractSet
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping, from typing import Callable, List, Literal, Optional, TypedDict, Union
Optional, TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
special_tokens: dict[str, special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore int] = tokenizer.special_tokens # type: ignore
...@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[img_start_id, img_end_id], target=[img_start_id, img_end_id],
replacement=PromptReplacementDetails( replacement=PromptUpdateDetails(
full=[img_start_id] + image_tokens + [img_end_id], full=[img_start_id] + image_tokens + [img_end_id],
features=image_tokens, features=image_tokens,
), ),
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
...@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor( ...@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from collections.abc import Iterable, Mapping, Sequence
Union) from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, ...@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription from .interfaces import SupportsMultiModal, SupportsTranscription
...@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor( ...@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio")) return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens() num_tokens = self.info.get_max_audio_tokens()
return [ return [
PromptReplacement( PromptReplacement(
......
...@@ -6,11 +6,14 @@ from collections import defaultdict ...@@ -6,11 +6,14 @@ from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence) Sequence)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache from functools import lru_cache
from itertools import groupby
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union) TypeVar, Union, cast)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
...@@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]] ...@@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
@dataclass @dataclass
class PromptReplacementDetails: class PromptUpdateDetails:
"""Details about the replacement token sequence or text.""" """Details about the token sequence or text that are part of the update."""
full: PromptSeq full: PromptSeq
"""The full replacement.""" """The full content."""
features: PromptSeq features: PromptSeq
""" """
The part of the replacement that corresponds to feature placeholders; The part of the content that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model this will be replaced by the output of the vision encoder during model
inference. inference.
""" """
@staticmethod @staticmethod
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails": def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
return PromptReplacementDetails(full=seq, features=seq) return PromptUpdateDetails(full=seq, features=seq)
PromptRepl = Union[PromptSeq, PromptReplacementDetails] PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
""" """
The replacement token sequence or text. The token sequence or text that are part of the update.
If only part of the replacement corresponds to feature placeholders, you can If only part of the content corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part. use :class:`PromptUpdateDetails` to specify which part.
""" """
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
class UpdateMode(str, Enum):
INSERT = "insert"
REPLACE = "replace"
@dataclass
class PromptUpdate:
"""
Defines how to update a prompt with placeholder tokens.
"""
modality: str
"""The modality for which the update is made."""
target: PromptSeq
"""The token sequence (or text) to update."""
@property
@abstractmethod
def content(self) -> PromptUpdateContent:
"""The placeholder tokens that are part of the update."""
raise NotImplementedError
@property
@abstractmethod
def mode(self) -> UpdateMode:
"""Defines how to update the prompt."""
raise NotImplementedError
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
return BoundPromptUpdate(
_origin=self,
tokenizer=tokenizer,
)
@dataclass @dataclass
class PromptReplacement: class PromptInsertion(PromptUpdate):
"""
Defines how to insert placeholder tokens into a prompt.
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token:
.. code-block:: python
PromptInsertion(
modality="image",
target="<s>",
insertion="<image>" * image_feature_size,
)
"""
insertion: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the token sequence (or text) to insert right after :attr:`target`.
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
@property
def content(self) -> PromptUpdateContent:
return self.insertion
@property
def mode(self) -> UpdateMode:
return UpdateMode.INSERT
@dataclass
class PromptReplacement(PromptUpdate):
""" """
Defines how to replace portions of an input prompt with placeholder tokens. Defines how to replace portions of an input prompt with placeholder tokens.
...@@ -93,7 +190,7 @@ class PromptReplacement: ...@@ -93,7 +190,7 @@ class PromptReplacement:
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target="<image>", target="<image>",
replacement=PromptReplacementDetails( replacement=PromptUpdateDetails(
full="".join([ full="".join([
"<image_bos>", "<image_bos>",
"<image>" * image_feature_size, "<image>" * image_feature_size,
...@@ -111,7 +208,7 @@ class PromptReplacement: ...@@ -111,7 +208,7 @@ class PromptReplacement:
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id], target=[image_token_id],
replacement=PromptReplacementDetails( replacement=PromptUpdateDetails(
full=([image_bos_id] + [image_token_id] * image_feature_size full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]), + [image_eos_id]),
features=[image_token_id] * image_feature_size, features=[image_token_id] * image_feature_size,
...@@ -119,29 +216,22 @@ class PromptReplacement: ...@@ -119,29 +216,22 @@ class PromptReplacement:
) )
""" """
modality: str replacement: PromptUpdateContent = field(repr=False)
"""The modality for which the replacement is made."""
target: PromptSeq
"""The token sequence (or text) to find and replace."""
replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
""" """
Given the index of the processed item within :attr:`modality`, Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text). output the token sequence (or text) to replace :attr:`target`.
For convenience, you can directly pass in the replacement token sequence For convenience, you can directly pass in the token sequence (or text)
(or text) instead of a function if it does not depend on the input. instead of a function if it does not depend on the input.
""" """
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement": @property
return BoundPromptReplacement( def content(self) -> PromptUpdateContent:
tokenizer=tokenizer, return self.replacement
modality=self.modality,
_target=self.target, @property
_replacement=self.replacement, def mode(self) -> UpdateMode:
) return UpdateMode.REPLACE
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
...@@ -232,64 +322,73 @@ class _BoundPromptSequence: ...@@ -232,64 +322,73 @@ class _BoundPromptSequence:
@dataclass @dataclass
class _BoundPromptReplacementGroup: class _BoundPromptContent:
full: _BoundPromptSequence full: _BoundPromptSequence
features: _BoundPromptSequence features: _BoundPromptSequence
@dataclass @dataclass
class BoundPromptReplacement: class BoundPromptUpdate:
""" """
A :class:`PromptReplacement` bound to a tokenizer to automatically A :class:`PromptUpdate` bound to a tokenizer to automatically convert
convert :attr:`target` and the result of :meth:`get_replacement` between :attr:`target` and the result of :meth:`get_content` between
token sequence and text representations. token sequence and text representations.
""" """
_origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False) tokenizer: AnyTokenizer = field(repr=False)
modality: str
_target: PromptSeq
_replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]() self._content_cache = dict[int, _BoundPromptContent]()
@property
def modality(self) -> str:
return self._origin.modality
@property @property
def target(self) -> _BoundPromptSequence: def target(self) -> _BoundPromptSequence:
"""The token sequence (or text) to find and replace.""" """The token sequence (or text) to update."""
return _BoundPromptSequence.from_seq(self.tokenizer, self._target) return _BoundPromptSequence.from_seq(self.tokenizer,
self._origin.target)
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup: @property
def content(self) -> PromptUpdateContent:
"""The placeholder tokens that are part of the update."""
return self._origin.content
@property
def mode(self) -> UpdateMode:
"""Defines how to update the prompt."""
return self._origin.mode
def get_content(self, item_idx: int) -> _BoundPromptContent:
""" """
Given the index of the processed item within :attr:`modality`, Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text). output the token sequence (or text) to update.
""" """
replacement = self._replacement content = self.content
if callable(replacement): if callable(content):
cache_key = item_idx cache_key = item_idx
if cache_key in self._replacement_cache: if cache_key in self._content_cache:
return self._replacement_cache[cache_key] return self._content_cache[cache_key]
replacement = replacement(item_idx) content = content(item_idx)
else: else:
cache_key = None cache_key = None
if not isinstance(replacement, PromptReplacementDetails): if not isinstance(content, PromptUpdateDetails):
replacement = PromptReplacementDetails.from_seq(replacement) content = PromptUpdateDetails.from_seq(content)
bound_full = _BoundPromptSequence.from_seq(self.tokenizer, bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.full) content.full)
bound_features = _BoundPromptSequence.from_seq(self.tokenizer, bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.features) content.features)
bound_replacement = _BoundPromptReplacementGroup( bound_content = _BoundPromptContent(full=bound_full,
full=bound_full, features=bound_features)
features=bound_features,
)
if cache_key is not None: if cache_key is not None:
self._replacement_cache[cache_key] = bound_replacement self._content_cache[cache_key] = bound_content
return bound_replacement return bound_content
class _TokenMatch(NamedTuple): class _TokenMatch(NamedTuple):
...@@ -326,12 +425,12 @@ def iter_token_matches( ...@@ -326,12 +425,12 @@ def iter_token_matches(
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementMatch(ABC): class _PromptTargetMatch(ABC):
prompt_repl: BoundPromptReplacement _origin: BoundPromptUpdate
@property @property
def modality(self) -> str: def modality(self) -> str:
return self.prompt_repl.modality return self._origin.modality
@property @property
@abstractmethod @abstractmethod
...@@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC): ...@@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch): class _PromptTargetTokenMatch(_PromptTargetMatch):
match: _TokenMatch match: _TokenMatch
@property @property
...@@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch): ...@@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch): class _PromptTargetTextMatch(_PromptTargetMatch):
match: re.Match[str] match: re.Match[str]
@property @property
...@@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo: ...@@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
def find_token_matches( def find_token_matches(
prompt: list[int], prompt: list[int],
prompt_repls: Sequence[BoundPromptReplacement], prompt_updates: Sequence[BoundPromptUpdate],
) -> list[_PromptReplacementTokenMatch]: ) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
return [ return [
_PromptReplacementTokenMatch(prompt_repl, match) _PromptTargetTokenMatch(update, match) for update in prompt_updates
for prompt_repl in prompt_repls for match in iter_token_matches(prompt, update.target.token_ids)
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
] ]
def find_text_matches( def find_text_matches(
prompt: str, prompt: str,
prompt_repls: Sequence[BoundPromptReplacement], prompt_updates: Sequence[BoundPromptUpdate],
) -> list[_PromptReplacementTextMatch]: ) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
return [ return [
_PromptReplacementTextMatch(prompt_repl, match) _PromptTargetTextMatch(update, match) for update in prompt_updates
for prompt_repl in prompt_repls for match in re.finditer(re.escape(update.target.text), prompt)
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
] ]
def _resolve_matches( def _resolve_matches(
prompt: PromptSeq, prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
) -> list[_PromptReplacementMatch]: ) -> list[_PromptTargetMatch]:
""" """
Resolve :code:`mm_matches` to ensure that there are no overlapping matches, Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones. and sort them such that earlier matches take priority over later ones.
""" """
matches = [m for matches in mm_matches.values() for m in matches] matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptReplacementMatch]] = [None seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
] * len(prompt)
for match in matches: for match in matches:
for idx in range(match.start_idx, match.end_idx): for idx in range(match.start_idx, match.end_idx):
...@@ -441,74 +537,91 @@ def _resolve_matches( ...@@ -441,74 +537,91 @@ def _resolve_matches(
return sorted(matches, key=lambda x: x.start_idx) return sorted(matches, key=lambda x: x.start_idx)
def _replace_matches( def _apply_matches(
prompt: _S, prompt: _S,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[_S]: ) -> list[_S]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[_S]() out_seqs = list[Union[str, list[int]]]()
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0) next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, mm_matches): for (start_idx, end_idx), group in groupby(
modality = match.modality _resolve_matches(prompt, mm_matches),
key=lambda x: (x.start_idx, x.end_idx),
):
matches = tuple(group)
assert len(matches) == 1
item_idx = next_idx_by_modality[modality] for match in matches:
if item_idx >= mm_item_counts.get(modality, 0): modality = match.modality
continue
start_idx = match.start_idx item_idx = next_idx_by_modality[modality]
end_idx = match.end_idx if item_idx >= mm_item_counts.get(modality, 0):
continue
repl_info = match.prompt_repl origin = match._origin
replacement = repl_info.get_replacement(item_idx) content = origin.get_content(item_idx)
mode = origin.mode
if isinstance(prompt, str): if mode == UpdateMode.INSERT:
repl_seq = replacement.full.text out_seqs.append(prompt[prev_end_idx:end_idx])
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) num_inserts = mm_item_counts.get(modality, 0)
else: elif mode == UpdateMode.REPLACE:
repl_seq = replacement.full.token_ids out_seqs.append(prompt[prev_end_idx:start_idx])
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq) num_inserts = 1
else:
assert_never(mode)
prev_end_idx = end_idx for _ in range(num_inserts):
next_idx_by_modality[modality] += 1 if item_idx >= mm_item_counts.get(modality, 0):
continue
if isinstance(prompt, str):
out_seqs.append(content.full.text)
else:
out_seqs.append(content.full.token_ids)
next_idx_by_modality[modality] += 1
prev_end_idx = end_idx
out_seqs.append(prompt[prev_end_idx:]) out_seqs.append(prompt[prev_end_idx:])
return out_seqs return cast(list[_S], out_seqs)
def replace_token_matches( def apply_token_matches(
prompt: list[int], prompt: list[int],
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]], mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> list[int]: ) -> list[int]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches: if not mm_matches:
return prompt return prompt
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts) token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
return flatten_2d_lists(token_id_seqs) return flatten_2d_lists(token_id_seqs)
def replace_text_matches( def apply_text_matches(
prompt: str, prompt: str,
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]], mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> str: ) -> str:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" """Apply the updates in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches: if not mm_matches:
return prompt return prompt
texts = _replace_matches(prompt, mm_matches, mm_item_counts) texts = _apply_matches(prompt, mm_matches, mm_item_counts)
return "".join(texts) return "".join(texts)
def _iter_placeholders( def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
...@@ -517,7 +630,7 @@ def _iter_placeholders( ...@@ -517,7 +630,7 @@ def _iter_placeholders(
Matches are exclusive even when multiple modalities share Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_repls` takes priority. appears earlier in `mm_prompt_updates` takes priority.
Note that empty matches are ignored. Note that empty matches are ignored.
""" """
...@@ -528,37 +641,37 @@ def _iter_placeholders( ...@@ -528,37 +641,37 @@ def _iter_placeholders(
while start_idx < prompt_len: while start_idx < prompt_len:
found = False found = False
for modality, modality_repls in mm_prompt_repls.items(): for modality, modality_updates in mm_prompt_updates.items():
item_idx = item_idx_by_modality[modality] item_idx = item_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0): if item_idx >= mm_item_counts.get(modality, 0):
continue continue
for repl_info in modality_repls: for update_info in modality_updates:
replacement = repl_info.get_replacement(item_idx) content = update_info.get_content(item_idx)
repl_tokens_full = replacement.full.token_ids content_tokens_full = content.full.token_ids
repl_len_full = len(repl_tokens_full) content_len_full = len(content_tokens_full)
end_idx_full = start_idx + repl_len_full end_idx_full = start_idx + content_len_full
if repl_len_full == 0 or end_idx_full > prompt_len: if content_len_full == 0 or end_idx_full > prompt_len:
continue continue
if prompt[start_idx:end_idx_full] == repl_tokens_full: if prompt[start_idx:end_idx_full] == content_tokens_full:
repl_tokens_feat = replacement.features.token_ids content_tokens_feat = content.features.token_ids
try: try:
match = next( match = next(
iter_token_matches(repl_tokens_full, iter_token_matches(content_tokens_full,
repl_tokens_feat)) content_tokens_feat))
yield PlaceholderFeaturesInfo( yield PlaceholderFeaturesInfo(
modality=modality, modality=modality,
item_idx=item_idx, item_idx=item_idx,
start_idx=start_idx + match.start_idx, start_idx=start_idx + match.start_idx,
tokens=repl_tokens_feat, tokens=content_tokens_feat,
) )
except StopIteration: except StopIteration:
raise AssertionError( raise AssertionError(
f"{repl_tokens_feat=} should be a " f"{content_tokens_feat=} should be a "
f"subsequence of {repl_tokens_full=}") from None f"subsequence of {content_tokens_full=}") from None
# Exclude overlapping matches # Exclude overlapping matches
start_idx = end_idx_full start_idx = end_idx_full
...@@ -574,11 +687,11 @@ def _iter_placeholders( ...@@ -574,11 +687,11 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
...@@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None: enable_sanity_checks: bool = True) -> None:
if get_repls := getattr(self, "_get_prompt_replacements", None):
logger.warning_once("`_get_prompt_replacements` has been renamed "
"to `_get_prompt_updates`. The old name will "
"be removed in an upcoming release.")
self._get_prompt_updates = get_repls # type: ignore[method-assign]
super().__init__() super().__init__()
self.info = info self.info = info
...@@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptUpdate]:
""" """
Given the original multi-modal items for this modality Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform. and HF-processed data, output the updates to perform.
Notes: Notes:
- You should not assume that HF processor always performs prompt - You should not assume that HF processor always performs prompt
replacement: in :meth:`_apply_hf_processor_missing`, this method updates: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately, is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call. instead of passing them in the same call.
- The replacement information returned by this method is also used - The update information returned by this method is also used to
to determine the placeholder token positions for each multi-modal determine the placeholder token positions for each multi-modal
item. item.
""" """
raise NotImplementedError raise NotImplementedError
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids, return find_mm_placeholders(mm_prompt_updates, new_token_ids,
mm_item_counts) mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
...@@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs, mm_kwargs,
) )
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> bool: ) -> bool:
""" """
Return whether the HF processor applies prompt replacements. Return whether the HF processor applies prompt updates.
For most HF processors, this should be :code:`True` when multi-modal For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings data items are passed, but :code:`False` when multi-modal embeddings
...@@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.
In addition, return whether prompt replacements have been applied. In addition, return whether prompt updates have been applied.
""" """
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
...@@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
) )
is_repl_applied = self._hf_processor_applies_repl( is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
return prompt_ids, mm_kwargs, is_repl_applied return prompt_ids, mm_kwargs, is_update_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
""" """
...@@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_replacement: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], MultiModalKwargs, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data. Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt replacements have been applied In addition, return whether prompt updates have been applied
(for most HF processors, this should be :code:`True`). (for most HF processors, this should be :code:`True`).
Note: Note:
If :code:`enable_hf_prompt_replacement=False`, we use HF processor If :code:`enable_hf_prompt_update=False`, we use HF processor
to perform prompt replacement if available; HF processor requires to perform prompt updates if available; HF processor requires
that the prompt corresponds to multi-modal items. that the prompt corresponds to multi-modal items.
""" """
if isinstance(prompt, str): if isinstance(prompt, str):
if enable_hf_prompt_replacement: if enable_hf_prompt_update:
return self._apply_hf_processor_text_mm( return self._apply_hf_processor_text_mm(
prompt_text=prompt, prompt_text=prompt,
mm_items=mm_items, mm_items=mm_items,
...@@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt, prompt=prompt,
mm_items=mm_data_items, mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=True, enable_hf_prompt_update=True,
) )
mm_maybe_cached_kw_items = { mm_maybe_cached_kw_items = {
...@@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items = self._to_mm_items(mm_missing_data) mm_missing_data_items = self._to_mm_items(mm_missing_data)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt replacements until the new multimodal # so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items # items are combined with the cached multimodal items
( (
prompt_ids, prompt_ids,
mm_missing_kwargs, mm_missing_kwargs,
is_repl_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=mm_missing_data_items, mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=False, enable_hf_prompt_update=False,
) )
mm_missing_next_idx = { mm_missing_next_idx = {
...@@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
return prompt_ids, mm_kwargs, is_repl_applied return prompt_ids, mm_kwargs, is_update_applied
def _bind_and_group_repls( def _bind_and_group_updates(
self, self,
prompt_repls: list[PromptReplacement], prompt_updates: list[PromptUpdate],
) -> dict[str, list[BoundPromptReplacement]]: ) -> dict[str, list[BoundPromptUpdate]]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) it = (update.bind(tokenizer) for update in prompt_updates)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
def _apply_prompt_replacements( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
mm_token_matches = { mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls) modality: find_token_matches(token_ids, updates)
for modality, prompt_repls in mm_prompt_repls.items() for modality, updates in mm_prompt_updates.items()
} }
mm_match_counts = { mm_match_counts = {
modality: len(matches) modality: len(matches)
...@@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# up a token, then the token ID of "foo" will not appear at all # up a token, then the token ID of "foo" will not appear at all
# ---- # ----
# Since it is inefficient to search for all possible tokenizations # Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string # of the search text in the prompt, we instead perform string-based
# replacement on the decoded token IDs, then encode them back. # updates on the decoded token IDs, then encode them back.
if all( if all(
mm_match_counts.get(modality, 0) >= item_count mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items() for modality, item_count in mm_item_counts.items()
): # yapf: disable ): # yapf: disable
token_ids = replace_token_matches( token_ids = apply_token_matches(
token_ids, token_ids,
mm_token_matches, mm_token_matches,
mm_item_counts, mm_item_counts,
) )
text = decode_tokens(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
matched_repls = { matched_updates = {
modality: [match.prompt_repl for match in token_matches] modality: [match._origin for match in token_matches]
for modality, token_matches in mm_token_matches.items() for modality, token_matches in mm_token_matches.items()
} }
else: else:
text = decode_tokens(tokenizer, token_ids) text = decode_tokens(tokenizer, token_ids)
mm_text_matches = { mm_text_matches = {
modality: find_text_matches(text, prompt_repls) modality: find_text_matches(text, updates)
for modality, prompt_repls in mm_prompt_repls.items() for modality, updates in mm_prompt_updates.items()
} }
text = replace_text_matches( text = apply_text_matches(
text, text,
mm_text_matches, mm_text_matches,
mm_item_counts, mm_item_counts,
...@@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids = encode_tokens(tokenizer, token_ids = encode_tokens(tokenizer,
text, text,
add_special_tokens=False) add_special_tokens=False)
matched_repls = { matched_updates = {
modality: [match.prompt_repl for match in token_matches] modality: [match._origin for match in token_matches]
for modality, token_matches in mm_text_matches.items() for modality, token_matches in mm_text_matches.items()
} }
placeholders = self._find_mm_placeholders( placeholders = self._find_mm_placeholders(
matched_repls, matched_updates,
token_ids, token_ids,
mm_item_counts, mm_item_counts,
) )
...@@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if len(placeholders) != item_count: if len(placeholders) != item_count:
raise RuntimeError( raise RuntimeError(
f"Expected there to be {item_count} prompt replacements " f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but " f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt replacements! " f"instead found {len(placeholders)} prompt updates! "
"Either the prompt text has missing/incorrect tokens for " "Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your " "multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this " "implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between " "model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).") "`_call_hf_processor` and `_get_prompt_updates`).")
def apply( def apply(
self, self,
...@@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
1. Apply HF Processor on prompt text and multi-modal data together, 1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors. outputting token IDs and processed tensors.
2. Find and replace sequences in the token IDs with placeholder tokens. 2. Find and update sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder. multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
...@@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_kwargs,
is_repl_applied, is_update_applied,
) = self._cached_apply_hf_processor( ) = self._cached_apply_hf_processor(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
) )
unbound_prompt_repls = self._get_prompt_replacements( unbound_prompt_updates = self._get_prompt_updates(
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
if is_repl_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls, mm_prompt_updates,
prompt_ids, prompt_ids,
mm_item_counts, mm_item_counts,
) )
...@@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids, prompt_ids,
prompt, prompt,
mm_placeholders, mm_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_updates(
prompt_ids, prompt_ids,
mm_prompt_repls, mm_prompt_updates,
mm_item_counts, mm_item_counts,
) )
self._validate_mm_placeholders(mm_placeholders, mm_item_counts) self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment