Unverified Commit 7b623fca authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Check required fields before initializing field config in `DictEmbeddingItems` (#13380)

parent 238dfc8a
......@@ -184,8 +184,8 @@ llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={
mm_data = {
"image": {
"image_embeds": image_embeds,
# image_size_list is needed to calculate details of the sliced image.
"image_size_list": [image.size for image in images], # list of image sizes
# image_sizes is needed to calculate details of the sliced image.
"image_sizes": [image.size for image in images], # list of image sizes
}
}
......
......@@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
import torch
from torch import nn
......@@ -122,13 +122,16 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"},
fields_factory=fields_factory,
)
......@@ -141,7 +144,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(
data,
fields_config=_minicpmo_field_config(data),
fields_factory=_minicpmo_field_config,
)
return super()._parse_audio_data(data)
......
......@@ -255,13 +255,16 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"image_embeds", "image_sizes"},
fields_factory=fields_factory,
)
def get_image_size(self, index: int) -> ImageSize:
......@@ -274,13 +277,16 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(
data,
modality="video",
fields_config=fields_config,
required_fields={"video_embeds", "video_image_sizes"},
fields_factory=fields_factory,
)
def get_frame_size(self, index: int) -> ImageSize:
......@@ -300,7 +306,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
fields_factory=_minicpmv_field_config,
)
return super()._parse_image_data(data)
......@@ -312,7 +318,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(
data,
fields_config=_minicpmv_field_config(data),
fields_factory=_minicpmv_field_config,
)
return super()._parse_video_data(data)
......
......@@ -691,8 +691,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems(
data,
modality="image",
fields_config=_qwen2vl_field_config(data),
required_fields={"image_embeds", "image_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_image_data(data)
......@@ -705,8 +705,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems(
data,
modality="video",
fields_config=_qwen2vl_field_config(data),
required_fields={"video_embeds", "video_grid_thw"},
fields_factory=_qwen2vl_field_config,
)
return super()._parse_video_data(data)
......
......@@ -125,17 +125,14 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
self,
data: Mapping[str, torch.Tensor],
modality: str,
fields_config: Mapping[str, MultiModalFieldConfig],
required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None:
super().__init__(data, modality)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys:
data_keys = set(data.keys())
......@@ -143,6 +140,13 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
f"but only found the following keys: {data_keys}")
raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config
self.required_fields = required_fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment