Unverified Commit 8c9da6be authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Simplify mm processing cache (#22457)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 399d2a10
...@@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], BatchFeature, bool]:
""" """
Qwen2.5-Omni reimplements this function to handle text only. Qwen2.5-Omni reimplements this function to handle text only.
""" """
...@@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only( mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_processed_data, False
def _apply_hf_processor_mm_only( def _apply_hf_processor_mm_only(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> BatchFeature:
""" """
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
""" """
...@@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
assert "audio" in mm_counts assert "audio" in mm_counts
mm_counts["audio"] -= mm_counts["video"] mm_counts["audio"] -= mm_counts["video"]
_, mm_kwargs, _ = self._apply_hf_processor_text_mm( _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_processed_data
def _validate_mm_placeholders( def _validate_mm_placeholders(
self, self,
......
...@@ -22,7 +22,8 @@ from typing import Literal, Optional, Union ...@@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
import regex as re import regex as re
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
...@@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
): ) -> tuple[list[int], BatchFeature, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.
......
...@@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext ...@@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache from .cache import MultiModalCache
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
...@@ -887,120 +887,19 @@ def find_mm_placeholders( ...@@ -887,120 +887,19 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
class ProcessingCacheItem(NamedTuple):
key: str
value: MultiModalKwargsItem
class ProcessingCache(MultiModalCache): class ProcessingCache(MultiModalCache):
def __init__( def __init__(self, capacity_gb: float) -> None:
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)
def _maybe_log_cache_stats(self) -> None: self.get = self._cache.get
steps = self.debug_cache_hit_ratio_steps self.put = self._cache.put
if not steps: self.reset = self._cache.clear
return
total = self.debug_cache_total
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> Optional[MultiModalKwargsItem]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
- The model ID
- The modality of the item
- The original data item passed to the HF processor
- The configuration options of the HF processor
"""
self._maybe_log_cache_stats()
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
if self.debug_cache_hit_ratio_steps: _CacheItemOrHash = Union[MultiModalKwargsItem, str]
if cache_key in self._cache:
self.debug_cache_hits += 1
self.debug_cache_total += 1
return self._cache.get(cache_key)
def get_item(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> ProcessingCacheOptionalItem:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return ProcessingCacheOptionalItem(
key=cache_key,
value=self._cache.get(cache_key),
)
def put(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
output_kwargs: MultiModalKwargsItem,
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache[cache_key] = output_kwargs
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value
def reset(self) -> bool:
self._cache.clear()
return True
class BaseProcessingInfo: class BaseProcessingInfo:
...@@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], "BatchFeature", bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.
...@@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids, = processed_data.pop("input_ids").tolist() prompt_ids, = processed_data.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
is_update_applied = self._hf_processor_applies_updates( is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_items, mm_items=mm_items,
...@@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, is_update_applied return prompt_ids, processed_data, is_update_applied
def _apply_hf_processor_text_only( def _apply_hf_processor_text_only(
self, prompt_text: str, self,
tokenization_kwargs: Mapping[str, object]) -> list[int]: prompt_text: str,
tokenization_kwargs: Mapping[str, object],
) -> list[int]:
""" """
Apply the HF processor on the prompt text only. Apply the HF processor on the prompt text only.
...@@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> "BatchFeature":
""" """
Apply the HF processor on the multi-modal data only. Apply the HF processor on the multi-modal data only.
...@@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
mm_counts = mm_items.get_all_counts() mm_counts = mm_items.get_all_counts()
_, mm_kwargs, _ = self._apply_hf_processor_text_mm( _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_processed_data
def _apply_hf_processor_main( def _apply_hf_processor_main(
self, self,
...@@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], "BatchFeature", bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data. Apply the HF processor on the prompt text and multi-modal data.
...@@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only( mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_processed_data, False
def _get_cache_missing_items( def _get_cache_missing_items(
self, self,
cache: ProcessingCache, cache: ProcessingCache,
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_hashes: MultiModalHashes,
tokenization_kwargs: Mapping[str, object], ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
str, list[object]]]: modality: [(h if (v := cache.get(h)) is None else v)
model_id = self.info.model_id for h in hashes]
for modality, hashes in mm_hashes.items()
mm_cache_items = {
modality: [
cache.get_item(
model_id, modality, item,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
for item in items
]
for modality, items in mm_data_items.items()
} }
mm_missing_idxs = { mm_missing_idxs = {
modality: [ modality: [
idx for idx, item in enumerate(cache_items) idx for idx, item_or_hash in enumerate(items_or_hashes)
if item.value is None if isinstance(item_or_hash, str)
] ]
for modality, cache_items in mm_cache_items.items() for modality, items_or_hashes in mm_cache_items_or_hashes.items()
} }
mm_missing_data = { mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs] modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items() for modality, idxs in mm_missing_idxs.items()
} }
return mm_cache_items, mm_missing_data return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items( def _hash_mm_items(
self, mm_items: MultiModalDataItems, self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).""" """Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id model_id = self.info.model_id
...@@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _merge_mm_kwargs( def _merge_mm_kwargs(
self, self,
cache: ProcessingCache, cache: ProcessingCache,
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]], mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_data: dict[str, list[object]],
mm_missing_kwargs: MultiModalKwargs, mm_missing_kwargs: MultiModalKwargs,
) -> dict[str, list[ProcessingCacheItem]]: ) -> dict[str, list[MultiModalKwargsItem]]:
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data} mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[ProcessingCacheItem]](list) merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
for modality, cache_items in mm_cache_items.items(): for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for cache_item in cache_items: for item_or_hash in items_or_hashes:
if cache_item.value is None: if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs.get_item( kw_item = mm_missing_kwargs.get_item(
modality, modality,
mm_missing_next_idx[modality], mm_missing_next_idx[modality],
) )
cache_item_new = ProcessingCacheItem( cache.put(item_or_hash, kw_item)
key=cache_item.key,
value=kw_item,
)
cache.put_item(cache_item_new)
mm_missing_next_idx[modality] += 1 mm_missing_next_idx[modality] += 1
else: else:
cache_item_new = ProcessingCacheItem( kw_item = item_or_hash
key=cache_item.key,
value=cache_item.value,
)
merged_items[modality].append(cache_item_new) merged_items[modality].append(kw_item)
return dict(merged_items) return dict(merged_items)
...@@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_processed_data,
is_update_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
...@@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=True, enable_hf_prompt_update=True,
) )
mm_kwargs = MultiModalKwargs.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs),
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs) tokenization_kwargs)
if return_mm_hashes else None) if return_mm_hashes else None)
...@@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
( (
mm_cache_items, mm_cache_items_or_hashes,
mm_missing_data, mm_missing_data_items,
) = self._get_cache_missing_items( ) = self._get_cache_missing_items(
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, mm_hashes=mm_hashes,
tokenization_kwargs=tokenization_kwargs,
) )
mm_hashes_to_return = mm_hashes if return_mm_hashes else None
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal # so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items # items are combined with the cached multimodal items
( (
prompt_ids, prompt_ids,
mm_missing_kwargs, mm_missing_processed_data,
is_update_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=self._to_mm_items(mm_missing_data), mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False, enable_hf_prompt_update=False,
) )
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(mm_missing_processed_data,
hf_processor_mm_kwargs),
)
mm_cache_items_merged = self._merge_mm_kwargs( mm_cache_items_merged = self._merge_mm_kwargs(
cache, cache,
mm_cache_items=mm_cache_items, mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_data=mm_missing_data,
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs=mm_missing_kwargs,
) )
mm_kwargs = MultiModalKwargs.from_items([ mm_kwargs = MultiModalKwargs.from_items([
item.value for cache_items in mm_cache_items_merged.values() item for cache_items in mm_cache_items_merged.values()
for item in cache_items for item in cache_items
]) ])
mm_hashes = { return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
modality: [item.key for item in cache_items]
for modality, cache_items in mm_cache_items_merged.items()
} if return_mm_hashes else None
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _bind_and_group_updates( def _bind_and_group_updates(
self, self,
......
...@@ -312,25 +312,25 @@ class MsgpackDecoder: ...@@ -312,25 +312,25 @@ class MsgpackDecoder:
return arr.view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = [] return [self._decode_mm_item(v) for v in obj]
for item in obj:
elems = [] def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
for v in item: return MultiModalKwargsItem.from_elems(
v["data"] = self._decode_nested_tensors(v["data"]) [self._decode_mm_field_elem(v) for v in obj])
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig # Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"] factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
factory_meth_name)
# Special case: decode the union "slices" field of # Special case: decode the union "slices" field of
# MultiModalFlatField # MultiModalFlatField
if factory_meth_name == "flat": if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0]) field_args[0] = self._decode_nested_slices(field_args[0])
v["field"] = factory_meth(None, *field_args).field obj["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v)) return MultiModalFieldElem(**obj)
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
return decoded_items
def _decode_nested_tensors(self, obj: Any) -> NestedTensors: def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)): if isinstance(obj, (int, float)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment