Unverified Commit 11b55687 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ee484b3f
...@@ -20,67 +20,6 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: ...@@ -20,67 +20,6 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
- `prompt`: The prompt should follow the format that is documented on HuggingFace. - `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. - `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][].
### Stable UUIDs for Caching (multi_modal_uuids)
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [img_a, img_b]},
# Provide stable IDs for caching.
# Requirements (matched by this example):
# - Include every modality present in multi_modal_data.
# - For lists, provide the same number of entries.
# - Use None to fall back to content hashing for that item.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [None, img_b]},
# Since img_a is expected to be cached, we can skip sending the actual
# image entirely.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
!!! warning
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
### Image Inputs ### Image Inputs
You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples:
...@@ -397,7 +336,8 @@ No manual conversion is needed - vLLM handles the channel normalization automati ...@@ -397,7 +336,8 @@ No manual conversion is needed - vLLM handles the channel normalization automati
### Embedding Inputs ### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. pass a tensor of shape `(..., hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
The exact shape depends on the model being used.
You must enable this feature via `enable_mm_embeds=True`. You must enable this feature via `enable_mm_embeds=True`.
...@@ -418,8 +358,7 @@ You must enable this feature via `enable_mm_embeds=True`. ...@@ -418,8 +358,7 @@ You must enable this feature via `enable_mm_embeds=True`.
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:" prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Embeddings for single image # For most models, `image_embeds` has shape: (num_images, image_feature_size, hidden_size)
# torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...) image_embeds = torch.load(...)
outputs = llm.generate({ outputs = llm.generate({
...@@ -430,21 +369,8 @@ You must enable this feature via `enable_mm_embeds=True`. ...@@ -430,21 +369,8 @@ You must enable this feature via `enable_mm_embeds=True`.
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
```
For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings:
??? code
```python
# Construct the prompt based on your model
prompt = ...
# Embeddings for multiple images # Additional examples for models that require extra fields
# torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)
# Qwen2-VL
llm = LLM( llm = LLM(
"Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-2B-Instruct",
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
...@@ -452,13 +378,15 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ...@@ -452,13 +378,15 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
) )
mm_data = { mm_data = {
"image": { "image": {
"image_embeds": image_embeds, # Shape: (total_feature_size, hidden_size)
# total_feature_size = sum(image_feature_size for image in images)
"image_embeds": torch.load(...),
# Shape: (num_images, 3)
# image_grid_thw is needed to calculate positional encoding. # image_grid_thw is needed to calculate positional encoding.
"image_grid_thw": torch.load(...), # torch.Tensor of shape (1, 3), "image_grid_thw": torch.load(...),
} }
} }
# MiniCPM-V
llm = LLM( llm = LLM(
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
trust_remote_code=True, trust_remote_code=True,
...@@ -467,20 +395,14 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd ...@@ -467,20 +395,14 @@ For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embedd
) )
mm_data = { mm_data = {
"image": { "image": {
"image_embeds": image_embeds, # Shape: (num_images, num_slices, hidden_size)
# num_slices can differ for each image
"image_embeds": [torch.load(...) for image in images],
# Shape: (num_images, 2)
# image_sizes is needed to calculate details of the sliced image. # image_sizes is needed to calculate details of the sliced image.
"image_sizes": [image.size for image in images], # list of image sizes "image_sizes": [image.size for image in images],
} }
} }
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
``` ```
For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features. For Qwen3-VL, the `image_embeds` should contain both the base image embedding and deepstack features.
...@@ -501,8 +423,8 @@ You can pass pre-computed audio embeddings similar to image embeddings: ...@@ -501,8 +423,8 @@ You can pass pre-computed audio embeddings similar to image embeddings:
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:" prompt = "USER: <audio>\nWhat is in this audio?\nASSISTANT:"
# Load pre-computed audio embeddings # Load pre-computed audio embeddings, usually with shape:
# torch.Tensor of shape (1, audio_feature_size, hidden_size of LM) # (num_audios, audio_feature_size, hidden_size of LM)
audio_embeds = torch.load(...) audio_embeds = torch.load(...)
outputs = llm.generate({ outputs = llm.generate({
...@@ -515,6 +437,67 @@ You can pass pre-computed audio embeddings similar to image embeddings: ...@@ -515,6 +437,67 @@ You can pass pre-computed audio embeddings similar to image embeddings:
print(generated_text) print(generated_text)
``` ```
### Cached Inputs
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [img_a, img_b]},
# Provide stable IDs for caching.
# Requirements (matched by this example):
# - Include every modality present in multi_modal_data.
# - For lists, provide the same number of entries.
# - Use None to fall back to content hashing for that item.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
Using UUIDs, you can also skip sending media data entirely if you expect cache hits for respective items. Note that the request will fail if the skipped media doesn't have a corresponding UUID, or if the UUID fails to hit the cache.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [None, img_b]},
# Since img_a is expected to be cached, we can skip sending the actual
# image entirely.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
!!! warning
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
## Online Serving ## Online Serving
Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests. Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions API](https://platform.openai.com/docs/api-reference/chat). Media inputs also support optional UUIDs users can provide to uniquely identify each media, which is used to cache the media results across requests.
...@@ -879,7 +862,11 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo ...@@ -879,7 +862,11 @@ Full example: [examples/online_serving/openai_chat_completion_client_for_multimo
### Embedding Inputs ### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. pass a tensor of shape `(..., hidden_size of LM)` for each item to the corresponding field of the multi-modal dictionary.
!!! important
Unlike offline inference, the embeddings for each item must be passed separately
in order for placeholder tokens to be applied correctly by the chat template.
You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`. You must enable this feature via the `--enable-mm-embeds` flag in `vllm serve`.
...@@ -897,11 +884,6 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -897,11 +884,6 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
```python ```python
from vllm.utils.serial_utils import tensor2base64 from vllm.utils.serial_utils import tensor2base64
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
base64_image_embedding = tensor2base64(image_embedding)
client = OpenAI( client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY") # defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key, api_key=openai_api_key,
...@@ -912,29 +894,33 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -912,29 +894,33 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
model = "llava-hf/llava-1.5-7b-hf" model = "llava-hf/llava-1.5-7b-hf"
embeds = { embeds = {
"type": "image_embeds", "type": "image_embeds",
"image_embeds": f"{base64_image_embedding}", "image_embeds": tensor2base64(torch.load(...)), # Shape: (image_feature_size, hidden_size)
"uuid": image_url, # Optional "uuid": image_url, # Optional
} }
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
# Additional examples for models that require extra fields
model = "Qwen/Qwen2-VL-2B-Instruct" model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = { embeds = {
"type": "image_embeds", "type": "image_embeds",
"image_embeds": { "image_embeds": {
"image_embeds": f"{base64_image_embedding}", # Required "image_embeds": tensor2base64(torch.load(...)), # Shape: (image_feature_size, hidden_size)
"image_grid_thw": f"{base64_image_grid_thw}", # Required by Qwen/Qwen2-VL-2B-Instruct "image_grid_thw": tensor2base64(torch.load(...)), # Shape: (3,)
}, },
"uuid": image_url, # Optional "uuid": image_url, # Optional
} }
model = "openbmb/MiniCPM-V-2_6" model = "openbmb/MiniCPM-V-2_6"
embeds = { embeds = {
"type": "image_embeds", "type": "image_embeds",
"image_embeds": { "image_embeds": {
"image_embeds": f"{base64_image_embedding}", # Required "image_embeds": tensor2base64(torch.load(...)), # Shape: (num_slices, hidden_size)
"image_sizes": f"{base64_image_sizes}", # Required by openbmb/MiniCPM-V-2_6 "image_sizes": tensor2base64(torch.load(...)), # Shape: (2,)
}, },
"uuid": image_url, # Optional "uuid": image_url, # Optional
} }
# Single image input
chat_completion = client.chat.completions.create( chat_completion = client.chat.completions.create(
messages=[ messages=[
{ {
...@@ -954,9 +940,55 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ...@@ -954,9 +940,55 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
], ],
model=model, model=model,
) )
# Multi image input
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
embeds,
embeds,
],
},
],
model=model,
)
# Multi image input (interleaved)
chat_completion = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": [
embeds,
{
"type": "text",
"text": "What's in this image?",
},
embeds,
],
},
],
model=model,
)
``` ```
For Online Serving, you can also skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this: ### Cached Inputs
Just like with offline inference, you can skip sending media if you expect cache hits with provided UUIDs. You can do so by sending media like this:
??? code ??? code
...@@ -990,13 +1022,3 @@ For Online Serving, you can also skip sending media if you expect cache hits wit ...@@ -990,13 +1022,3 @@ For Online Serving, you can also skip sending media if you expect cache hits wit
}, },
``` ```
!!! note
Multiple messages can now contain `{"type": "image_embeds"}`, enabling you to pass multiple image embeddings in a single request (similar to regular images). The number of embeddings is limited by `--limit-mm-per-prompt`.
**Important**: The embedding shape format differs based on the number of embeddings:
- **Single embedding**: 3D tensor of shape `(1, feature_size, hidden_size)`
- **Multiple embeddings**: List of 2D tensors, each of shape `(feature_size, hidden_size)`
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
...@@ -59,8 +59,10 @@ class PrithviMAE: ...@@ -59,8 +59,10 @@ class PrithviMAE:
input_data = input_data[0] input_data = input_data[0]
mm_data = { mm_data = {
"pixel_values": input_data, "image": {
"location_coords": location_coords, "pixel_values": input_data,
"location_coords": location_coords,
}
} }
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
......
...@@ -13,31 +13,10 @@ from vllm.utils.serial_utils import tensor2base64 ...@@ -13,31 +13,10 @@ from vllm.utils.serial_utils import tensor2base64
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
def _terratorch_dummy_messages():
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
return [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": tensor2base64(pixel_values),
"location_coords": tensor2base64(location_coords),
},
}
],
}
]
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
) )
def test_single_request(model_name: str): def test_single_content(model_name: str):
args = [ args = [
"--runner", "--runner",
"pooling", "pooling",
...@@ -59,7 +38,24 @@ def test_single_request(model_name: str): ...@@ -59,7 +38,24 @@ def test_single_request(model_name: str):
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"messages": _terratorch_dummy_messages(), "messages": [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": tensor2base64(
torch.ones((6, 512, 512), dtype=torch.float16)
),
"location_coords": tensor2base64(
torch.ones((1, 2), dtype=torch.float16)
),
},
},
],
}
],
"encoding_format": "base64", "encoding_format": "base64",
}, },
) )
...@@ -69,3 +65,87 @@ def test_single_request(model_name: str): ...@@ -69,3 +65,87 @@ def test_single_request(model_name: str):
np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32)
assert len(np_response) == 524288 assert len(np_response) == 524288
@pytest.mark.parametrize("model_name", ["Qwen/Qwen3-VL-2B-Instruct"])
def test_multi_content(model_name: str):
args = [
"--enforce-eager",
"--max-num-seqs",
"32",
"--max-model-len",
"8192",
"--enable-mm-embeds",
]
with RemoteOpenAIServer(model_name, args) as server:
client = server.get_client()
# Image only
chat_completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
],
}
],
max_tokens=5,
)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
# Interleaved text and image
chat_completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
{"type": "text", "text": "OCR:"},
{
"type": "image_embeds",
"image_embeds": {
"image_embeds": tensor2base64(torch.zeros(220, 8192)),
"image_grid_thw": tensor2base64(
torch.tensor([1, 22, 40])
),
},
},
],
}
],
max_tokens=5,
)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
...@@ -68,6 +68,16 @@ def phi3v_model_config_image_embeds(): ...@@ -68,6 +68,16 @@ def phi3v_model_config_image_embeds():
) )
@pytest.fixture(scope="function")
def qwen25omni_model_config_image_embeds():
return ModelConfig(
QWEN25OMNI_MODEL_ID,
runner="generate",
limit_mm_per_prompt={"image": 2},
enable_mm_embeds=True,
)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def qwen2_audio_model_config(): def qwen2_audio_model_config():
return ModelConfig( return ModelConfig(
...@@ -823,7 +833,8 @@ def test_parse_chat_messages_audio_embeds_with_string( ...@@ -823,7 +833,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
import torch import torch
# Create a sample audio embedding tensor # Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768) hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64 # Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding) base64_audio_embedding = tensor2base64(audio_embedding)
...@@ -865,7 +876,8 @@ async def test_parse_chat_messages_audio_embeds_async( ...@@ -865,7 +876,8 @@ async def test_parse_chat_messages_audio_embeds_async(
import torch import torch
# Create a sample audio embedding tensor # Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768) hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64 # Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding) base64_audio_embedding = tensor2base64(audio_embedding)
...@@ -908,8 +920,9 @@ def test_parse_chat_messages_multiple_image_embeds( ...@@ -908,8 +920,9 @@ def test_parse_chat_messages_multiple_image_embeds(
can be provided in a single request, similar to regular images. can be provided in a single request, similar to regular images.
""" """
# Create two sample image embedding tensors # Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024) hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_2 = torch.randn(128, 1024) image_embedding_1 = torch.randn(256, hidden_size)
image_embedding_2 = torch.randn(128, hidden_size)
# Encode them as base64 using the convenience function # Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1) base64_image_embedding_1 = tensor2base64(image_embedding_1)
...@@ -1022,8 +1035,9 @@ async def test_parse_chat_messages_multiple_image_embeds_async( ...@@ -1022,8 +1035,9 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
This validates the AsyncMultiModalItemTracker also supports multiple embeddings. This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
""" """
# Create two sample image embedding tensors # Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768) hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_2 = torch.randn(150, 768) image_embedding_1 = torch.randn(200, hidden_size)
image_embedding_2 = torch.randn(150, hidden_size)
# Encode them as base64 using the convenience function # Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1) base64_image_embedding_1 = tensor2base64(image_embedding_1)
...@@ -1145,13 +1159,14 @@ def test_parse_chat_messages_empty_dict_image_embeds( ...@@ -1145,13 +1159,14 @@ def test_parse_chat_messages_empty_dict_image_embeds(
def test_parse_chat_messages_multiple_dict_image_embeds( def test_parse_chat_messages_multiple_dict_image_embeds(
phi3v_model_config_image_embeds, qwen25omni_model_config_image_embeds,
): ):
"""Test that multiple dictionaries for image_embeds is handled without errors.""" """Test that multiple dictionaries for image_embeds is handled without errors."""
# Create two sample image embedding tensors # Create two sample image embedding tensors
batch_size = 2 batch_size = 2
image_embedding_1 = torch.randn(batch_size, 256, 1024) hidden_size = qwen25omni_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_2 = torch.randn(batch_size, 3) image_embeds = torch.randn(batch_size * 220, hidden_size)
image_grid_thw = torch.tensor([[1, 22, 40] for _ in range(batch_size)])
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
[ [
...@@ -1161,18 +1176,20 @@ def test_parse_chat_messages_multiple_dict_image_embeds( ...@@ -1161,18 +1176,20 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
{ {
"type": "image_embeds", "type": "image_embeds",
"image_embeds": { "image_embeds": {
"image_embedding_1": tensor2base64(p), "image_embeds": tensor2base64(embeds),
"image_embedding_2": tensor2base64(i), "image_grid_thw": tensor2base64(grid_thw),
}, },
} }
for p, i in zip(image_embedding_1, image_embedding_2) for embeds, grid_thw in zip(
image_embeds.chunk(batch_size), image_grid_thw
)
] ]
+ [ + [
{"type": "text", "text": "Describe these two images."}, {"type": "text", "text": "Describe these two images."},
], ],
} }
], ],
phi3v_model_config_image_embeds, qwen25omni_model_config_image_embeds,
content_format="string", content_format="string",
) )
...@@ -1180,7 +1197,8 @@ def test_parse_chat_messages_multiple_dict_image_embeds( ...@@ -1180,7 +1197,8 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert conversation == [ assert conversation == [
{ {
"role": "user", "role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.", "content": "<|vision_start|><|IMAGE|><|vision_end|>\n"
"<|vision_start|><|IMAGE|><|vision_end|>\nDescribe these two images.",
} }
] ]
...@@ -1191,10 +1209,10 @@ def test_parse_chat_messages_multiple_dict_image_embeds( ...@@ -1191,10 +1209,10 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert len(mm_data["image"]) == batch_size assert len(mm_data["image"]) == batch_size
# Verify each embedding has the correct shape # Verify each embedding has the correct shape
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor) assert isinstance(mm_data["image"]["image_embeds"], torch.Tensor)
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape assert mm_data["image"]["image_embeds"].shape == image_embeds.shape
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor) assert isinstance(mm_data["image"]["image_grid_thw"], torch.Tensor)
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape assert mm_data["image"]["image_grid_thw"].shape == image_grid_thw.shape
# Verify UUIDs (None since we didn't provide any) # Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None]) _assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
......
...@@ -7,14 +7,6 @@ import torch ...@@ -7,14 +7,6 @@ import torch
from ....conftest import VllmRunner from ....conftest import VllmRunner
def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data
def _run_test( def _run_test(
vllm_runner: type[VllmRunner], vllm_runner: type[VllmRunner],
model: str, model: str,
...@@ -23,7 +15,12 @@ def _run_test( ...@@ -23,7 +15,12 @@ def _run_test(
{ {
# This model deals with no text input # This model deals with no text input
"prompt_token_ids": [1], "prompt_token_ids": [1],
"multi_modal_data": generate_test_mm_data(), "multi_modal_data": {
"image": {
"pixel_values": torch.ones((6, 512, 512), dtype=torch.float16),
"location_coords": torch.ones((1, 2), dtype=torch.float16),
}
},
} }
for _ in range(10) for _ in range(10)
] ]
......
...@@ -349,8 +349,10 @@ class PrithviMultimodalDataProcessor(IOProcessor): ...@@ -349,8 +349,10 @@ class PrithviMultimodalDataProcessor(IOProcessor):
{ {
"prompt_token_ids": [1], "prompt_token_ids": [1],
"multi_modal_data": { "multi_modal_data": {
"pixel_values": window.to(torch.float16)[0], "image": {
"location_coords": location_coords.to(torch.float16), "pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
}
}, },
} }
) )
......
...@@ -5,14 +5,8 @@ import pytest ...@@ -5,14 +5,8 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import ( from vllm.config import CacheConfig, ModelConfig, VllmConfig
CacheConfig, from vllm.multimodal import MultiModalUUIDDict
DeviceConfig,
ModelConfig,
MultiModalConfig,
VllmConfig,
)
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
...@@ -21,55 +15,26 @@ stop_pil_image = ImageAsset("stop_sign").pil_image ...@@ -21,55 +15,26 @@ stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
def _mock_input_processor( def _build_input_processor(
monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> InputProcessor: ) -> InputProcessor:
"""
Create a Processor instance with minimal configuration suitable for unit
tests without accessing external resources.
"""
monkeypatch.setattr(
ModelConfig, "try_get_generation_config", lambda self: {}, raising=True
)
monkeypatch.setattr(
ModelConfig, "__post_init__", lambda self, *args: None, raising=True
)
monkeypatch.setattr(
ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True,
)
monkeypatch.setattr(
MultiModalRegistry,
"processor_cache_from_config",
lambda self, vllm_config: None,
raising=True,
)
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
model_config = ModelConfig( model_config = ModelConfig(
tokenizer="dummy", model="Qwen/Qwen2.5-VL-3B-Instruct",
skip_tokenizer_init=True, skip_tokenizer_init=True,
max_model_len=128, max_model_len=128,
mm_processor_cache_gb=mm_cache_gb, mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm",
) )
model_config.runner_type = "generate"
model_config.multimodal_config = MultiModalConfig(mm_processor_cache_gb=mm_cache_gb)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"),
) )
return InputProcessor(vllm_config) return InputProcessor(vllm_config)
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): def test_multi_modal_uuids_length_mismatch_raises():
input_processor = _mock_input_processor(monkeypatch) input_processor = _build_input_processor()
prompt = { prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:", "prompt": "USER: <image>\nDescribe\nASSISTANT:",
...@@ -78,7 +43,7 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): ...@@ -78,7 +43,7 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
"multi_modal_uuids": {"image": ["hash_cherry"]}, "multi_modal_uuids": {"image": ["hash_cherry"]},
} }
with pytest.raises(ValueError, match="must have same length as data"): with pytest.raises(ValueError, match="must have same length as"):
input_processor.process_inputs( input_processor.process_inputs(
request_id="req-1", request_id="req-1",
prompt=prompt, # type: ignore[arg-type] prompt=prompt, # type: ignore[arg-type]
...@@ -86,21 +51,21 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): ...@@ -86,21 +51,21 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
) )
def test_multi_modal_uuids_missing_modality_raises(monkeypatch): def test_multi_modal_uuids_missing_modality_raises():
input_processor = _mock_input_processor(monkeypatch) input_processor = _build_input_processor()
prompt = { prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:", "prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data # Two modalities provided in data
"multi_modal_data": { "multi_modal_data": {
"image": [cherry_pil_image], "image": [cherry_pil_image],
"video": [baby_reading_np_ndarrays], "video": None,
}, },
# Only image uuids provided; video missing should raise # Only image uuids provided; video missing should raise
"multi_modal_uuids": {"image": ["hash_cherry"]}, "multi_modal_uuids": {"image": ["hash_cherry"]},
} }
with pytest.raises(ValueError, match="must be provided if multi_modal_data"): with pytest.raises(ValueError, match="is empty but .* is missing"):
input_processor.process_inputs( input_processor.process_inputs(
request_id="req-2", request_id="req-2",
prompt=prompt, # type: ignore[arg-type] prompt=prompt, # type: ignore[arg-type]
...@@ -119,8 +84,7 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch): ...@@ -119,8 +84,7 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
def test_multi_modal_uuids_accepts_none_and_passes_through( def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
): ):
input_processor = _mock_input_processor( input_processor = _build_input_processor(
monkeypatch,
mm_cache_gb=mm_cache_gb, mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
) )
...@@ -163,8 +127,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -163,8 +127,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the # When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs. # processor builds overrides from request id instead of using user UUIDs.
input_processor = _mock_input_processor( input_processor = _build_input_processor(
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False mm_cache_gb=0.0, enable_prefix_caching=False
) )
captured: dict[str, MultiModalUUIDDict] = {} captured: dict[str, MultiModalUUIDDict] = {}
...@@ -180,12 +144,12 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -180,12 +144,12 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
) )
request_id = "req-42" request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"} mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
prompt = { prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:", "prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": { "multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image], "image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays, "video": [baby_reading_np_ndarrays],
}, },
"multi_modal_uuids": mm_uuids, "multi_modal_uuids": mm_uuids,
} }
...@@ -197,16 +161,15 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -197,16 +161,15 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
mm_uuids = captured["mm_uuids"]
assert set(mm_uuids.keys()) == {"image", "video"} assert set(mm_uuids.keys()) == {"image", "video"}
assert len(mm_uuids["image"]) == 2 assert len(mm_uuids["image"]) == 2
assert len(mm_uuids["video"]) == 1 assert len(mm_uuids["video"]) == 1
assert mm_uuids["image"][0].startswith(f"{request_id}-image-") and mm_uuids[ assert captured["mm_uuids"]["image"][0].startswith(
"image" f"{request_id}-image-"
][0].endswith("-0") ) and captured["mm_uuids"]["image"][0].endswith("-0")
assert mm_uuids["image"][1].startswith(f"{request_id}-image-") and mm_uuids[ assert captured["mm_uuids"]["image"][1].startswith(
"image" f"{request_id}-image-"
][1].endswith("-1") ) and captured["mm_uuids"]["image"][1].endswith("-1")
assert mm_uuids["video"][0].startswith(f"{request_id}-video-") and mm_uuids[ assert captured["mm_uuids"]["video"][0].startswith(
"video" f"{request_id}-video-"
][0].endswith("-0") ) and captured["mm_uuids"]["video"][0].endswith("-0")
This diff is collapsed.
...@@ -1574,10 +1574,6 @@ class LLM: ...@@ -1574,10 +1574,6 @@ class LLM:
try: try:
for i, prompt in enumerate(it): for i, prompt in enumerate(it):
if isinstance(prompt, dict):
self._validate_mm_data_and_uuids(
prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
)
request_id = self._add_request( request_id = self._add_request(
prompt, prompt,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
...@@ -1593,54 +1589,6 @@ class LLM: ...@@ -1593,54 +1589,6 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True) self.llm_engine.abort_request(added_request_ids, internal=True)
raise e raise e
def _validate_mm_data_and_uuids(
self,
multi_modal_data: Any | None, # MultiModalDataDict
multi_modal_uuids: Any | None, # MultiModalUUIDDict
):
"""
Validate that if any multi-modal data is skipped (i.e. None),
then its corresponding UUID must be set.
"""
if multi_modal_data is None:
return
for modality, data in multi_modal_data.items():
if isinstance(data, list):
for i, d in enumerate(data):
if d is None:
if (
multi_modal_uuids is None
or modality not in multi_modal_uuids
or multi_modal_uuids[ # noqa: E501
modality
]
is None
):
raise ValueError(
f"Multi-modal data for {modality} is None "
f"but UUID is not provided"
)
else:
if (
len(multi_modal_uuids[modality]) <= i
or multi_modal_uuids[modality][i] is None
):
raise ValueError(
f"Multi-modal data for {modality} is None "
f"but UUID is not provided"
)
else:
if data is None and (
multi_modal_uuids is None
or modality not in multi_modal_uuids
or multi_modal_uuids[modality] is None
):
raise ValueError(
f"Multi-modal data for {modality} is None"
f" but UUID is not provided"
)
def _process_inputs( def _process_inputs(
self, self,
request_id: str, request_id: str,
......
...@@ -19,7 +19,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -19,7 +19,7 @@ from vllm.entrypoints.chat_utils import (
) )
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -95,7 +95,7 @@ def parse_score_data( ...@@ -95,7 +95,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
model_config: ModelConfig, model_config: ModelConfig,
) -> tuple[str, str, MultiModalDataDict | None]: ) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
mm_tracker = MultiModalItemTracker(model_config) mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content(data_1, mm_tracker) content_1 = _parse_score_content(data_1, mm_tracker)
...@@ -109,8 +109,9 @@ def parse_score_data( ...@@ -109,8 +109,9 @@ def parse_score_data(
prompt_1 = ensure_str(content_1) prompt_1 = ensure_str(content_1)
prompt_2 = ensure_str(content_2) prompt_2 = ensure_str(content_2)
mm_items, mm_uuids = mm_tracker.resolve_items()
return prompt_1, prompt_2, mm_tracker.all_mm_data() return prompt_1, prompt_2, mm_items, mm_uuids
def _parse_score_content( def _parse_score_content(
...@@ -187,7 +188,7 @@ def get_score_prompt( ...@@ -187,7 +188,7 @@ def get_score_prompt(
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
score_template: str | None = None, score_template: str | None = None,
) -> tuple[str, TokensPrompt]: ) -> tuple[str, TokensPrompt]:
prompt_1, prompt_2, mm_data = parse_score_data( prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
data_1, data_1,
data_2, data_2,
model_config, model_config,
...@@ -248,6 +249,9 @@ def get_score_prompt( ...@@ -248,6 +249,9 @@ def get_score_prompt(
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
return full_prompt, engine_prompt return full_prompt, engine_prompt
......
...@@ -113,7 +113,11 @@ from .qwen2_5_vl import ( ...@@ -113,7 +113,11 @@ from .qwen2_5_vl import (
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoPixelInputs,
) )
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo from .qwen2_vl import (
Qwen2VLMultiModalDataParser,
Qwen2VLProcessingInfo,
_create_qwen2vl_field_factory,
)
from .qwen3 import Qwen3ForCausalLM, Qwen3Model from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
...@@ -985,28 +989,9 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]) ...@@ -985,28 +989,9 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) return _create_qwen2vl_field_factory(
image_grid_sizes = image_grid_thw.prod(-1) self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"""Wrapper around `Terratorch` models""" """Wrapper around `Terratorch` models"""
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any from typing import Any
import torch import torch
...@@ -62,6 +62,7 @@ from vllm.multimodal.processing import ( ...@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import attn_type from .interfaces_base import attn_type
...@@ -69,28 +70,21 @@ from .interfaces_base import attn_type ...@@ -69,28 +70,21 @@ from .interfaces_base import attn_type
logger = init_logger(__name__) logger = init_logger(__name__)
def _terratorch_field_names(pretrained_cfg: dict): def _terratorch_field_names(input_definition: InputDefinition):
input_definition = InputDefinition(**pretrained_cfg["input"])
return set(input_definition.data.keys()) return set(input_definition.data.keys())
def _terratorch_field_factory( def _terratorch_field_factory(input_definition: InputDefinition):
pretrained_cfg: dict, def _terratorch_field_config(
) -> Callable[ hf_inputs: Mapping[str, torch.Tensor],
[Mapping[str, torch.Tensor]], ) -> Mapping[str, MultiModalFieldConfig]:
Mapping[str, MultiModalFieldConfig], fields = dict[str, MultiModalFieldConfig]()
]: for name, input in input_definition.data.items():
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): modality = "image"
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
for input_name, input in input_definition.data.items():
if input.type == InputTypeEnum.tensor: if input.type == InputTypeEnum.tensor:
fields[input_name] = "image" fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
return { return fields
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}
return _terratorch_field_config return _terratorch_field_config
...@@ -130,26 +124,31 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): ...@@ -130,26 +124,31 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
class TerratorchMultiModalDataParser(MultiModalDataParser): class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs): def __init__(self, input_definition: InputDefinition, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data( def _parse_image_data(
self, self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem], data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None: ) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict): if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems( return DictEmbeddingItems(
data, data,
modality="image", modality="image",
required_fields=terratorch_fields, required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self._pretrained_cfg), fields_factory=_terratorch_field_factory(self.input_definition),
) )
return super()._parse_image_data(data) return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor): class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__( def __init__(
...@@ -159,18 +158,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ...@@ -159,18 +158,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
*, *,
cache: MultiModalProcessorOnlyCache | None = None, cache: MultiModalProcessorOnlyCache | None = None,
) -> None: ) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
self._input_definition = InputDefinition(**pretrained_cfg["input"])
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg) return TerratorchMultiModalDataParser(self._input_definition)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) return _terratorch_field_factory(self._input_definition)(hf_inputs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -188,23 +189,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ...@@ -188,23 +189,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {} tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
) )
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data) mm_processed_data = BatchFeature(
mm_data.get("image", mm_data), tensor_type="pt"
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data, mm_processed_data,
...@@ -272,9 +266,15 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): ...@@ -272,9 +266,15 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs: object, **kwargs: object,
): ):
model_output = self.inference_runner.forward(**kwargs) input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
model_output = self.inference_runner.forward(**batched_kwargs).output
return model_output.output # The leading dimension of hidden states needs to equal input length
return model_output.expand(
input_len, *(-1 for _ in range(model_output.ndim - 1))
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_list = [] params_list = []
......
...@@ -13,7 +13,7 @@ def random_uuid() -> str: ...@@ -13,7 +13,7 @@ def random_uuid() -> str:
def length_from_prompt_token_ids_or_embeds( def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: list[int] | None, prompt_token_ids: list[int] | torch.Tensor | None,
prompt_embeds: torch.Tensor | None, prompt_embeds: torch.Tensor | None,
) -> int: ) -> int:
"""Calculate the request length (in number of tokens) give either """Calculate the request length (in number of tokens) give either
......
...@@ -8,14 +8,24 @@ from typing import Any, Literal, cast ...@@ -8,14 +8,24 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import (
from vllm.inputs.parse import split_enc_dec_inputs ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import (
from vllm.multimodal.parse import MultiModalDataParser MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -188,7 +198,66 @@ class InputProcessor: ...@@ -188,7 +198,66 @@ class InputProcessor:
self._validate_sampling_params(params) self._validate_sampling_params(params)
self._validate_supported_sampling_params(params) self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None: def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
return mm_processor.data_parser.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
prompt = TextPrompt(prompt=prompt)
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType) -> None:
""" """
Validate that user-provided multi_modal_uuids align with Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s). multi_modal_data in the incoming request prompt(s).
...@@ -196,55 +265,13 @@ class InputProcessor: ...@@ -196,55 +265,13 @@ class InputProcessor:
auto-hashed downstream. auto-hashed downstream.
""" """
def _validate_single_prompt(single_prompt: dict | str) -> None: if is_explicit_encoder_decoder_prompt(prompt):
if not isinstance(single_prompt, dict): self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = _get_len(items)
uuid_len = _get_len(mm_uuids[modality])
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got "
f"{uuid_len} uuids vs {data_len} items."
)
else:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided."
)
# Handle explicit encoder/decoder prompts or singleton prompt if (dec_prompt := prompt["decoder_prompt"]) is not None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt: self._validate_singleton_mm_uuids(dec_prompt)
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(cast(dict | str, enc))
if dec is not None:
_validate_single_prompt(cast(dict | str, dec))
else: else:
_validate_single_prompt(prompt) # type: ignore[arg-type] self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None: def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None: if lora_request is None:
...@@ -379,6 +406,20 @@ class InputProcessor: ...@@ -379,6 +406,20 @@ class InputProcessor:
# roundtrip serialization/deserialization won't fail. # roundtrip serialization/deserialization won't fail.
params.structured_outputs.__post_init__() params.structured_outputs.__post_init__()
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, str):
return None
return prompt.get("multi_modal_data") # type: ignore[return-value]
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
if is_explicit_encoder_decoder_prompt(prompt):
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids( def _maybe_build_mm_uuids(
self, self,
request_id: str, request_id: str,
...@@ -391,31 +432,18 @@ class InputProcessor: ...@@ -391,31 +432,18 @@ class InputProcessor:
Returns a dictionary of modality -> list[str] of overrides, or None if Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present. disabled or no multimodal data is present.
""" """
mm_data = self._extract_mm_data(prompt)
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data: if not mm_data:
return None return None
mm_uuids: dict[str, list[str | None] | str] = {} mm_items = self._parse_mm_items(
for modality, data in mm_data.items(): {k: v for k, v in mm_data.items() if v is not None}
# Hash each item for embedding inputs. )
n = (
len(data) return {
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data) modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
else 1 for modality, data_count in mm_items.get_all_counts().items()
) }
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
def _get_mm_identifier( def _get_mm_identifier(
self, self,
...@@ -494,7 +522,7 @@ class InputProcessor: ...@@ -494,7 +522,7 @@ class InputProcessor:
else: else:
# Otherwise, use user-provided uuids as multimodal hash overrides # Otherwise, use user-provided uuids as multimodal hash overrides
# if provided. # if provided.
self._validate_multi_modal_uuids(prompt) self._validate_mm_uuids(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
mm_uuids = cast( mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids") MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment