Unverified Commit 7b623fca authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Check required fields before initializing field config in `DictEmbeddingItems` (#13380)

parent 238dfc8a
...@@ -184,8 +184,8 @@ llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={ ...@@ -184,8 +184,8 @@ llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={
mm_data = { mm_data = {
"image": { "image": {
"image_embeds": image_embeds, "image_embeds": image_embeds,
# image_size_list is needed to calculate details of the sliced image. # image_sizes is needed to calculate details of the sliced image.
"image_size_list": [image.size for image in images], # list of image sizes "image_sizes": [image.size for image in images], # list of image sizes
} }
} }
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Tuple, TypedDict, Union) Optional, Set, Tuple, TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -122,13 +122,16 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): ...@@ -122,13 +122,16 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
def __init__( def __init__(
self, self,
data: Mapping[str, torch.Tensor], data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig], fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None: ) -> None:
super().__init__( super().__init__(
data, data,
modality="image", modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"}, required_fields={"audio_embeds"},
fields_factory=fields_factory,
) )
...@@ -141,7 +144,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): ...@@ -141,7 +144,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems( return MiniCPMOAudioEmbeddingItems(
data, data,
fields_config=_minicpmo_field_config(data), fields_factory=_minicpmo_field_config,
) )
return super()._parse_audio_data(data) return super()._parse_audio_data(data)
......
...@@ -255,13 +255,16 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): ...@@ -255,13 +255,16 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
def __init__( def __init__(
self, self,
data: Mapping[str, torch.Tensor], data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig], fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None: ) -> None:
super().__init__( super().__init__(
data, data,
modality="image", modality="image",
fields_config=fields_config,
required_fields={"image_embeds", "image_sizes"}, required_fields={"image_embeds", "image_sizes"},
fields_factory=fields_factory,
) )
def get_image_size(self, index: int) -> ImageSize: def get_image_size(self, index: int) -> ImageSize:
...@@ -274,13 +277,16 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): ...@@ -274,13 +277,16 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
def __init__( def __init__(
self, self,
data: Mapping[str, torch.Tensor], data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig], fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None: ) -> None:
super().__init__( super().__init__(
data, data,
modality="video", modality="video",
fields_config=fields_config,
required_fields={"video_embeds", "video_image_sizes"}, required_fields={"video_embeds", "video_image_sizes"},
fields_factory=fields_factory,
) )
def get_frame_size(self, index: int) -> ImageSize: def get_frame_size(self, index: int) -> ImageSize:
...@@ -300,7 +306,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -300,7 +306,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems( return MiniCPMVImageEmbeddingItems(
data, data,
fields_config=_minicpmv_field_config(data), fields_factory=_minicpmv_field_config,
) )
return super()._parse_image_data(data) return super()._parse_image_data(data)
...@@ -312,7 +318,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -312,7 +318,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems( return MiniCPMVVideoEmbeddingItems(
data, data,
fields_config=_minicpmv_field_config(data), fields_factory=_minicpmv_field_config,
) )
return super()._parse_video_data(data) return super()._parse_video_data(data)
......
...@@ -691,8 +691,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -691,8 +691,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems( return DictEmbeddingItems(
data, data,
modality="image", modality="image",
fields_config=_qwen2vl_field_config(data),
required_fields={"image_embeds", "image_grid_thw"}, required_fields={"image_embeds", "image_grid_thw"},
fields_factory=_qwen2vl_field_config,
) )
return super()._parse_image_data(data) return super()._parse_image_data(data)
...@@ -705,8 +705,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser): ...@@ -705,8 +705,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
return DictEmbeddingItems( return DictEmbeddingItems(
data, data,
modality="video", modality="video",
fields_config=_qwen2vl_field_config(data),
required_fields={"video_embeds", "video_grid_thw"}, required_fields={"video_embeds", "video_grid_thw"},
fields_factory=_qwen2vl_field_config,
) )
return super()._parse_video_data(data) return super()._parse_video_data(data)
......
...@@ -125,17 +125,14 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], ...@@ -125,17 +125,14 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
self, self,
data: Mapping[str, torch.Tensor], data: Mapping[str, torch.Tensor],
modality: str, modality: str,
fields_config: Mapping[str, MultiModalFieldConfig],
required_fields: set[str], required_fields: set[str],
fields_factory: Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
],
) -> None: ) -> None:
super().__init__(data, modality) super().__init__(data, modality)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
missing_required_data_keys = required_fields - data.keys() missing_required_data_keys = required_fields - data.keys()
if missing_required_data_keys: if missing_required_data_keys:
data_keys = set(data.keys()) data_keys = set(data.keys())
...@@ -143,6 +140,13 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], ...@@ -143,6 +140,13 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
f"but only found the following keys: {data_keys}") f"but only found the following keys: {data_keys}")
raise ValueError(msg) raise ValueError(msg)
fields_config = fields_factory(data)
missing_required_fields = required_fields - fields_config.keys()
if missing_required_fields:
fields = set(fields_config.keys())
msg = f"{required_fields=} should be a subset of {fields=}"
raise ValueError(msg)
self.fields_config = fields_config self.fields_config = fields_config
self.required_fields = required_fields self.required_fields = required_fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment