Unverified Commit 80b18230 authored by Nithin Chalapathi's avatar Nithin Chalapathi Committed by GitHub
Browse files

[Frontend] Add multimodal support to /inference/v1/generate endpoint (#38405)


Signed-off-by: default avatarNithin Chalapathi <nithin.ch10@gmail.com>
Signed-off-by: default avatarNithin Chalapathi <nithinc@berkeley.edu>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent d0697cc7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Disaggregated multimodal serving: render → generate round-trip.
Demonstrates the two-phase disaggregated flow:
1. /v1/chat/completions/render – preprocesses a multimodal chat request
into token IDs and serialized tensor features.
2. /inference/v1/generate – runs inference on the preprocessed tokens.
The render response is passed *directly* to generate with only
``sampling_params`` added, showing that the two endpoints compose with
zero client-side transformation.
Launch the server first:
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--dtype bfloat16 --max-model-len 4096 --enforce-eager
Then run this script:
python example_mm_serve.py
"""
import io
import pybase64 as base64
import requests
from PIL import Image
from transformers import AutoTokenizer
BASE_URL = "http://localhost:8000"
MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
def make_data_url(image: Image.Image) -> str:
"""Encode a PIL image as a base64 data URL."""
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def main():
# -- Step 1: Create a test image (solid red) -------------------------
image = Image.new("RGB", (224, 224), color=(255, 0, 0))
data_url = make_data_url(image)
print("Created 224x224 red test image")
# -- Step 2: Render (preprocess) -------------------------------------
render_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What color is this image? Answer in one word.",
},
],
}
],
}
print("\n--- Render ---")
render_resp = requests.post(
f"{BASE_URL}/v1/chat/completions/render", json=render_payload
)
render_resp.raise_for_status()
render_data = render_resp.json()
print(f"Response keys: {list(render_data.keys())}")
print(f"Number of token_ids: {len(render_data['token_ids'])}")
features = render_data.get("features")
if features and features.get("kwargs_data"):
print(f"kwargs_data modalities: {list(features['kwargs_data'].keys())}")
for modality, items in features["kwargs_data"].items():
print(
f" {modality}: {len(items)} item(s), "
f"first item type: {type(items[0])} length: {len(items[0])}"
if items
else "First item: (empty)"
)
else:
print("WARNING: no kwargs_data in render response")
# -- Step 3: Generate (inference) ------------------------------------
# Pass the render output directly — only add sampling_params.
generate_payload = render_data
generate_payload["sampling_params"] = {
"max_tokens": 20,
"temperature": 0.0,
}
print("\n--- Generate ---")
gen_resp = requests.post(f"{BASE_URL}/inference/v1/generate", json=generate_payload)
gen_resp.raise_for_status()
gen_data = gen_resp.json()
# -- Step 4: Decode & print ------------------------------------------
output_ids = gen_data["choices"][0]["token_ids"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output token count: {len(output_ids)}")
print(f"Generated text: {text!r}")
if "red" in text.lower():
print("\nModel correctly identified the red image.")
else:
print(f"\nWARNING: Expected 'red' in output, got: {text!r}")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Roundtrip tests for multimodal serde used by the disagg generate endpoint."""
import torch
from vllm.entrypoints.serve.disagg.mm_serde import (
decode_mm_kwargs_item,
encode_mm_kwargs_item,
)
from vllm.entrypoints.serve.disagg.protocol import (
MultiModalFeatures,
PlaceholderRangeInfo,
)
from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFieldElem,
MultiModalFlatField,
MultiModalKwargsItem,
MultiModalSharedField,
)
def test_mm_kwargs_item_roundtrip():
"""Full roundtrip test with all three field types and multiple dtypes."""
e1 = MultiModalFieldElem(
data=torch.zeros(1000, dtype=torch.bfloat16),
field=MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
data=torch.ones(100, dtype=torch.int32),
field=MultiModalSharedField(batch_size=4),
)
e3 = MultiModalFieldElem(
data=torch.randn(20, dtype=torch.float32),
field=MultiModalFlatField(slices=[slice(0, 10), slice(10, 20)], dim=0),
)
item = MultiModalKwargsItem({"pixel_values": e1, "grid_thw": e2, "embeds": e3})
encoded = encode_mm_kwargs_item(item)
# Encoded result is a base64 string
assert isinstance(encoded, str)
decoded = decode_mm_kwargs_item(encoded)
assert set(decoded.keys()) == {"pixel_values", "grid_thw", "embeds"}
assert torch.equal(item["pixel_values"].data, decoded["pixel_values"].data)
assert torch.equal(item["grid_thw"].data, decoded["grid_thw"].data)
assert torch.equal(item["embeds"].data, decoded["embeds"].data)
assert isinstance(decoded["pixel_values"].field, MultiModalBatchedField)
assert isinstance(decoded["grid_thw"].field, MultiModalSharedField)
assert isinstance(decoded["embeds"].field, MultiModalFlatField)
def test_mm_kwargs_item_none_data():
"""Roundtrip with None data field."""
elem = MultiModalFieldElem(
data=None,
field=MultiModalSharedField(batch_size=2),
)
item = MultiModalKwargsItem({"empty": elem})
encoded = encode_mm_kwargs_item(item)
decoded = decode_mm_kwargs_item(encoded)
assert decoded["empty"].data is None
assert isinstance(decoded["empty"].field, MultiModalSharedField)
def test_mm_kwargs_item_nested_tensors():
"""Roundtrip with nested tensor data."""
nested = [torch.randn(3, 4), torch.randn(5, 4)]
elem = MultiModalFieldElem(
data=nested,
field=MultiModalBatchedField(),
)
item = MultiModalKwargsItem({"nested": elem})
encoded = encode_mm_kwargs_item(item)
decoded = decode_mm_kwargs_item(encoded)
decoded_data = decoded["nested"].data
assert len(decoded_data) == 2
assert torch.equal(nested[0], decoded_data[0])
assert torch.equal(nested[1], decoded_data[1])
def test_mm_features_with_kwargs_data():
"""Test that MultiModalFeatures can carry serialized tensor data."""
elem = MultiModalFieldElem(
data=torch.randn(5, 3, dtype=torch.float32),
field=MultiModalBatchedField(),
)
item = MultiModalKwargsItem({"pixel_values": elem})
encoded = encode_mm_kwargs_item(item)
features = MultiModalFeatures(
mm_hashes={"image": ["abc123"]},
mm_placeholders={"image": [PlaceholderRangeInfo(offset=0, length=10)]},
kwargs_data={"image": [encoded]},
)
# JSON roundtrip
json_str = features.model_dump_json()
features2 = MultiModalFeatures.model_validate_json(json_str)
assert features2.mm_hashes == {"image": ["abc123"]}
assert features2.kwargs_data is not None
assert len(features2.kwargs_data["image"]) == 1
decoded = decode_mm_kwargs_item(features2.kwargs_data["image"][0])
assert torch.equal(elem.data, decoded["pixel_values"].data)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for multimodal features through the /inference/v1/generate endpoint.
Mirrors test_serving_tokens.py but exercises the multimodal piping
using Qwen/Qwen3-VL-2B-Instruct end-to-end via the server's /render ->
/generate -> /detokenize path. Intentionally avoids running the HF
processor in the pytest parent process to keep os.fork() in sibling
tests (e.g. test_weight_transfer_llm.py) deadlock-free.
"""
import os
import httpx
import pytest
import pytest_asyncio
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_url
MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
GEN_ENDPOINT = "/inference/v1/generate"
RENDER_ENDPOINT = "/v1/chat/completions/render"
DETOKENIZE_ENDPOINT = "/detokenize"
@pytest.fixture(scope="module")
def test_image():
return Image.new("RGB", (224, 224), color=(255, 0, 0))
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
"--no-enable-prefix-caching",
]
envs = os.environ.copy()
envs["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
async with httpx.AsyncClient(
transport=transport,
base_url=server.url_root,
timeout=600,
headers=headers,
) as c:
yield c
@pytest.mark.asyncio
async def test_render_to_generate_roundtrip(client, test_image):
"""End-to-end: render a multimodal chat -> feed into generate -> decode.
All preprocessing and detokenization happens in the server subprocess;
the pytest parent never imports transformers or touches torch tensors.
"""
data_url = encode_image_url(test_image, format="PNG")
render_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What color is this image? Answer in one word.",
},
],
}
],
}
render_resp = await client.post(RENDER_ENDPOINT, json=render_payload)
render_resp.raise_for_status()
render_data = render_resp.json()
# Validate render output structure: keys exist and values are non-empty
# and well-typed.
assert "token_ids" in render_data
assert isinstance(render_data["token_ids"], list)
assert len(render_data["token_ids"]) > 0
assert all(isinstance(t, int) for t in render_data["token_ids"])
assert "features" in render_data
features = render_data["features"]
assert features is not None
assert isinstance(features, dict)
assert "mm_hashes" in features
assert "image" in features["mm_hashes"]
image_hashes = features["mm_hashes"]["image"]
assert isinstance(image_hashes, list)
assert len(image_hashes) > 0
assert all(isinstance(h, str) and h for h in image_hashes)
assert "mm_placeholders" in features
assert "image" in features["mm_placeholders"]
image_placeholders = features["mm_placeholders"]["image"]
assert isinstance(image_placeholders, list)
assert len(image_placeholders) > 0
for p in image_placeholders:
assert isinstance(p.get("offset"), int)
assert isinstance(p.get("length"), int)
assert p["length"] > 0
assert "kwargs_data" in features
assert "image" in features["kwargs_data"]
assert len(features["kwargs_data"]["image"]) > 0
# Build generate request from render output
generate_payload = render_data
generate_payload["sampling_params"] = {
"max_tokens": 10,
"temperature": 0.0,
}
gen_resp = await client.post(GEN_ENDPOINT, json=generate_payload)
gen_resp.raise_for_status()
gen_data = gen_resp.json()
assert "choices" in gen_data
assert isinstance(gen_data["choices"], list)
assert len(gen_data["choices"]) >= 1
choice = gen_data["choices"][0]
assert "token_ids" in choice
assert isinstance(choice["token_ids"], list)
assert len(choice["token_ids"]) > 0
assert all(isinstance(t, int) for t in choice["token_ids"])
detok_resp = await client.post(
DETOKENIZE_ENDPOINT,
json={"model": MODEL_NAME, "tokens": choice["token_ids"]},
)
detok_resp.raise_for_status()
detok_data = detok_resp.json()
assert "prompt" in detok_data
text = detok_data["prompt"]
assert isinstance(text, str)
assert len(text) > 0
assert "red" in text.lower(), (
f"Expected model to identify the red image, got: {text!r}"
)
...@@ -7,17 +7,39 @@ from tests.models.utils import check_embeddings_close ...@@ -7,17 +7,39 @@ from tests.models.utils import check_embeddings_close
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EMBED_DTYPES, EMBED_DTYPES,
ENDIANNESS, ENDIANNESS,
MM_METADATA_DTYPES,
EmbedDType, EmbedDType,
Endianness, Endianness,
MmMetadataDType,
binary2tensor, binary2tensor,
tensor2binary, tensor2binary,
) )
FLOAT_EMBED_DTYPES = tuple(EMBED_DTYPES.keys())
INTEGER_EMBED_DTYPES = tuple(MM_METADATA_DTYPES.keys())
def _build_integer_tensor(
embed_dtype: MmMetadataDType, shape: tuple[int, ...]
) -> torch.Tensor:
torch_dtype = MM_METADATA_DTYPES[embed_dtype].torch_dtype
if torch_dtype is torch.bool:
return torch.randint(0, 2, shape, dtype=torch.int32).to(torch.bool)
if torch_dtype is torch.uint8:
return torch.randint(0, 256, shape, dtype=torch.uint8)
if torch_dtype is torch.int32:
return torch.randint(-(2**20), 2**20, shape, dtype=torch.int32)
if torch_dtype is torch.int64:
return torch.randint(-(2**62), 2**62, shape, dtype=torch.int64)
raise AssertionError(f"Unsupported non-floating embed dtype: {embed_dtype}")
@pytest.mark.parametrize("endianness", ENDIANNESS) @pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPES.keys()) @pytest.mark.parametrize("embed_dtype", FLOAT_EMBED_DTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness): def test_encode_and_decode_floats(embed_dtype: EmbedDType, endianness: Endianness):
for i in range(10): for i in range(10):
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32) tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
shape = tensor.shape shape = tensor.shape
...@@ -40,3 +62,20 @@ def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness): ...@@ -40,3 +62,20 @@ def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
name_1="new", name_1="new",
tol=1e-2, tol=1e-2,
) )
@pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", INTEGER_EMBED_DTYPES)
@torch.inference_mode()
def test_encode_and_decode_integers(
embed_dtype: MmMetadataDType, endianness: Endianness
):
shape = (2, 3, 5, 7, 11, 13)
for i in range(10):
tensor = _build_integer_tensor(embed_dtype, shape)
binary = tensor2binary(tensor, embed_dtype, endianness)
new_tensor = binary2tensor(binary, shape, embed_dtype, endianness)
assert new_tensor.dtype == MM_METADATA_DTYPES[embed_dtype].torch_dtype
torch.testing.assert_close(tensor, new_tensor, atol=0, rtol=0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Encode/decode utilities for multimodal tensors and field metadata
over JSON/HTTP, used by the disaggregated generate endpoint."""
from __future__ import annotations
import pybase64
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
_encoder = MsgpackEncoder(size_threshold=2**62) # force all tensors inline
_decoder = MsgpackDecoder(t=MultiModalKwargsItem)
def encode_mm_kwargs_item(item: MultiModalKwargsItem) -> str:
"""Serialize a MultiModalKwargsItem to a base64 string."""
bufs = _encoder.encode(item)
assert len(bufs) == 1, "All tensors should be inline"
return pybase64.b64encode(bufs[0]).decode("ascii")
def decode_mm_kwargs_item(data: str) -> MultiModalKwargsItem:
"""Deserialize a base64 string back to a MultiModalKwargsItem."""
raw = pybase64.b64decode(data)
return _decoder.decode(raw)
...@@ -35,14 +35,6 @@ class MultiModalFeatures(BaseModel): ...@@ -35,14 +35,6 @@ class MultiModalFeatures(BaseModel):
Carries hashes (for cache lookup / identification) and placeholder Carries hashes (for cache lookup / identification) and placeholder
positions so the downstream `/generate` service knows *where* in positions so the downstream `/generate` service knows *where* in
the token sequence each multimodal item lives. the token sequence each multimodal item lives.
Note:
Phase 1 — metadata only.
Phase 2 should add `mm_kwargs` (processed tensor data) using a
binary transport so the ``/generate` side can skip re-processing.
The `/generate` endpoint must also be updated to inject these
features into `EngineInput` before passing to
`InputProcessor.process_inputs`.
""" """
mm_hashes: dict[str, list[str]] mm_hashes: dict[str, list[str]]
...@@ -51,6 +43,15 @@ class MultiModalFeatures(BaseModel): ...@@ -51,6 +43,15 @@ class MultiModalFeatures(BaseModel):
mm_placeholders: dict[str, list[PlaceholderRangeInfo]] mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
"""Per-modality placeholder ranges in the token sequence.""" """Per-modality placeholder ranges in the token sequence."""
kwargs_data: dict[str, list[str | None]] | None = None
"""Per-modality serialized tensor data.
Each value is a list parallel to ``mm_hashes[modality]``. A ``str``
entry is a base64-encoded ``MultiModalKwargsItem``; ``None`` means
the item should be resolved from cache. The entire field is
``None`` for metadata-only (cache-hit) responses.
"""
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
request_id: str = Field( request_id: str = Field(
......
...@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.disagg.mm_serde import decode_mm_kwargs_item
from vllm.entrypoints.serve.disagg.protocol import ( from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest, GenerateRequest,
GenerateResponse, GenerateResponse,
...@@ -34,8 +35,14 @@ from vllm.entrypoints.serve.disagg.protocol import ( ...@@ -34,8 +35,14 @@ from vllm.entrypoints.serve.disagg.protocol import (
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import should_include_usage from vllm.entrypoints.utils import should_include_usage
from vllm.inputs import EngineInput, mm_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.multimodal.inputs import (
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
...@@ -103,11 +110,42 @@ class ServingTokens(OpenAIServing): ...@@ -103,11 +110,42 @@ class ServingTokens(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
(engine_input,) = await self.openai_serving_render.preprocess_completion( engine_input: EngineInput
request, if features := request.features:
prompt_input=request.token_ids, # Convert PlaceholderRangeInfo → PlaceholderRange per modality.
prompt_embeds=None, mm_placeholders: dict[str, list[PlaceholderRange]] = {
) modality: [
PlaceholderRange(offset=p.offset, length=p.length) for p in ranges
]
for modality, ranges in features.mm_placeholders.items()
}
# Deserialize tensor data when present; None → cache hit.
mm_kwargs: dict[str, list[MultiModalKwargsItem | None]] = {}
if features.kwargs_data is not None:
for modality, items in features.kwargs_data.items():
mm_kwargs[modality] = [
decode_mm_kwargs_item(item) if item is not None else None
for item in items
]
else:
for modality, hashes in features.mm_hashes.items():
mm_kwargs[modality] = [None] * len(hashes)
engine_input = mm_input(
prompt_token_ids=request.token_ids,
mm_kwargs=MultiModalKwargsItems(mm_kwargs),
mm_hashes=features.mm_hashes,
mm_placeholders=mm_placeholders,
cache_salt=request.cache_salt,
)
else:
(engine_input,) = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,
skip_mm_cache=True,
)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None result_generator: AsyncGenerator[RequestOutput, None] | None = None
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any, cast
from openai_harmony import Message as OpenAIMessage from openai_harmony import Message as OpenAIMessage
...@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item
from vllm.entrypoints.serve.disagg.protocol import ( from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest, GenerateRequest,
MultiModalFeatures, MultiModalFeatures,
...@@ -37,6 +38,7 @@ from vllm.entrypoints.utils import ( ...@@ -37,6 +38,7 @@ from vllm.entrypoints.utils import (
from vllm.inputs import ( from vllm.inputs import (
EngineInput, EngineInput,
MultiModalHashes, MultiModalHashes,
MultiModalInput,
MultiModalPlaceholders, MultiModalPlaceholders,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
...@@ -251,6 +253,7 @@ class OpenAIServingRender: ...@@ -251,6 +253,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs, default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
skip_mm_cache=True,
reasoning_parser=self.reasoning_parser, reasoning_parser=self.reasoning_parser,
) )
else: else:
...@@ -342,6 +345,7 @@ class OpenAIServingRender: ...@@ -342,6 +345,7 @@ class OpenAIServingRender:
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
skip_mm_cache=True,
) )
return engine_inputs return engine_inputs
...@@ -357,9 +361,10 @@ class OpenAIServingRender: ...@@ -357,9 +361,10 @@ class OpenAIServingRender:
if engine_input.get("type") != "multimodal": if engine_input.get("type") != "multimodal":
return None return None
# At this point engine_input is a MultiModalInputs TypedDict. # At this point engine_input is a MultiModalInput TypedDict.
mm_hashes: MultiModalHashes = engine_input["mm_hashes"] # type: ignore[typeddict-item] mm_engine_input = cast(MultiModalInput, engine_input)
raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"] # type: ignore[typeddict-item] mm_hashes: MultiModalHashes = mm_engine_input["mm_hashes"]
raw_placeholders: MultiModalPlaceholders = mm_engine_input["mm_placeholders"]
mm_placeholders = { mm_placeholders = {
modality: [ modality: [
...@@ -368,9 +373,20 @@ class OpenAIServingRender: ...@@ -368,9 +373,20 @@ class OpenAIServingRender:
for modality, ranges in raw_placeholders.items() for modality, ranges in raw_placeholders.items()
} }
# Serialize tensor data per modality.
kwargs_data: dict[str, list[str | None]] | None = None
if raw_mm_kwargs := mm_engine_input.get("mm_kwargs"):
kwargs_data = {}
for modality, items in raw_mm_kwargs.items():
kwargs_data[modality] = [
encode_mm_kwargs_item(item) if item is not None else None
for item in items
]
return MultiModalFeatures( return MultiModalFeatures(
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholders,
kwargs_data=kwargs_data,
) )
def _make_request_with_harmony( def _make_request_with_harmony(
......
...@@ -27,6 +27,7 @@ class DTypeInfo: ...@@ -27,6 +27,7 @@ class DTypeInfo:
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
MmMetadataDType = Literal["int32", "int64", "uint8", "bool"]
Endianness = Literal["native", "big", "little"] Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"] EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
...@@ -42,6 +43,15 @@ EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = { ...@@ -42,6 +43,15 @@ EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
"fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8), "fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
"fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8), "fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
} }
MM_METADATA_DTYPES: Mapping[MmMetadataDType, DTypeInfo] = {
"int32": DTypeInfo(torch.int32, torch.int32, np.int32),
"int64": DTypeInfo(torch.int64, torch.int64, np.int64),
"uint8": DTypeInfo(torch.uint8, torch.uint8, np.uint8),
"bool": DTypeInfo(torch.bool, torch.uint8, np.uint8),
}
_ALL_SERIAL_DTYPES: Mapping[str, DTypeInfo] = {
k: v for d in (EMBED_DTYPES, MM_METADATA_DTYPES) for k, v in d.items()
}
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness) ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)
...@@ -56,14 +66,14 @@ def tensor2base64(x: torch.Tensor) -> str: ...@@ -56,14 +66,14 @@ def tensor2base64(x: torch.Tensor) -> str:
def tensor2binary( def tensor2binary(
tensor: torch.Tensor, tensor: torch.Tensor,
embed_dtype: EmbedDType, embed_dtype: "EmbedDType | MmMetadataDType",
endianness: Endianness, endianness: Endianness,
) -> bytes: ) -> bytes:
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPES assert embed_dtype in _ALL_SERIAL_DTYPES
assert endianness in ENDIANNESS assert endianness in ENDIANNESS
dtype_info = EMBED_DTYPES[embed_dtype] dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
np_array = ( np_array = (
tensor.to(dtype_info.torch_dtype) tensor.to(dtype_info.torch_dtype)
...@@ -82,13 +92,13 @@ def tensor2binary( ...@@ -82,13 +92,13 @@ def tensor2binary(
def binary2tensor( def binary2tensor(
binary: bytes, binary: bytes,
shape: tuple[int, ...], shape: tuple[int, ...],
embed_dtype: EmbedDType, embed_dtype: "EmbedDType | MmMetadataDType",
endianness: Endianness, endianness: Endianness,
) -> torch.Tensor: ) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPES assert embed_dtype in _ALL_SERIAL_DTYPES
assert endianness in ENDIANNESS assert endianness in ENDIANNESS
dtype_info = EMBED_DTYPES[embed_dtype] dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape) np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment