Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
......@@ -9,13 +9,15 @@ from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
import numpy as np
import torch
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict, VideoItem)
ImageItem, ModalityData, MultiModalDataDict,
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
......@@ -111,6 +113,64 @@ class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
return len(self.get(item_idx))
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor]]):
"""
Base class for data items that are expressed as a dictionary of tensors.
Usually, the dictionary keys correspond to the outputs of HF processor.
"""
def __init__(
self,
data: Mapping[str, torch.Tensor],
modality: str,
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
msg = (f"The data should contain the fields: {required_fields}, "
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
self._kwargs = MultiModalKwargs.from_hf_inputs(
BatchFeature(dict(data)),
fields_config,
)
def get_count(self) -> int:
return self._kwargs.get_item_count(self.modality)
def get(self, index: int) -> Mapping[str, torch.Tensor]:
return {
k: v.data
for k, v in self._kwargs.get_item(self.modality, index).items()
}
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment