Unverified Commit f1579b22 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Generalized prompt updates for multi-modal processor (#13964)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 78648758
......@@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Mapping, Optional
from collections.abc import Mapping, Sequence
from typing import Optional
import torch
import torch.nn as nn
......@@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
......@@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
......@@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n",
features=hf_processor.get_image_repl_features(
......
......@@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
BoundPromptReplacement,
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptReplacementDetails)
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
......@@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
......@@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
) for image_token in image_tokens[:num_images]
]
def _apply_prompt_replacements(
def _apply_prompt_updates(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts,
)
......
......@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch
import torch.nn as nn
......@@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
......@@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder(
......@@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
pass
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
) -> Sequence[PromptUpdate]:
return []
def apply(
self,
......@@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None:
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
......@@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
......
......@@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from typing import Any, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
......@@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features
return PromptReplacementDetails(
return PromptUpdateDetails(
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens,
)
......
......@@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
import torch
import torch.nn as nn
......@@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
......@@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self,
in_features: int,
hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
......@@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
......@@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self.info._get_image_processor_kwargs(**mm_kwargs),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
......
......@@ -9,9 +9,10 @@ import copy
import math
import re
import unicodedata
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Set as AbstractSet
from functools import lru_cache, partial
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping,
Optional, TypedDict, Union)
from typing import Callable, List, Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs,
)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
......@@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
......@@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement(
modality="image",
target=[img_start_id, img_end_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=[img_start_id] + image_tokens + [img_end_id],
features=image_tokens,
),
......
......@@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
......@@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
......@@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
......
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
......@@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
......@@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment