Unverified Commit a4ce74c1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Use shared field to pass token ids to model

parent 3b2005e1
......@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = [image_token_id
] * len(image_data)
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
......@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
......
......@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
......@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
@dataclass(frozen=True)
class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field: "BaseMultiModalField"
"""
Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field: "BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return (self.field == other.field
and nested_tensors_equal(self.data, other.data))
return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data)
and type(self.field) == type(other.field)) # noqa: E721
@dataclass(frozen=True)
class BaseMultiModalField(ABC):
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
key: str
modality: str
"""
Defines how to interpret tensor data belonging to a keyword argument in
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
# Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem:
return f(data=data)
return factory
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise NotImplementedError
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldElem(self, data)
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
"""
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
data = self._reduce_data([item.data for item in batch])
This is the inverse of :meth:`build_elems`.
"""
field_types = [type(item.field) for item in elems]
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
return self._build_elem(data)
return self._reduce_data([item.data for item in elems])
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.batched`
"""
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_elem(item) for item in batch]
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
......@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
slices: Sequence[slice]
def build_elems(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldElem]:
return [self._build_elem(batch[slice_]) for slice_ in slices]
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data[s]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
......@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem]
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size: int
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
return batch[0]
class MultiModalFieldConfig:
@staticmethod
def batched(modality: str):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return MultiModalFieldConfig(
field_cls=MultiModalBatchedField,
field=MultiModalBatchedField(),
modality=modality,
)
@staticmethod
def flat(modality: str, slices: Sequence[slice]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return MultiModalFieldConfig(
field_cls=MultiModalFlatField,
field=MultiModalFlatField(slices=slices),
modality=modality,
slices=slices,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
......@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
return MultiModalFieldConfig.flat(modality, slices)
def __init__(
self,
field_cls: type[BaseMultiModalField],
modality: str,
**field_config: Any,
) -> None:
@staticmethod
def shared(modality: str, batch_size: int):
"""
Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data.
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
modality=modality,
)
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
super().__init__()
self.field_cls = field_cls
self.field = field
self.modality = modality
self.field_config = field_config
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore
return self.field.build_elems(self.modality, key, batch)
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
......@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
......@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce(elems).data
key: elems[0].field.reduce_data(elems)
for key, elems in elems_by_key.items() if len(elems) > 0
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment