Unverified Commit 3ccf39e9 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: [vLLM] Abstract embedding loading to a class. Encapsulate frontend...


refactor: [vLLM] Abstract embedding loading to a class. Encapsulate frontend decoding receiver detail into ImageLoader (#7482)
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 19fc7660
...@@ -21,6 +21,7 @@ from safetensors import torch as safetensors_torch ...@@ -21,6 +21,7 @@ from safetensors import torch as safetensors_torch
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.runtime import run_async
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -826,19 +827,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -826,19 +827,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
self.aggregated_op_wait_time = 0 self.aggregated_op_wait_time = 0
self.warmedup_descriptors: Queue[nixl_connect.Descriptor] = Queue() self.warmedup_descriptors: Queue[nixl_connect.Descriptor] = Queue()
self.inuse_descriptors: dict[int, tuple[nixl_connect.Descriptor, bool]] = {} self.inuse_descriptors: dict[int, tuple[nixl_connect.Descriptor, bool]] = {}
# Handle both sync and async contexts connection = run_async(self.connector._create_connection)
try:
asyncio.get_running_loop() # Check if we're in async context
# If we're in an async context, we need to run the connection creation in a separate thread to avoid blocking the event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
connection = pool.submit(
asyncio.run, self.connector._create_connection()
).result(timeout=10)
except RuntimeError:
# No running loop - safe to use asyncio.run()
connection = asyncio.run(self.connector._create_connection())
# Create descriptor for our allocated tensor # Create descriptor for our allocated tensor
for _ in range(max_items): for _ in range(max_items):
encodings_tensor = torch.zeros( encodings_tensor = torch.zeros(
......
...@@ -19,7 +19,7 @@ import binascii ...@@ -19,7 +19,7 @@ import binascii
import logging import logging
import os import os
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Final, List, Optional from typing import Any, Dict, Final, List
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
...@@ -28,6 +28,7 @@ from PIL import Image ...@@ -28,6 +28,7 @@ from PIL import Image
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
from .http_client import get_http_client from .http_client import get_http_client
...@@ -43,11 +44,35 @@ class ImageLoader: ...@@ -43,11 +44,35 @@ class ImageLoader:
CACHE_SIZE_MAXIMUM = int(os.environ.get("DYN_MM_IMAGE_CACHE_SIZE", "8")) CACHE_SIZE_MAXIMUM = int(os.environ.get("DYN_MM_IMAGE_CACHE_SIZE", "8"))
def __init__( def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0 self,
cache_size: int = CACHE_SIZE_MAXIMUM,
http_timeout: float = 30.0,
enable_frontend_decoding: bool = False,
): ):
"""
Initialize the ImageLoader with caching, HTTP settings, and optional NIXL config for
receiving frontend decoding.
Args:
cache_size: Maximum number of images to store in the in-memory LRU cache.
Defaults to CACHE_SIZE_MAXIMUM.
http_timeout: Timeout in seconds for HTTP requests when fetching remote images.
Defaults to 30.0 seconds.
enable_frontend_decoding: If True, enables NIXL RDMA for transferring
decoded images directly from frontend memory, bypassing standard
network transport. Defaults to False.
"""
self._http_timeout = http_timeout self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {} self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
self._enable_frontend_decoding = enable_frontend_decoding
# Lazy-init NIXL connector only when frontend decoding is enabled
self._nixl_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(
self._nixl_connector.initialize
) # Synchronously wait for async init
@_nvtx.annotate("mm:img:load_image", color="lime") @_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image: async def load_image(self, image_url: str) -> Image.Image:
...@@ -137,8 +162,6 @@ class ImageLoader: ...@@ -137,8 +162,6 @@ class ImageLoader:
async def load_image_batch( async def load_image_batch(
self, self,
image_mm_items: List[Dict[str, Any]], image_mm_items: List[Dict[str, Any]],
enable_frontend_decoding: bool = False,
nixl_connector: Optional["nixl_connect.Connector"] = None,
) -> List[Any]: ) -> List[Any]:
""" """
Load a batch of images from multimodal data items. Load a batch of images from multimodal data items.
...@@ -149,8 +172,6 @@ class ImageLoader: ...@@ -149,8 +172,6 @@ class ImageLoader:
Args: Args:
image_mm_items: List of multimodal data items for images image_mm_items: List of multimodal data items for images
enable_frontend_decoding: If True, enables NIXL RDMA for decoded images
nixl_connector: NIXL connector for frontend decoding (required if enable_frontend_decoding=True)
Returns: Returns:
List of loaded image data List of loaded image data
...@@ -168,19 +189,10 @@ class ImageLoader: ...@@ -168,19 +189,10 @@ class ImageLoader:
image_futures.append(self.load_image(url)) image_futures.append(self.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...") logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item: elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if enable_frontend_decoding: if self._enable_frontend_decoding:
if nixl_connector is None:
logger.error(
"Frontend decoding enabled but nixl_connector not provided. "
"Caller must pass an initialized NIXL connector."
)
raise ValueError(
"nixl_connector required when enable_frontend_decoding=True"
)
metadata = item[DECODED_VARIANT_KEY] metadata = item[DECODED_VARIANT_KEY]
image_futures.append( image_futures.append(
read_decoded_media_via_nixl(nixl_connector, metadata) read_decoded_media_via_nixl(self._nixl_connector, metadata)
) )
else: else:
logger.error( logger.error(
......
...@@ -6,8 +6,9 @@ Common runtime utilities shared across Dynamo engine backends. ...@@ -6,8 +6,9 @@ Common runtime utilities shared across Dynamo engine backends.
Provides: Provides:
- parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings - parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings
- graceful_shutdown: Shutdown DistributedRuntime with optional event signaling
- create_runtime: Create DistributedRuntime. - create_runtime: Create DistributedRuntime.
- run_async: Helper to run async functions in non-async functions that
may be run in either sync or async context.
""" """
import asyncio import asyncio
...@@ -70,3 +71,28 @@ def create_runtime( ...@@ -70,3 +71,28 @@ def create_runtime(
runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats) runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats)
return runtime, loop return runtime, loop
def run_async(func, *args, **kwargs):
"""Run an async function as if it is synchronous, handling both sync and async contexts.
Args:
func: An async function to execute.
*args: Positional arguments to pass to the function.
**kwargs: Keyword arguments to pass to the function.
Returns:
The result of the async function.
"""
# Check if we're in async context, exception is raised if not and we can safely
# run 'func' with asyncio.run()
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(func(*args, **kwargs))
# In an async context, we want to run 'func' in a separate thread to avoid blocking the event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, func(*args, **kwargs)).result()
...@@ -23,7 +23,6 @@ from vllm.outputs import RequestOutput ...@@ -23,7 +23,6 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
...@@ -352,14 +351,9 @@ class BaseWorkerHandler(ABC): ...@@ -352,14 +351,9 @@ class BaseWorkerHandler(ABC):
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
self.config = config self.config = config
self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event) self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len self.model_max_len = model_max_len
self.enable_multimodal = enable_multimodal self.enable_multimodal = enable_multimodal
self.enable_frontend_decoding = enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
self._nixl_connector: nixl_connect.Connector | None = None
self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking: name -> LoRAInfo(id, path) # LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {} self.loaded_loras: dict[str, LoRAInfo] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA # Per-LoRA locks to prevent concurrent load operations for the same LoRA
...@@ -367,6 +361,10 @@ class BaseWorkerHandler(ABC): ...@@ -367,6 +361,10 @@ class BaseWorkerHandler(ABC):
# Guard lock-map access in case handlers are invoked from multiple threads. # Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock() self._lora_load_locks_guard = threading.Lock()
self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config) self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config)
...@@ -1014,18 +1012,9 @@ class BaseWorkerHandler(ABC): ...@@ -1014,18 +1012,9 @@ class BaseWorkerHandler(ABC):
mm_map = request["multi_modal_data"] mm_map = request["multi_modal_data"]
vllm_mm_data = {} vllm_mm_data = {}
# Lazy-init NIXL connector only when frontend decoding is enabled
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
# Process image_url entries # Process image_url entries
images = await self.image_loader.load_image_batch( images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []), mm_map.get(IMAGE_URL_KEY, []),
enable_frontend_decoding=self.enable_frontend_decoding,
nixl_connector=self._nixl_connector,
) )
if images: if images:
......
...@@ -4,14 +4,12 @@ ...@@ -4,14 +4,12 @@
import copy import copy
import logging import logging
import uuid import uuid
from collections import defaultdict
from typing import Any from typing import Any
import torch import torch
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
...@@ -34,7 +32,7 @@ from ..multimodal_utils import ( ...@@ -34,7 +32,7 @@ from ..multimodal_utils import (
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
from ..multimodal_utils.model import is_qwen_vl_model from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings from ..multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -71,17 +69,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -71,17 +69,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
) )
self.config = config self.config = config
self.encode_worker_client = encode_worker_client
self.decode_worker_client = decode_worker_client self.decode_worker_client = decode_worker_client
self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
# Initialize multimodal-specific components # Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.") logger.info("Multimodal PD Worker startup started.")
...@@ -91,12 +80,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -91,12 +80,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
else: else:
self.EMBEDDINGS_DTYPE = torch.float16 self.EMBEDDINGS_DTYPE = torch.float16
# Create and initialize a dynamo connector for this worker. # Embedding loader consist of two main components:
# We'll need this to move data between this worker and remote workers efficiently. # 1) An remote encode worker client and matching embedding receiver,
# Note: This is synchronous initialization, async initialization happens in async_init # which can request remote encode and handle the transfer of embeddings
self._connector: connect.Connector | None = ( # from the encode worker to this prefill worker.
None # Will be initialized in async_init # 2) A local embedding cache manager, which can store previously fetched embeddings
) # and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL: if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver() self.embedding_receiver = LocalEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE: elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
...@@ -109,13 +99,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -109,13 +99,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
raise ValueError( raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}" f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
) )
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
self.embedding_loader = MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager,
)
logger.info("Multimodal PD Worker has been initialized") logger.info("Multimodal PD Worker has been initialized")
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
"""Async initialization for connector that requires async setup""" """Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.") logger.info("Multimodal PD Worker async initialization completed.")
def _parse_frontend_request( def _parse_frontend_request(
...@@ -164,17 +165,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -164,17 +165,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
Returns an empty dict when no encode worker is configured or no images Returns an empty dict when no encode worker is configured or no images
are present. are present.
""" """
if not self.encode_worker_client or not image_urls:
return defaultdict(list)
return await load_multimodal_embeddings( return await self.embedding_loader.load_multimodal_embeddings(
self.encode_worker_client, # type: ignore[arg-type]
image_urls, image_urls,
request_id, request_id,
self.embedding_receiver,
model=self.config.model, model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE, embeddings_dtype=self.EMBEDDINGS_DTYPE,
cache=self.embedding_cache_manager,
context=context, context=context,
) )
...@@ -212,10 +208,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -212,10 +208,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if not value: if not value:
del multi_modal_data[key] del multi_modal_data[key]
else: else:
# [gluo FIXME] should be mindful to default dict, move this evaluation logic to here
# so that we don't accidentally add empty keys to the dict which causes vLLM misbehavior
logger.debug( logger.debug(
f"Prepared multimodal data size: {len(multi_modal_data[key])}" f"Prepared multimodal data key {key}, number of items: {len(multi_modal_data[key])}"
) )
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys())) logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
......
...@@ -19,7 +19,7 @@ from dynamo.vllm.multimodal_utils.model import ( ...@@ -19,7 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data, construct_mm_data,
load_vision_model, load_vision_model,
) )
from dynamo.vllm.multimodal_utils.prefill_worker_utils import load_multimodal_embeddings from dynamo.vllm.multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup, MultiModalGroup,
MultiModalInput, MultiModalInput,
...@@ -48,5 +48,5 @@ __all__ = [ ...@@ -48,5 +48,5 @@ __all__ = [
"MultiModalRequest", "MultiModalRequest",
"MyRequestOutput", "MyRequestOutput",
"vLLMMultimodalRequest", "vLLMMultimodalRequest",
"load_multimodal_embeddings", "MultiModalEmbeddingLoader",
] ]
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import asyncio import asyncio
import logging import logging
import os import os
from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
...@@ -92,32 +91,32 @@ def _accumulate_embeddings( ...@@ -92,32 +91,32 @@ def _accumulate_embeddings(
image_grid_thw=image_grid_thw, image_grid_thw=image_grid_thw,
) )
if "image" not in multi_modal_data:
multi_modal_data["image"] = mm_data["image"]
return
if isinstance(mm_data["image"], dict): if isinstance(mm_data["image"], dict):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors # Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if multi_modal_data["image"] == []: multi_modal_data["image"]["image_embeds"] = torch.cat(
multi_modal_data["image"] = mm_data["image"] (
else: multi_modal_data["image"]["image_embeds"],
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings mm_data["image"]["image_embeds"],
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
mm_data["image"]["image_embeds"],
)
) )
multi_modal_data["image"]["image_grid_thw"] = torch.cat( )
( multi_modal_data["image"]["image_grid_thw"] = torch.cat(
multi_modal_data["image"]["image_grid_thw"], (
mm_data["image"]["image_grid_thw"], multi_modal_data["image"]["image_grid_thw"],
) mm_data["image"]["image_grid_thw"],
) )
)
elif isinstance(mm_data["image"], torch.Tensor):
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else: else:
# [gluo FIXME] embedding with multiple images? raise ValueError(
if multi_modal_data["image"] == []: f"Unexpected image data format from construct_mm_data: {type(mm_data['image'])}"
multi_modal_data["image"] = mm_data["image"] )
else:
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
def _ensure_owned_tensors(multi_modal_data: Dict[str, Any]) -> None: def _ensure_owned_tensors(multi_modal_data: Dict[str, Any]) -> None:
...@@ -287,53 +286,67 @@ async def _fetch_embeddings( ...@@ -287,53 +286,67 @@ async def _fetch_embeddings(
# ── Public API (single entry point) ───────────────────────────────── # ── Public API (single entry point) ─────────────────────────────────
async def load_multimodal_embeddings( class MultiModalEmbeddingLoader:
encode_worker_client: Client, """Helper class for requesting remote encode and receive embeddings."""
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
*,
model: str,
embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
cache check → remote fetch → cache update → accumulate → release NIXL buffers.
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
groups, pending = await _fetch_embeddings(
encode_worker_client,
image_urls,
request_id,
receiver,
cache=cache,
context=context,
)
multi_modal_data: Dict[str, Any] = defaultdict(list) def __init__(
with time_and_log_code_section( self,
f"[PREFILL] request: {request_id} accumulate embeddings" encode_worker_client: Client,
receiver: AbstractEmbeddingReceiver,
embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None,
): ):
for group in groups: self._encode_worker_client = encode_worker_client
assert group.loaded_embedding is not None self._receiver = receiver
_accumulate_embeddings( self._embedding_cache_manager = embedding_cache_manager
multi_modal_data,
model, async def load_multimodal_embeddings(
embeddings_dtype, self,
group.loaded_embedding, image_urls: list[str],
group.image_grid_thw, request_id: str,
) *,
model: str,
embeddings_dtype: torch.dtype,
context=None,
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
cache check → remote fetch → cache update → accumulate → release NIXL buffers.
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
if not self._encode_worker_client or not image_urls:
return {}
groups, pending = await _fetch_embeddings(
self._encode_worker_client,
image_urls,
request_id,
self._receiver,
cache=self._embedding_cache_manager,
context=context,
)
if pending is not None: multi_modal_data: Dict[str, Any] = {}
# Multi-image: torch.cat in _accumulate_embeddings already created with time_and_log_code_section(
# owned tensors. Single-image: the data is still a view into the f"[PREFILL] request: {request_id} accumulate embeddings"
# NIXL buffer, so we must clone before releasing. ):
if len(groups) == 1: for group in groups:
_ensure_owned_tensors(multi_modal_data) assert group.loaded_embedding is not None
pending.release_all() _accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
return multi_modal_data if pending is not None:
# Multi-image: torch.cat in _accumulate_embeddings already created
# owned tensors. Single-image: the data is still a view into the
# NIXL buffer, so we must clone before releasing.
if len(groups) == 1:
_ensure_owned_tensors(multi_modal_data)
pending.release_all()
return multi_modal_data
...@@ -182,7 +182,7 @@ class TestLoadMultimodalData: ...@@ -182,7 +182,7 @@ class TestLoadMultimodalData:
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) # type: ignore fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) # type: ignore
with patch.object( with patch.object(
mod, handler.embedding_loader,
"load_multimodal_embeddings", "load_multimodal_embeddings",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=fake_mm_data, return_value=fake_mm_data,
...@@ -192,24 +192,6 @@ class TestLoadMultimodalData: ...@@ -192,24 +192,6 @@ class TestLoadMultimodalData:
mock_load.assert_awaited_once() mock_load.assert_awaited_once()
assert result is fake_mm_data assert result is fake_mm_data
@pytest.mark.asyncio
async def test_passes_cache_to_load_multimodal_embeddings(self):
"""With cache enabled -> passes cache manager kwarg."""
mock_client = MagicMock()
config = _make_config(multimodal_embedding_cache_capacity_gb=1.0)
handler = _make_handler(config=config, encode_worker_client=mock_client)
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=defaultdict(list),
) as mock_load:
await handler._load_multimodal_data(["http://img.png"], "req-1")
mock_load.assert_awaited_once()
assert mock_load.call_args.kwargs["cache"] is handler.embedding_cache_manager
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_passes_model_and_dtype(self): async def test_passes_model_and_dtype(self):
"""Model name and embeddings dtype are forwarded.""" """Model name and embeddings dtype are forwarded."""
...@@ -217,7 +199,7 @@ class TestLoadMultimodalData: ...@@ -217,7 +199,7 @@ class TestLoadMultimodalData:
handler = _make_handler(encode_worker_client=mock_client) handler = _make_handler(encode_worker_client=mock_client)
with patch.object( with patch.object(
mod, handler.embedding_loader,
"load_multimodal_embeddings", "load_multimodal_embeddings",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=defaultdict(list), return_value=defaultdict(list),
......
...@@ -26,7 +26,7 @@ MODEL = "test-model" ...@@ -26,7 +26,7 @@ MODEL = "test-model"
DTYPE = torch.float16 DTYPE = torch.float16
class TestLoadMultimodalEmbeddings: class TestMultimodalEmbeddingsLoader:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_cached(self): async def test_all_cached(self):
"""All URLs cached -> no encode worker call, returns accumulated mm_data.""" """All URLs cached -> no encode worker call, returns accumulated mm_data."""
...@@ -42,14 +42,12 @@ class TestLoadMultimodalEmbeddings: ...@@ -42,14 +42,12 @@ class TestLoadMultimodalEmbeddings:
"_fetch_from_encode_workers", "_fetch_from_encode_workers",
new_callable=AsyncMock, new_callable=AsyncMock,
) as mock_fetch: ) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings( embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
AsyncMock(), mm_data = await embedding_loader.load_multimodal_embeddings(
[url], [url],
"req-1", "req-1",
None,
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE, embeddings_dtype=DTYPE,
cache=cache,
) )
mock_fetch.assert_not_awaited() mock_fetch.assert_not_awaited()
...@@ -73,14 +71,12 @@ class TestLoadMultimodalEmbeddings: ...@@ -73,14 +71,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=([fake_group], None), return_value=([fake_group], None),
) as mock_fetch: ) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings( embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
AsyncMock(), mm_data = await embedding_loader.load_multimodal_embeddings(
[url], [url],
"req-1", "req-1",
None,
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE, embeddings_dtype=DTYPE,
cache=cache,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
...@@ -107,14 +103,12 @@ class TestLoadMultimodalEmbeddings: ...@@ -107,14 +103,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=([fake_group], None), return_value=([fake_group], None),
) as mock_fetch: ) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings( embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, None)
AsyncMock(), mm_data = await embedding_loader.load_multimodal_embeddings(
[url], [url],
"req-1", "req-1",
None,
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE, embeddings_dtype=DTYPE,
cache=None,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
...@@ -145,14 +139,12 @@ class TestLoadMultimodalEmbeddings: ...@@ -145,14 +139,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=([fake_group], None), return_value=([fake_group], None),
) as mock_fetch: ) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings( embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
AsyncMock(), mm_data = await embedding_loader.load_multimodal_embeddings(
[url_cached, url_miss], [url_cached, url_miss],
"req-1", "req-1",
None,
model=MODEL, model=MODEL,
embeddings_dtype=DTYPE, embeddings_dtype=DTYPE,
cache=cache,
) )
mock_fetch.assert_awaited_once() mock_fetch.assert_awaited_once()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment