"vllm/vscode:/vscode.git/clone" did not exist on "90d74ebaa47fcecdcd8ef72338dda47b7cb6fbf0"
Unverified Commit f1579b22 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Generalized prompt updates for multi-modal processor (#13964)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 78648758
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Mapping, Optional from collections.abc import Mapping, Sequence
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptReplacementDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
...@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): ...@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs: if "image_num_patches" in out_mm_kwargs:
...@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): ...@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptReplacementDetails( return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size, full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n", num_patches) + "\n",
features=hf_processor.get_image_repl_features( features=hf_processor.get_image_repl_features(
......
...@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo, BoundPromptUpdate,
BoundPromptReplacement,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptReplacement, PromptUpdate,
PromptReplacementDetails) PromptUpdateDetails)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
...@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails( return PromptUpdateDetails(
full=image_tokens + [bos_token_id], full=image_tokens + [bos_token_id],
features=image_tokens, features=image_tokens,
) )
...@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
) for image_token in image_tokens[:num_images] ) for image_token in image_tokens[:num_images]
] ]
def _apply_prompt_replacements( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]], mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls, mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts, mm_item_counts=mm_item_counts,
) )
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, Mapping, Optional, Set, Tuple, Union from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput) PoolingSequenceGroupOutput)
...@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): ...@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder( class PrithviGeoSpatialMAEInputBuilder(
...@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords=MultiModalFieldConfig.batched("image"), location_coords=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
pass return []
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
def apply( def apply(
self, self,
...@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder""" """ Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None: def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model # We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask": if config["task_args"]["task"] == "SemanticSegmentationTask":
...@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE.") "by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data( def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]: self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict, from typing import Any, Optional, Set, Tuple, TypedDict, Union
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, ...@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor( ...@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask=MultiModalFieldConfig.batched("audio"), feature_attention_mask=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
...@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor( ...@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features audio_tokens = [audio_token_id] * num_features
return PromptReplacementDetails( return PromptUpdateDetails(
full=[audio_bos_id] + audio_tokens + [audio_eos_id], full=[audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens, features=audio_tokens,
) )
......
...@@ -23,9 +23,10 @@ ...@@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Tuple, Type, TypedDict, Union) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, ...@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module): ...@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self, self,
in_features: int, in_features: int,
hidden_features: int, hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim: int, dim: int,
num_heads: int, num_heads: int,
mlp_ratio: float, mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU, act_layer: type[nn.Module] = QuickGELU,
norm_layer: Optional[Callable[[int], nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self.info._get_image_processor_kwargs(**mm_kwargs), self.info._get_image_processor_kwargs(**mm_kwargs),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor( image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs) **hf_processor_mm_kwargs)
......
...@@ -9,9 +9,10 @@ import copy ...@@ -9,9 +9,10 @@ import copy
import math import math
import re import re
import unicodedata import unicodedata
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Set as AbstractSet
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping, from typing import Callable, List, Literal, Optional, TypedDict, Union
Optional, TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
special_tokens: dict[str, special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore int] = tokenizer.special_tokens # type: ignore
...@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[img_start_id, img_end_id], target=[img_start_id, img_end_id],
replacement=PromptReplacementDetails( replacement=PromptUpdateDetails(
full=[img_start_id] + image_tokens + [img_end_id], full=[img_start_id] + image_tokens + [img_end_id],
features=image_tokens, features=image_tokens,
), ),
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
...@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor( ...@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from collections.abc import Iterable, Mapping, Sequence
Union) from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, ...@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription from .interfaces import SupportsMultiModal, SupportsTranscription
...@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor( ...@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio")) return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens() num_tokens = self.info.get_max_audio_tokens()
return [ return [
PromptReplacement( PromptReplacement(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment