Unverified Commit 80b18230 authored by Nithin Chalapathi's avatar Nithin Chalapathi Committed by GitHub
Browse files

[Frontend] Add multimodal support to /inference/v1/generate endpoint (#38405)


Signed-off-by: default avatarNithin Chalapathi <nithin.ch10@gmail.com>
Signed-off-by: default avatarNithin Chalapathi <nithinc@berkeley.edu>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent d0697cc7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Disaggregated multimodal serving: render → generate round-trip.
Demonstrates the two-phase disaggregated flow:
1. /v1/chat/completions/render – preprocesses a multimodal chat request
into token IDs and serialized tensor features.
2. /inference/v1/generate – runs inference on the preprocessed tokens.
The render response is passed *directly* to generate with only
``sampling_params`` added, showing that the two endpoints compose with
zero client-side transformation.
Launch the server first:
vllm serve Qwen/Qwen3-VL-2B-Instruct \
--dtype bfloat16 --max-model-len 4096 --enforce-eager
Then run this script:
python example_mm_serve.py
"""
import io
import pybase64 as base64
import requests
from PIL import Image
from transformers import AutoTokenizer
BASE_URL = "http://localhost:8000"
MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
def make_data_url(image: Image.Image) -> str:
"""Encode a PIL image as a base64 data URL."""
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def main():
# -- Step 1: Create a test image (solid red) -------------------------
image = Image.new("RGB", (224, 224), color=(255, 0, 0))
data_url = make_data_url(image)
print("Created 224x224 red test image")
# -- Step 2: Render (preprocess) -------------------------------------
render_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What color is this image? Answer in one word.",
},
],
}
],
}
print("\n--- Render ---")
render_resp = requests.post(
f"{BASE_URL}/v1/chat/completions/render", json=render_payload
)
render_resp.raise_for_status()
render_data = render_resp.json()
print(f"Response keys: {list(render_data.keys())}")
print(f"Number of token_ids: {len(render_data['token_ids'])}")
features = render_data.get("features")
if features and features.get("kwargs_data"):
print(f"kwargs_data modalities: {list(features['kwargs_data'].keys())}")
for modality, items in features["kwargs_data"].items():
print(
f" {modality}: {len(items)} item(s), "
f"first item type: {type(items[0])} length: {len(items[0])}"
if items
else "First item: (empty)"
)
else:
print("WARNING: no kwargs_data in render response")
# -- Step 3: Generate (inference) ------------------------------------
# Pass the render output directly — only add sampling_params.
generate_payload = render_data
generate_payload["sampling_params"] = {
"max_tokens": 20,
"temperature": 0.0,
}
print("\n--- Generate ---")
gen_resp = requests.post(f"{BASE_URL}/inference/v1/generate", json=generate_payload)
gen_resp.raise_for_status()
gen_data = gen_resp.json()
# -- Step 4: Decode & print ------------------------------------------
output_ids = gen_data["choices"][0]["token_ids"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Output token count: {len(output_ids)}")
print(f"Generated text: {text!r}")
if "red" in text.lower():
print("\nModel correctly identified the red image.")
else:
print(f"\nWARNING: Expected 'red' in output, got: {text!r}")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Roundtrip tests for multimodal serde used by the disagg generate endpoint."""
import torch
from vllm.entrypoints.serve.disagg.mm_serde import (
decode_mm_kwargs_item,
encode_mm_kwargs_item,
)
from vllm.entrypoints.serve.disagg.protocol import (
MultiModalFeatures,
PlaceholderRangeInfo,
)
from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFieldElem,
MultiModalFlatField,
MultiModalKwargsItem,
MultiModalSharedField,
)
def test_mm_kwargs_item_roundtrip():
"""Full roundtrip test with all three field types and multiple dtypes."""
e1 = MultiModalFieldElem(
data=torch.zeros(1000, dtype=torch.bfloat16),
field=MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
data=torch.ones(100, dtype=torch.int32),
field=MultiModalSharedField(batch_size=4),
)
e3 = MultiModalFieldElem(
data=torch.randn(20, dtype=torch.float32),
field=MultiModalFlatField(slices=[slice(0, 10), slice(10, 20)], dim=0),
)
item = MultiModalKwargsItem({"pixel_values": e1, "grid_thw": e2, "embeds": e3})
encoded = encode_mm_kwargs_item(item)
# Encoded result is a base64 string
assert isinstance(encoded, str)
decoded = decode_mm_kwargs_item(encoded)
assert set(decoded.keys()) == {"pixel_values", "grid_thw", "embeds"}
assert torch.equal(item["pixel_values"].data, decoded["pixel_values"].data)
assert torch.equal(item["grid_thw"].data, decoded["grid_thw"].data)
assert torch.equal(item["embeds"].data, decoded["embeds"].data)
assert isinstance(decoded["pixel_values"].field, MultiModalBatchedField)
assert isinstance(decoded["grid_thw"].field, MultiModalSharedField)
assert isinstance(decoded["embeds"].field, MultiModalFlatField)
def test_mm_kwargs_item_none_data():
"""Roundtrip with None data field."""
elem = MultiModalFieldElem(
data=None,
field=MultiModalSharedField(batch_size=2),
)
item = MultiModalKwargsItem({"empty": elem})
encoded = encode_mm_kwargs_item(item)
decoded = decode_mm_kwargs_item(encoded)
assert decoded["empty"].data is None
assert isinstance(decoded["empty"].field, MultiModalSharedField)
def test_mm_kwargs_item_nested_tensors():
"""Roundtrip with nested tensor data."""
nested = [torch.randn(3, 4), torch.randn(5, 4)]
elem = MultiModalFieldElem(
data=nested,
field=MultiModalBatchedField(),
)
item = MultiModalKwargsItem({"nested": elem})
encoded = encode_mm_kwargs_item(item)
decoded = decode_mm_kwargs_item(encoded)
decoded_data = decoded["nested"].data
assert len(decoded_data) == 2
assert torch.equal(nested[0], decoded_data[0])
assert torch.equal(nested[1], decoded_data[1])
def test_mm_features_with_kwargs_data():
"""Test that MultiModalFeatures can carry serialized tensor data."""
elem = MultiModalFieldElem(
data=torch.randn(5, 3, dtype=torch.float32),
field=MultiModalBatchedField(),
)
item = MultiModalKwargsItem({"pixel_values": elem})
encoded = encode_mm_kwargs_item(item)
features = MultiModalFeatures(
mm_hashes={"image": ["abc123"]},
mm_placeholders={"image": [PlaceholderRangeInfo(offset=0, length=10)]},
kwargs_data={"image": [encoded]},
)
# JSON roundtrip
json_str = features.model_dump_json()
features2 = MultiModalFeatures.model_validate_json(json_str)
assert features2.mm_hashes == {"image": ["abc123"]}
assert features2.kwargs_data is not None
assert len(features2.kwargs_data["image"]) == 1
decoded = decode_mm_kwargs_item(features2.kwargs_data["image"][0])
assert torch.equal(elem.data, decoded["pixel_values"].data)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for multimodal features through the /inference/v1/generate endpoint.
Mirrors test_serving_tokens.py but exercises the multimodal piping
using Qwen/Qwen3-VL-2B-Instruct end-to-end via the server's /render ->
/generate -> /detokenize path. Intentionally avoids running the HF
processor in the pytest parent process to keep os.fork() in sibling
tests (e.g. test_weight_transfer_llm.py) deadlock-free.
"""
import os
import httpx
import pytest
import pytest_asyncio
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_url
MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
GEN_ENDPOINT = "/inference/v1/generate"
RENDER_ENDPOINT = "/v1/chat/completions/render"
DETOKENIZE_ENDPOINT = "/detokenize"
@pytest.fixture(scope="module")
def test_image():
return Image.new("RGB", (224, 224), color=(255, 0, 0))
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
"--no-enable-prefix-caching",
]
envs = os.environ.copy()
envs["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server: RemoteOpenAIServer):
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
async with httpx.AsyncClient(
transport=transport,
base_url=server.url_root,
timeout=600,
headers=headers,
) as c:
yield c
@pytest.mark.asyncio
async def test_render_to_generate_roundtrip(client, test_image):
"""End-to-end: render a multimodal chat -> feed into generate -> decode.
All preprocessing and detokenization happens in the server subprocess;
the pytest parent never imports transformers or touches torch tensors.
"""
data_url = encode_image_url(test_image, format="PNG")
render_payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{
"type": "text",
"text": "What color is this image? Answer in one word.",
},
],
}
],
}
render_resp = await client.post(RENDER_ENDPOINT, json=render_payload)
render_resp.raise_for_status()
render_data = render_resp.json()
# Validate render output structure: keys exist and values are non-empty
# and well-typed.
assert "token_ids" in render_data
assert isinstance(render_data["token_ids"], list)
assert len(render_data["token_ids"]) > 0
assert all(isinstance(t, int) for t in render_data["token_ids"])
assert "features" in render_data
features = render_data["features"]
assert features is not None
assert isinstance(features, dict)
assert "mm_hashes" in features
assert "image" in features["mm_hashes"]
image_hashes = features["mm_hashes"]["image"]
assert isinstance(image_hashes, list)
assert len(image_hashes) > 0
assert all(isinstance(h, str) and h for h in image_hashes)
assert "mm_placeholders" in features
assert "image" in features["mm_placeholders"]
image_placeholders = features["mm_placeholders"]["image"]
assert isinstance(image_placeholders, list)
assert len(image_placeholders) > 0
for p in image_placeholders:
assert isinstance(p.get("offset"), int)
assert isinstance(p.get("length"), int)
assert p["length"] > 0
assert "kwargs_data" in features
assert "image" in features["kwargs_data"]
assert len(features["kwargs_data"]["image"]) > 0
# Build generate request from render output
generate_payload = render_data
generate_payload["sampling_params"] = {
"max_tokens": 10,
"temperature": 0.0,
}
gen_resp = await client.post(GEN_ENDPOINT, json=generate_payload)
gen_resp.raise_for_status()
gen_data = gen_resp.json()
assert "choices" in gen_data
assert isinstance(gen_data["choices"], list)
assert len(gen_data["choices"]) >= 1
choice = gen_data["choices"][0]
assert "token_ids" in choice
assert isinstance(choice["token_ids"], list)
assert len(choice["token_ids"]) > 0
assert all(isinstance(t, int) for t in choice["token_ids"])
detok_resp = await client.post(
DETOKENIZE_ENDPOINT,
json={"model": MODEL_NAME, "tokens": choice["token_ids"]},
)
detok_resp.raise_for_status()
detok_data = detok_resp.json()
assert "prompt" in detok_data
text = detok_data["prompt"]
assert isinstance(text, str)
assert len(text) > 0
assert "red" in text.lower(), (
f"Expected model to identify the red image, got: {text!r}"
)
......@@ -7,17 +7,39 @@ from tests.models.utils import check_embeddings_close
from vllm.utils.serial_utils import (
EMBED_DTYPES,
ENDIANNESS,
MM_METADATA_DTYPES,
EmbedDType,
Endianness,
MmMetadataDType,
binary2tensor,
tensor2binary,
)
FLOAT_EMBED_DTYPES = tuple(EMBED_DTYPES.keys())
INTEGER_EMBED_DTYPES = tuple(MM_METADATA_DTYPES.keys())
def _build_integer_tensor(
embed_dtype: MmMetadataDType, shape: tuple[int, ...]
) -> torch.Tensor:
torch_dtype = MM_METADATA_DTYPES[embed_dtype].torch_dtype
if torch_dtype is torch.bool:
return torch.randint(0, 2, shape, dtype=torch.int32).to(torch.bool)
if torch_dtype is torch.uint8:
return torch.randint(0, 256, shape, dtype=torch.uint8)
if torch_dtype is torch.int32:
return torch.randint(-(2**20), 2**20, shape, dtype=torch.int32)
if torch_dtype is torch.int64:
return torch.randint(-(2**62), 2**62, shape, dtype=torch.int64)
raise AssertionError(f"Unsupported non-floating embed dtype: {embed_dtype}")
@pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPES.keys())
@pytest.mark.parametrize("embed_dtype", FLOAT_EMBED_DTYPES)
@torch.inference_mode()
def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
def test_encode_and_decode_floats(embed_dtype: EmbedDType, endianness: Endianness):
for i in range(10):
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
shape = tensor.shape
......@@ -40,3 +62,20 @@ def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
name_1="new",
tol=1e-2,
)
@pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", INTEGER_EMBED_DTYPES)
@torch.inference_mode()
def test_encode_and_decode_integers(
embed_dtype: MmMetadataDType, endianness: Endianness
):
shape = (2, 3, 5, 7, 11, 13)
for i in range(10):
tensor = _build_integer_tensor(embed_dtype, shape)
binary = tensor2binary(tensor, embed_dtype, endianness)
new_tensor = binary2tensor(binary, shape, embed_dtype, endianness)
assert new_tensor.dtype == MM_METADATA_DTYPES[embed_dtype].torch_dtype
torch.testing.assert_close(tensor, new_tensor, atol=0, rtol=0)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Encode/decode utilities for multimodal tensors and field metadata
over JSON/HTTP, used by the disaggregated generate endpoint."""
from __future__ import annotations
import pybase64
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
_encoder = MsgpackEncoder(size_threshold=2**62) # force all tensors inline
_decoder = MsgpackDecoder(t=MultiModalKwargsItem)
def encode_mm_kwargs_item(item: MultiModalKwargsItem) -> str:
"""Serialize a MultiModalKwargsItem to a base64 string."""
bufs = _encoder.encode(item)
assert len(bufs) == 1, "All tensors should be inline"
return pybase64.b64encode(bufs[0]).decode("ascii")
def decode_mm_kwargs_item(data: str) -> MultiModalKwargsItem:
"""Deserialize a base64 string back to a MultiModalKwargsItem."""
raw = pybase64.b64decode(data)
return _decoder.decode(raw)
......@@ -35,14 +35,6 @@ class MultiModalFeatures(BaseModel):
Carries hashes (for cache lookup / identification) and placeholder
positions so the downstream `/generate` service knows *where* in
the token sequence each multimodal item lives.
Note:
Phase 1 — metadata only.
Phase 2 should add `mm_kwargs` (processed tensor data) using a
binary transport so the ``/generate` side can skip re-processing.
The `/generate` endpoint must also be updated to inject these
features into `EngineInput` before passing to
`InputProcessor.process_inputs`.
"""
mm_hashes: dict[str, list[str]]
......@@ -51,6 +43,15 @@ class MultiModalFeatures(BaseModel):
mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
"""Per-modality placeholder ranges in the token sequence."""
kwargs_data: dict[str, list[str | None]] | None = None
"""Per-modality serialized tensor data.
Each value is a list parallel to ``mm_hashes[modality]``. A ``str``
entry is a base64-encoded ``MultiModalKwargsItem``; ``None`` means
the item should be resolved from cache. The entire field is
``None`` for metadata-only (cache-hit) responses.
"""
class GenerateRequest(BaseModel):
request_id: str = Field(
......
......@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.disagg.mm_serde import decode_mm_kwargs_item
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
......@@ -34,8 +35,14 @@ from vllm.entrypoints.serve.disagg.protocol import (
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import should_include_usage
from vllm.inputs import EngineInput, mm_input
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.inputs import (
MultiModalKwargsItem,
MultiModalKwargsItems,
PlaceholderRange,
)
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.utils.collection_utils import as_list
......@@ -103,10 +110,41 @@ class ServingTokens(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
engine_input: EngineInput
if features := request.features:
# Convert PlaceholderRangeInfo → PlaceholderRange per modality.
mm_placeholders: dict[str, list[PlaceholderRange]] = {
modality: [
PlaceholderRange(offset=p.offset, length=p.length) for p in ranges
]
for modality, ranges in features.mm_placeholders.items()
}
# Deserialize tensor data when present; None → cache hit.
mm_kwargs: dict[str, list[MultiModalKwargsItem | None]] = {}
if features.kwargs_data is not None:
for modality, items in features.kwargs_data.items():
mm_kwargs[modality] = [
decode_mm_kwargs_item(item) if item is not None else None
for item in items
]
else:
for modality, hashes in features.mm_hashes.items():
mm_kwargs[modality] = [None] * len(hashes)
engine_input = mm_input(
prompt_token_ids=request.token_ids,
mm_kwargs=MultiModalKwargsItems(mm_kwargs),
mm_hashes=features.mm_hashes,
mm_placeholders=mm_placeholders,
cache_salt=request.cache_salt,
)
else:
(engine_input,) = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,
skip_mm_cache=True,
)
# Schedule the request and get the result generator.
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from http import HTTPStatus
from typing import Any
from typing import Any, cast
from openai_harmony import Message as OpenAIMessage
......@@ -25,6 +25,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
render_for_completion,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.serve.disagg.mm_serde import encode_mm_kwargs_item
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
MultiModalFeatures,
......@@ -37,6 +38,7 @@ from vllm.entrypoints.utils import (
from vllm.inputs import (
EngineInput,
MultiModalHashes,
MultiModalInput,
MultiModalPlaceholders,
PromptType,
SingletonPrompt,
......@@ -251,6 +253,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
skip_mm_cache=True,
reasoning_parser=self.reasoning_parser,
)
else:
......@@ -342,6 +345,7 @@ class OpenAIServingRender:
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
skip_mm_cache=True,
)
return engine_inputs
......@@ -357,9 +361,10 @@ class OpenAIServingRender:
if engine_input.get("type") != "multimodal":
return None
# At this point engine_input is a MultiModalInputs TypedDict.
mm_hashes: MultiModalHashes = engine_input["mm_hashes"] # type: ignore[typeddict-item]
raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"] # type: ignore[typeddict-item]
# At this point engine_input is a MultiModalInput TypedDict.
mm_engine_input = cast(MultiModalInput, engine_input)
mm_hashes: MultiModalHashes = mm_engine_input["mm_hashes"]
raw_placeholders: MultiModalPlaceholders = mm_engine_input["mm_placeholders"]
mm_placeholders = {
modality: [
......@@ -368,9 +373,20 @@ class OpenAIServingRender:
for modality, ranges in raw_placeholders.items()
}
# Serialize tensor data per modality.
kwargs_data: dict[str, list[str | None]] | None = None
if raw_mm_kwargs := mm_engine_input.get("mm_kwargs"):
kwargs_data = {}
for modality, items in raw_mm_kwargs.items():
kwargs_data[modality] = [
encode_mm_kwargs_item(item) if item is not None else None
for item in items
]
return MultiModalFeatures(
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
kwargs_data=kwargs_data,
)
def _make_request_with_harmony(
......
......@@ -27,6 +27,7 @@ class DTypeInfo:
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
MmMetadataDType = Literal["int32", "int64", "uint8", "bool"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
......@@ -42,6 +43,15 @@ EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
"fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
"fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
}
MM_METADATA_DTYPES: Mapping[MmMetadataDType, DTypeInfo] = {
"int32": DTypeInfo(torch.int32, torch.int32, np.int32),
"int64": DTypeInfo(torch.int64, torch.int64, np.int64),
"uint8": DTypeInfo(torch.uint8, torch.uint8, np.uint8),
"bool": DTypeInfo(torch.bool, torch.uint8, np.uint8),
}
_ALL_SERIAL_DTYPES: Mapping[str, DTypeInfo] = {
k: v for d in (EMBED_DTYPES, MM_METADATA_DTYPES) for k, v in d.items()
}
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)
......@@ -56,14 +66,14 @@ def tensor2base64(x: torch.Tensor) -> str:
def tensor2binary(
tensor: torch.Tensor,
embed_dtype: EmbedDType,
embed_dtype: "EmbedDType | MmMetadataDType",
endianness: Endianness,
) -> bytes:
assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPES
assert embed_dtype in _ALL_SERIAL_DTYPES
assert endianness in ENDIANNESS
dtype_info = EMBED_DTYPES[embed_dtype]
dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
np_array = (
tensor.to(dtype_info.torch_dtype)
......@@ -82,13 +92,13 @@ def tensor2binary(
def binary2tensor(
binary: bytes,
shape: tuple[int, ...],
embed_dtype: EmbedDType,
embed_dtype: "EmbedDType | MmMetadataDType",
endianness: Endianness,
) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPES
assert embed_dtype in _ALL_SERIAL_DTYPES
assert endianness in ENDIANNESS
dtype_info = EMBED_DTYPES[embed_dtype]
dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment