Unverified Commit f46498f2 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: add sglang embeding cache (#7674)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 93530057
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import hashlib
import json
import logging
from typing import Any, AsyncIterator, Dict, Optional
......@@ -17,6 +18,10 @@ from sglang.srt.parser.conversation import chat_templates
from transformers import AutoTokenizer
from dynamo._core import Client, Context
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal import EMBEDDING_SENDER_FACTORIES
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.sglang.args import Config
......@@ -123,9 +128,133 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
)
self.embedding_sender = sender()
# Optional CPU-side LRU embedding cache
self._embedding_cache: MultimodalEmbeddingCacheManager | None = None
capacity_gb = config.dynamo_args.multimodal_embedding_cache_capacity_gb
if capacity_gb > 0:
capacity_bytes = int(capacity_gb * 1024**3)
self._embedding_cache = MultimodalEmbeddingCacheManager(capacity_bytes)
logger.info("Multimodal embedding cache enabled: %.2f GB", capacity_gb)
def cleanup(self) -> None:
pass
@staticmethod
def _url_hash(url: str) -> str:
"""Stable blake2b hash of an image URL, used as embedding cache key."""
return hashlib.blake2b(url.encode(), digest_size=32).hexdigest()
@staticmethod
def _split_token_counts(grid_list: list, total_tokens: int) -> list[int]:
"""Compute per-image embedding token counts from image_grid_thw shapes.
Each entry in grid_list is [t, h, w]. The spatial grid size (h*w) is
proportional to the number of tokens for that image. We infer the shared
merge factor from the ratio of total grid tokens to total embedding tokens,
then apply it per image.
"""
if total_tokens <= 0:
raise ValueError("Invalid token count for embeddings")
grid_sizes = []
for image_grid_thw in grid_list:
if not isinstance(image_grid_thw, list) or len(image_grid_thw) != 3:
raise ValueError(f"Invalid image_grid_thw: {image_grid_thw}")
grid_sizes.append(int(image_grid_thw[1] * image_grid_thw[2]))
total_grid_tokens = sum(grid_sizes)
if total_grid_tokens <= 0:
raise ValueError("Invalid grid statistics for embeddings")
if total_grid_tokens % total_tokens != 0:
raise ValueError(
"Cannot infer merge factor: grid token total is not divisible "
"by embedding token total"
)
merge_factor = total_grid_tokens // total_tokens
token_counts = []
for grid_count in grid_sizes:
if grid_count % merge_factor != 0:
raise ValueError(
"Cannot split embeddings: per-image grid token count not "
"divisible by inferred merge factor"
)
token_counts.append(grid_count // merge_factor)
if sum(token_counts) != total_tokens:
raise ValueError(
"Cannot split embeddings: per-image token counts do not match "
"embedding token total"
)
return token_counts
async def _encode_with_cache(
self, image_urls: list[str]
) -> tuple[Any, torch.Tensor]:
"""Cache-aware vision encoding.
Checks the CPU LRU cache per URL. Uncached URLs are batch-encoded,
split per image, stored in cache, then reassembled with the cached
hits in the original URL order.
Returns the same (image_grid_dim, embeddings) shape as
``self.encoder._encode()``.
"""
assert self._embedding_cache is not None
cached: dict[int, CachedEmbedding] = {}
uncached_indices: list[int] = []
uncached_urls: list[str] = []
for i, url in enumerate(image_urls):
hit = self._embedding_cache.get(self._url_hash(url))
if hit is not None:
logger.debug("Embedding cache hit for URL index %d", i)
cached[i] = hit
else:
uncached_indices.append(i)
uncached_urls.append(url)
new_entries: dict[int, CachedEmbedding] = {}
# SGLang's _encode outputs are already on CPU; use CPU as target for consistency
target_device = torch.device("cpu")
if uncached_urls:
grid_dim, new_embeddings = await self.encoder._encode(uncached_urls)
# Verify SGLang output is on CPU as expected
if new_embeddings.device != target_device:
logger.warning(
f"SGLang _encode returned embeddings on {new_embeddings.device}, "
f"expected CPU. Moving to CPU."
)
new_embeddings = new_embeddings.to(target_device)
grid_list: list = (
grid_dim.tolist() if isinstance(grid_dim, torch.Tensor) else grid_dim
)
if not (
isinstance(new_embeddings, torch.Tensor) and new_embeddings.ndim == 2
):
raise ValueError(
f"Unsupported embeddings type from encoder: {type(new_embeddings)}"
)
token_counts = self._split_token_counts(grid_list, new_embeddings.shape[0])
split_tensors = torch.split(new_embeddings, token_counts, dim=0)
for orig_idx, url, tensor, grid_thw in zip(
uncached_indices, uncached_urls, split_tensors, grid_list
):
entry = CachedEmbedding(
tensor=tensor.contiguous(),
image_grid_thw=grid_thw,
)
self._embedding_cache.set(self._url_hash(url), entry)
new_entries[orig_idx] = entry
# Reassemble results in original URL order
all_grid_thw: list = []
embedding_parts: list[torch.Tensor] = []
for i in range(len(image_urls)):
entry = cached[i] if i in cached else new_entries[i]
all_grid_thw.append(entry.image_grid_thw)
embedding_parts.append(entry.tensor)
full_embeddings = torch.cat(embedding_parts, dim=0)
return torch.tensor(all_grid_thw), full_embeddings
def _extract_image_urls(self, request: Dict[str, Any]) -> list[str]:
"""
Extract image URLs from the multi_modal_data field of a PreprocessedRequest.
......@@ -200,9 +329,15 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
try:
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
)
if self._embedding_cache is not None:
(
image_grid_dim,
precomputed_embeddings,
) = await self._encode_with_cache(image_urls)
else:
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
)
image_grid_thw_list = (
image_grid_dim.tolist()
......@@ -213,54 +348,15 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
if len(image_grid_thw_list) != len(multimodal_groups):
raise ValueError("image_grid_thw size mismatch")
def _build_token_counts(total_tokens: int) -> list[int]:
if total_tokens <= 0:
raise ValueError("Invalid token statistics for embeddings")
# image_grid_thw is [t, h, w]. We derive per-item relative sizes
# from spatial grid (h * w), then infer merge factor
# from the total embedding token count.
grid_sizes = []
for image_grid_thw in image_grid_thw_list:
if not isinstance(image_grid_thw, list) or len(image_grid_thw) != 3:
raise ValueError(
"Cannot split embeddings: invalid image_grid_thw"
)
grid_sizes.append(int(image_grid_thw[1] * image_grid_thw[2]))
total_grid_tokens = sum(grid_sizes)
if total_grid_tokens <= 0:
raise ValueError("Invalid grid statistics for embeddings")
if total_grid_tokens % total_tokens != 0:
raise ValueError(
"Cannot infer merge factor: grid token total is not divisible by embedding token total"
)
merge_factor = total_grid_tokens // total_tokens
token_counts = []
for grid_count in grid_sizes:
if grid_count % merge_factor != 0:
raise ValueError(
"Cannot split embeddings: per-image grid token count not divisible by inferred merge factor"
)
token_counts.append(grid_count // merge_factor)
if sum(token_counts) != total_tokens:
raise ValueError(
"Cannot split embeddings: per-image token counts do not match embedding token total"
)
return token_counts
if isinstance(precomputed_embeddings, torch.Tensor):
if precomputed_embeddings.ndim != 2:
raise ValueError(
"Unsupported embeddings tensor rank from encoder: "
f"{precomputed_embeddings.ndim}. Expected 2D [tokens, hidden]."
)
token_counts = _build_token_counts(precomputed_embeddings.shape[0])
token_counts = self._split_token_counts(
image_grid_thw_list, precomputed_embeddings.shape[0]
)
else:
raise ValueError(
"Unsupported embeddings type from encoder: "
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for SGLang multimodal embedding cache behavior."""
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.sglang.request_handlers.multimodal.encode_worker_handler import (
MultimodalEncodeWorkerHandler,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.sglang,
pytest.mark.gpu_1, # sglang tests run on GPU-enabled workers
pytest.mark.max_vram_gib(0),
pytest.mark.pre_merge,
]
@pytest.fixture
def cache_handler() -> MultimodalEncodeWorkerHandler:
"""Create a lightweight handler instance for cache-path unit tests."""
handler = MultimodalEncodeWorkerHandler.__new__(MultimodalEncodeWorkerHandler)
handler._embedding_cache = MultimodalEmbeddingCacheManager(
capacity_bytes=32 * 1024 * 1024
)
handler.encoder = SimpleNamespace(_encode=AsyncMock())
return handler
@pytest.mark.asyncio
async def test_encode_with_cache_partial_hit_and_reuse(
cache_handler: MultimodalEncodeWorkerHandler,
) -> None:
"""Partial-hit should encode only misses and preserve URL order in output."""
urls = [
"http://example.com/a.jpg",
"http://example.com/b.jpg",
"http://example.com/c.jpg",
]
# Pre-cache url[1] (4 tokens x 3 hidden)
cached_tensor = torch.full((4, 3), fill_value=-1.0)
cache_handler._embedding_cache.set(
cache_handler._url_hash(urls[1]),
CachedEmbedding(tensor=cached_tensor, image_grid_thw=[1, 2, 2]),
)
# Encode only misses url[0], url[2]: token counts [8, 4]
encoded = torch.arange(12 * 3, dtype=torch.float32).reshape(12, 3)
cache_handler.encoder._encode.return_value = (
torch.tensor([[1, 2, 4], [1, 2, 2]]),
encoded,
)
grid, full_embeddings = await cache_handler._encode_with_cache(urls)
# Encoder called once for uncached URLs only
cache_handler.encoder._encode.assert_awaited_once_with([urls[0], urls[2]])
# Order should match original URL order: a(8), b(4 cached), c(4)
assert grid.tolist() == [[1, 2, 4], [1, 2, 2], [1, 2, 2]]
assert torch.equal(full_embeddings[:8], encoded[:8])
assert torch.equal(full_embeddings[8:12], cached_tensor)
assert torch.equal(full_embeddings[12:16], encoded[8:12])
# Second call should be all-cache hit: no additional encoder calls
grid2, full_embeddings2 = await cache_handler._encode_with_cache(urls)
assert cache_handler.encoder._encode.await_count == 1
assert grid2.tolist() == grid.tolist()
assert torch.equal(full_embeddings2, full_embeddings)
@pytest.mark.asyncio
async def test_encode_with_cache_all_hit_no_remote_call(
cache_handler: MultimodalEncodeWorkerHandler,
) -> None:
"""All-cache-hit path should not call encoder at all."""
urls = ["http://example.com/x.jpg", "http://example.com/y.jpg"]
x = torch.ones(2, 3)
y = torch.ones(1, 3) * 9
cache_handler._embedding_cache.set(
cache_handler._url_hash(urls[0]),
CachedEmbedding(tensor=x, image_grid_thw=[1, 1, 2]),
)
cache_handler._embedding_cache.set(
cache_handler._url_hash(urls[1]),
CachedEmbedding(tensor=y, image_grid_thw=[1, 1, 1]),
)
grid, full_embeddings = await cache_handler._encode_with_cache(urls)
cache_handler.encoder._encode.assert_not_called()
assert grid.tolist() == [[1, 1, 2], [1, 1, 1]]
assert torch.equal(full_embeddings, torch.cat([x, y], dim=0))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment