Unverified Commit 11b55687 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ee484b3f
......@@ -20,67 +20,6 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][].
### Stable UUIDs for Caching (multi_modal_uuids)
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [img_a, img_b]},
# Provide stable IDs for caching.
# Requirements (matched by this example):
# - Include every modality present in multi_modal_data.
# - For lists, provide the same number of entries.
# - Use None to fall back to content hashing for that item.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [None, img_b]},
# Since img_a is expected to be cached, we can skip sending the actual
# image entirely.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
!!! warning
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
### Image Inputs
You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples:
......@@ -397,7 +336,8 @@ No manual conversion is needed - vLLM handles the channel normalization automati
### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
pass a tensor of shape `(..., hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
The exact shape depends on the model being used.
You must enable this feature via `enable_mm_embeds=True`.
......@@ -418,8 +358,7 @@ You must enable this feature via `enable_mm_embeds=True`.
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Embeddings for single image
# torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
# For most models, `image_embeds` has shape: (num_images, image_feature_size, hidden_size)
image_embeds = torch.load(...)
outputs = llm.generate({
......@@ -430,21 +369,8 @@ You must enable this feature via `enable_mm_embeds=True`.
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings:
??? code
```python
# Construct the prompt based on your model
prompt = ...
# Embeddings for multiple images
# torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)
# Qwen2-VL
# Additional examples for models that require extra fields
llm = LLM(
"Qwen/Qwen2-VL-2B-Instruct",
limit_mm_per_prompt={"image": 4},
......@@ -452,13 +378,15 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
)
mm_data = {
"image": {
"image_embeds": image_embeds,
# Shape: (total_feature_size, hidden_size)
# total_feature_size = sum(image_feature_size for image in images)
"image_embeds": torch.load(...),
# Shape: (num_images, 3)
# image_grid_thw is needed to calculate positional encoding.
"image_grid_thw": torch.load(...), # torch.Tensor of shape (1, 3),
"image_grid_thw": torch.load(...),
}
}
# MiniCPM-V
llm = LLM(
"openbmb/MiniCPM-V-2_6",
trust_remote_code=True,
......@@ -467,20 +395,14 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
)
mm_data = {
"image": {
"image_embeds": image_embeds,
# Shape: (num_images, num_slices, hidden_size)
# num_slices can differ for each image
"image_embeds": [torch.load(...) for image in images],
# Shape: (num_images, 2)
# image_sizes is needed to calculate details of the sliced image.
"image_sizes": [image.size for image in images], # list of image sizes
"image_sizes": [image.size for image in images],
}
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
```
For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features.
......@@ -501,8 +423,8 @@ You can pass pre-computed audio embeddings similar to image embeddings:
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:"
# Load pre-computed audio embeddings
# torch.Tensor of shape (1, audio_feature_size, hidden_size of LM)
# Load pre-computed audio embeddings, usually with shape:
# (num_audios, audio_feature_size, hidden_size of LM)
audio_embeds = torch.load(...)
outputs = llm.generate({
......@@ -515,6 +437,67 @@ You can pass pre-computed audio embeddings similar to image embeddings:
print(generated_text)
```
### Cached Inputs
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [img_a, img_b]},
# Provide stable IDs for caching.
# Requirements (matched by this example):
# - Include every modality present in multi_modal_data.
# - For lists, provide the same number of entries.
# - Use None to fall back to content hashing for that item.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [None, img_b]},
# Since img_a is expected to be cached, we can skip sending the actual
# image entirely.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
!!! warning
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
## Online Serving
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
......@@ -879,7 +862,11 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo
### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
pass a tensor of shape `(..., hidden_size of LM)` for each item to the corresponding field of the multi-modal dictionary.
!!! important
Unlike offline inference, the embeddings for each item must be passed separately
in order for placeholder tokens to be applied correctly by the chat template.
You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`.
......@@ -897,11 +884,6 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
```python
from vllm.utils.serial_utils import tensor2base64
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
base64_image_embedding = tensor2base64(image_embedding)
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
......@@ -912,29 +894,33 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
model = "llava-hf/llava-1.5-7b-hf"
embeds = {
"type": "image_embeds",
"image_embeds": f"{base64_image_embedding}",
"image_embeds": tensor2base64(torch.load(...)), # Shape: (image_feature_size, hidden_size)
"uuid": image_url, # Optional
}
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
# Additional examples for models that require extra fields
model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}", # Required
"image_grid_thw": f"{base64_image_grid_thw}", # Required by Qwen/Qwen2-VL-2B-Instruct
"image_embeds": tensor2base64(torch.load(...)), # Shape: (image_feature_size, hidden_size)
"image_grid_thw": tensor2base64(torch.load(...)), # Shape: (3,)
},
"uuid": image_url, # Optional
}
model = "openbmb/MiniCPM-V-2_6"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}", # Required
"image_sizes": f"{base64_image_sizes}", # Required by openbmb/MiniCPM-V-2_6
"image_embeds": tensor2base64(torch.load(...)), # Shape: (num_slices, hidden_size)
"image_sizes": tensor2base64(torch.load(...)), # Shape: (2,)
},
"uuid": image_url, # Optional
}
# Single image input
chat_completion = client.chat.completions.create(
messages=[
{
......@@ -954,9 +940,55 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
],
model=model,
)
# Multi image input
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
embeds,
embeds,
],
},
],
model=model,
)
# Multi image input (interleaved)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": [
embeds,
{
"type": "text",
"text": "What's in this image?",
},
embeds,
],
},
],
model=model,
)
```
For Online Serving, you can also skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this:
### Cached Inputs
Just like with offline inference, you can skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this:
??? code
......@@ -990,13 +1022,3 @@ For Online Serving, you can also skip sending media if you expect cache hits wit
},
```
!!! note
Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`.
**Important**: The embedding shape format differs based on the number of embeddings:
- **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)`
- **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)`
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
......@@ -59,8 +59,10 @@ class PrithviMAE:
input_data = input_data[0]
mm_data = {
"pixel_values": input_data,
"location_coords": location_coords,
"image": {
"pixel_values": input_data,
"location_coords": location_coords,
}
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
......
......@@ -13,31 +13,10 @@ from vllm.utils.serial_utils import tensor2base64
from ...utils import RemoteOpenAIServer
def _terratorch_dummy_messages():
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
return [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": tensor2base64(pixel_values),
"location_coords": tensor2base64(location_coords),
},
}
],
}
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
)
def test_single_request(model_name: str):
def test_single_content(model_name: str):
args = [
"--runner",
"pooling",
......@@ -59,7 +38,24 @@ def test_single_request(model_name: str):
server.url_for("pooling"),
json={
"model": model_name,
"messages": _terratorch_dummy_messages(),
"messages": [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": tensor2base64(
torch.ones((6, 512, 512), dtype=torch.float16)
),
"location_coords": tensor2base64(
torch.ones((1, 2), dtype=torch.float16)
),
},
},
],
}
],
"encoding_format": "base64",
},
)
......@@ -69,3 +65,87 @@ def test_single_request(model_name: str):
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
assert len(np_response) == 524288
@pytest.mark.parametrize("model_name", ["Qwen/Qwen3-VL-2B-Instruct"])
def test_multi_content(model_name: str):
args = [
"--enforce-eager",
"--max-num-seqs",
"32",
"--max-model-len",
"8192",
"--enable-mm-embeds",
]
with RemoteOpenAIServer(model_name, args) as server:
client = server.get_client()
# Image only
chat_completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
],
}
],
max_tokens=5,
)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
# Interleaved text and image
chat_completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
{"type": "text", "text": "OCR:"},
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
],
}
],
max_tokens=5,
)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
......@@ -68,6 +68,16 @@ def phi3v_model_config_image_embeds():
)
@pytest.fixture(scope="function")
def qwen25omni_model_config_image_embeds():
return ModelConfig(
QWEN25OMNI_MODEL_ID,
runner="generate",
limit_mm_per_prompt={"image": 2},
enable_mm_embeds=True,
)
@pytest.fixture(scope="function")
def qwen2_audio_model_config():
return ModelConfig(
......@@ -823,7 +833,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding)
......@@ -865,7 +876,8 @@ async def test_parse_chat_messages_audio_embeds_async(
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding)
......@@ -908,8 +920,9 @@ def test_parse_chat_messages_multiple_image_embeds(
can be provided in a single request, similar to regular images.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024)
image_embedding_2 = torch.randn(128, 1024)
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_1 = torch.randn(256, hidden_size)
image_embedding_2 = torch.randn(128, hidden_size)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
......@@ -1022,8 +1035,9 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768)
image_embedding_2 = torch.randn(150, 768)
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_1 = torch.randn(200, hidden_size)
image_embedding_2 = torch.randn(150, hidden_size)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
......@@ -1145,13 +1159,14 @@ def test_parse_chat_messages_empty_dict_image_embeds(
def test_parse_chat_messages_multiple_dict_image_embeds(
phi3v_model_config_image_embeds,
qwen25omni_model_config_image_embeds,
):
"""Test that multiple dictionaries for image_embeds is handled without errors."""
# Create two sample image embedding tensors
batch_size = 2
image_embedding_1 = torch.randn(batch_size, 256, 1024)
image_embedding_2 = torch.randn(batch_size, 3)
hidden_size = qwen25omni_model_config_image_embeds.get_inputs_embeds_size()
image_embeds = torch.randn(batch_size * 220, hidden_size)
image_grid_thw = torch.tensor([[1, 22, 40] for _ in range(batch_size)])
conversation, mm_data, mm_uuids = parse_chat_messages(
[
......@@ -1161,18 +1176,20 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
{
"type": "image_embeds",
"image_embeds": {
"image_embedding_1": tensor2base64(p),
"image_embedding_2": tensor2base64(i),
"image_embeds": tensor2base64(embeds),
"image_grid_thw": tensor2base64(grid_thw),
},
}
for p, i in zip(image_embedding_1, image_embedding_2)
for embeds, grid_thw in zip(
image_embeds.chunk(batch_size), image_grid_thw
)
]
+ [
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
qwen25omni_model_config_image_embeds,
content_format="string",
)
......@@ -1180,7 +1197,8 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
"content": "<|vision_start|><|IMAGE|><|vision_end|>\n"
"<|vision_start|><|IMAGE|><|vision_end|>\nDescribe these two images.",
}
]
......@@ -1191,10 +1209,10 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert len(mm_data["image"]) == batch_size
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
assert isinstance(mm_data["image"]["image_embeds"], torch.Tensor)
assert mm_data["image"]["image_embeds"].shape == image_embeds.shape
assert isinstance(mm_data["image"]["image_grid_thw"], torch.Tensor)
assert mm_data["image"]["image_grid_thw"].shape == image_grid_thw.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
......
......@@ -7,14 +7,6 @@ import torch
from ....conftest import VllmRunner
def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data
def _run_test(
vllm_runner: type[VllmRunner],
model: str,
......@@ -23,7 +15,12 @@ def _run_test(
{
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": generate_test_mm_data(),
"multi_modal_data": {
"image": {
"pixel_values": torch.ones((6, 512, 512), dtype=torch.float16),
"location_coords": torch.ones((1, 2), dtype=torch.float16),
}
},
}
for _ in range(10)
]
......
......@@ -349,8 +349,10 @@ class PrithviMultimodalDataProcessor(IOProcessor):
{
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
"image": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
}
},
}
)
......
......@@ -5,14 +5,8 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
MultiModalConfig,
VllmConfig,
)
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor
......@@ -21,55 +15,26 @@ stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _mock_input_processor(
monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
def _build_input_processor(
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> InputProcessor:
"""
Create a Processor instance with minimal configuration suitable for unit
tests without accessing external resources.
"""
monkeypatch.setattr(
ModelConfig, "try_get_generation_config", lambda self: {}, raising=True
)
monkeypatch.setattr(
ModelConfig, "__post_init__", lambda self, *args: None, raising=True
)
monkeypatch.setattr(
ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True,
)
monkeypatch.setattr(
MultiModalRegistry,
"processor_cache_from_config",
lambda self, vllm_config: None,
raising=True,
)
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
model_config = ModelConfig(
tokenizer="dummy",
model="Qwen/Qwen2.5-VL-3B-Instruct",
skip_tokenizer_init=True,
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm",
)
model_config.runner_type = "generate"
model_config.multimodal_config = MultiModalConfig(mm_processor_cache_gb=mm_cache_gb)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"),
)
return InputProcessor(vllm_config)
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
input_processor = _mock_input_processor(monkeypatch)
def test_multi_modal_uuids_length_mismatch_raises():
input_processor = _build_input_processor()
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
......@@ -78,7 +43,7 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError, match="must have same length as data"):
with pytest.raises(ValueError, match="must have same length as"):
input_processor.process_inputs(
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
......@@ -86,21 +51,21 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
)
def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
input_processor = _mock_input_processor(monkeypatch)
def test_multi_modal_uuids_missing_modality_raises():
input_processor = _build_input_processor()
prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data
"multi_modal_data": {
"image": [cherry_pil_image],
"video": [baby_reading_np_ndarrays],
"video": None,
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {"image": ["hash_cherry"]},
}
with pytest.raises(ValueError, match="must be provided if multi_modal_data"):
with pytest.raises(ValueError, match="is empty but .* is missing"):
input_processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
......@@ -119,8 +84,7 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
input_processor = _mock_input_processor(
monkeypatch,
input_processor = _build_input_processor(
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching,
)
......@@ -163,8 +127,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
input_processor = _mock_input_processor(
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
input_processor = _build_input_processor(
mm_cache_gb=0.0, enable_prefix_caching=False
)
captured: dict[str, MultiModalUUIDDict] = {}
......@@ -180,12 +144,12 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
"video": [baby_reading_np_ndarrays],
},
"multi_modal_uuids": mm_uuids,
}
......@@ -197,16 +161,15 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
# Expect request-id-based overrides are passed through
mm_uuids = captured["mm_uuids"]
assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][0].endswith("-0")
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[
"image"
][1].endswith("-1")
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[
"video"
][0].endswith("-0")
assert captured["mm_uuids"]["image"][0].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][0].endswith("-0")
assert captured["mm_uuids"]["image"][1].startswith(
f"{request_id}-image-"
) and captured["mm_uuids"]["image"][1].endswith("-1")
assert captured["mm_uuids"]["video"][0].startswith(
f"{request_id}-video-"
) and captured["mm_uuids"]["video"][0].endswith("-0")
This diff is collapsed.
......@@ -1574,10 +1574,6 @@ class LLM:
try:
for i, prompt in enumerate(it):
if isinstance(prompt, dict):
self._validate_mm_data_and_uuids(
prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
)
request_id = self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
......@@ -1593,54 +1589,6 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e
def _validate_mm_data_and_uuids(
self,
multi_modal_data: Any | None, # MultiModalDataDict
multi_modal_uuids: Any | None, # MultiModalUUIDDict
):
"""
Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set.
"""
if multi_modal_data is None:
return
for modality, data in multi_modal_data.items():
if isinstance(data, list):
for i, d in enumerate(data):
if d is None:
if (
multi_modal_uuids is None
or modality not in multi_modal_uuids
or multi_modal_uuids[ # noqa: E501
modality
]
is None
):
raise ValueError(
f"Multi-modal data for {modality} is None "
f"but UUID is not provided"
)
else:
if (
len(multi_modal_uuids[modality]) <= i
or multi_modal_uuids[modality][i] is None
):
raise ValueError(
f"Multi-modal data for {modality} is None "
f"but UUID is not provided"
)
else:
if data is None and (
multi_modal_uuids is None
or modality not in multi_modal_uuids
or multi_modal_uuids[modality] is None
):
raise ValueError(
f"Multi-modal data for {modality} is None"
f" but UUID is not provided"
)
def _process_inputs(
self,
request_id: str,
......
......@@ -19,7 +19,7 @@ from vllm.entrypoints.chat_utils import (
)
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
......@@ -95,7 +95,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None]:
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content(data_1, mm_tracker)
......@@ -109,8 +109,9 @@ def parse_score_data(
prompt_1 = ensure_str(content_1)
prompt_2 = ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_tracker.all_mm_data()
return prompt_1, prompt_2, mm_items, mm_uuids
def _parse_score_content(
......@@ -187,7 +188,7 @@ def get_score_prompt(
data_2: str | ScoreContentPartParam,
score_template: str | None = None,
) -> tuple[str, TokensPrompt]:
prompt_1, prompt_2, mm_data = parse_score_data(
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
data_1,
data_2,
model_config,
......@@ -248,6 +249,9 @@ def get_score_prompt(
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt
......
......@@ -113,7 +113,11 @@ from .qwen2_5_vl import (
Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs,
)
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo
from .qwen2_vl import (
Qwen2VLMultiModalDataParser,
Qwen2VLProcessingInfo,
_create_qwen2vl_field_factory,
)
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (
AutoWeightsLoader,
......@@ -985,28 +989,9 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
return _create_qwen2vl_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def _get_prompt_updates(
self,
......
......@@ -18,7 +18,7 @@
"""Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Any
import torch
......@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import attn_type
......@@ -69,28 +70,21 @@ from .interfaces_base import attn_type
logger = init_logger(__name__)
def _terratorch_field_names(pretrained_cfg: dict):
input_definition = InputDefinition(**pretrained_cfg["input"])
def _terratorch_field_names(input_definition: InputDefinition):
return set(input_definition.data.keys())
def _terratorch_field_factory(
pretrained_cfg: dict,
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
for input_name, input in input_definition.data.items():
def _terratorch_field_factory(input_definition: InputDefinition):
def _terratorch_field_config(
hf_inputs: Mapping[str, torch.Tensor],
) -> Mapping[str, MultiModalFieldConfig]:
fields = dict[str, MultiModalFieldConfig]()
for name, input in input_definition.data.items():
modality = "image"
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
return {
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}
return fields
return _terratorch_field_config
......@@ -130,26 +124,31 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
......@@ -159,18 +158,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
*,
cache: MultiModalProcessorOnlyCache | None = None,
) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
self._input_definition = InputDefinition(**pretrained_cfg["input"])
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
return TerratorchMultiModalDataParser(self._input_definition)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
return _terratorch_field_factory(self._input_definition)(hf_inputs)
def _get_prompt_updates(
self,
......@@ -188,23 +189,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data)
mm_processed_data = BatchFeature(
mm_data.get("image", mm_data), tensor_type="pt"
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
......@@ -272,9 +266,15 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
model_output = self.inference_runner.forward(**kwargs)
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
model_output = self.inference_runner.forward(**batched_kwargs).output
return model_output.output
# The leading dimension of hidden states needs to equal input length
return model_output.expand(
input_len, *(-1 for _ in range(model_output.ndim - 1))
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_list = []
......
......@@ -13,7 +13,7 @@ def random_uuid() -> str:
def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: list[int] | None,
prompt_token_ids: list[int] | torch.Tensor | None,
prompt_embeds: torch.Tensor | None,
) -> int:
"""Calculate the request length (in number of tokens) give either
......
......@@ -8,14 +8,24 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
......@@ -188,7 +198,66 @@ class InputProcessor:
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
return mm_processor.data_parser.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
prompt = TextPrompt(prompt=prompt)
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
......@@ -196,55 +265,13 @@ class InputProcessor:
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: dict | str) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = _get_len(items)
uuid_len = _get_len(mm_uuids[modality])
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got "
f"{uuid_len} uuids vs {data_len} items."
)
else:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided."
)
if is_explicit_encoder_decoder_prompt(prompt):
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(cast(dict | str, enc))
if dec is not None:
_validate_single_prompt(cast(dict | str, dec))
if (dec_prompt := prompt["decoder_prompt"]) is not None:
self._validate_singleton_mm_uuids(dec_prompt)
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
......@@ -379,6 +406,20 @@ class InputProcessor:
# roundtrip serialization/deserialization won't fail.
params.structured_outputs.__post_init__()
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, str):
return None
return prompt.get("multi_modal_data") # type: ignore[return-value]
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
if is_explicit_encoder_decoder_prompt(prompt):
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
......@@ -391,31 +432,18 @@ class InputProcessor:
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
mm_data = self._extract_mm_data(prompt)
if not mm_data:
return None
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
# Hash each item for embedding inputs.
n = (
len(data)
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
else 1
)
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
mm_items = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
return {
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
def _get_mm_identifier(
self,
......@@ -494,7 +522,7 @@ class InputProcessor:
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
self._validate_mm_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment