Unverified Commit 3ccf39e9 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: [vLLM] Abstract embedding loading to a class. Encapsulate frontend...


refactor: [vLLM] Abstract embedding loading to a class. Encapsulate frontend decoding receiver detail into ImageLoader (#7482)
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 19fc7660
......@@ -21,6 +21,7 @@ from safetensors import torch as safetensors_torch
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.runtime import run_async
logger = logging.getLogger(__name__)
......@@ -826,19 +827,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
self.aggregated_op_wait_time = 0
self.warmedup_descriptors: Queue[nixl_connect.Descriptor] = Queue()
self.inuse_descriptors: dict[int, tuple[nixl_connect.Descriptor, bool]] = {}
# Handle both sync and async contexts
try:
asyncio.get_running_loop() # Check if we're in async context
# If we're in an async context, we need to run the connection creation in a separate thread to avoid blocking the event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
connection = pool.submit(
asyncio.run, self.connector._create_connection()
).result(timeout=10)
except RuntimeError:
# No running loop - safe to use asyncio.run()
connection = asyncio.run(self.connector._create_connection())
connection = run_async(self.connector._create_connection)
# Create descriptor for our allocated tensor
for _ in range(max_items):
encodings_tensor = torch.zeros(
......
......@@ -19,7 +19,7 @@ import binascii
import logging
import os
from io import BytesIO
from typing import Any, Dict, Final, List, Optional
from typing import Any, Dict, Final, List
from urllib.parse import urlparse
import httpx
......@@ -28,6 +28,7 @@ from PIL import Image
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
from .http_client import get_http_client
......@@ -43,11 +44,35 @@ class ImageLoader:
CACHE_SIZE_MAXIMUM = int(os.environ.get("DYN_MM_IMAGE_CACHE_SIZE", "8"))
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
self,
cache_size: int = CACHE_SIZE_MAXIMUM,
http_timeout: float = 30.0,
enable_frontend_decoding: bool = False,
):
"""
Initialize the ImageLoader with caching, HTTP settings, and optional NIXL config for
receiving frontend decoding.
Args:
cache_size: Maximum number of images to store in the in-memory LRU cache.
Defaults to CACHE_SIZE_MAXIMUM.
http_timeout: Timeout in seconds for HTTP requests when fetching remote images.
Defaults to 30.0 seconds.
enable_frontend_decoding: If True, enables NIXL RDMA for transferring
decoded images directly from frontend memory, bypassing standard
network transport. Defaults to False.
"""
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
self._enable_frontend_decoding = enable_frontend_decoding
# Lazy-init NIXL connector only when frontend decoding is enabled
self._nixl_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(
self._nixl_connector.initialize
) # Synchronously wait for async init
@_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image:
......@@ -137,8 +162,6 @@ class ImageLoader:
async def load_image_batch(
self,
image_mm_items: List[Dict[str, Any]],
enable_frontend_decoding: bool = False,
nixl_connector: Optional["nixl_connect.Connector"] = None,
) -> List[Any]:
"""
Load a batch of images from multimodal data items.
......@@ -149,8 +172,6 @@ class ImageLoader:
Args:
image_mm_items: List of multimodal data items for images
enable_frontend_decoding: If True, enables NIXL RDMA for decoded images
nixl_connector: NIXL connector for frontend decoding (required if enable_frontend_decoding=True)
Returns:
List of loaded image data
......@@ -168,19 +189,10 @@ class ImageLoader:
image_futures.append(self.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if enable_frontend_decoding:
if nixl_connector is None:
logger.error(
"Frontend decoding enabled but nixl_connector not provided. "
"Caller must pass an initialized NIXL connector."
)
raise ValueError(
"nixl_connector required when enable_frontend_decoding=True"
)
if self._enable_frontend_decoding:
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(nixl_connector, metadata)
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
else:
logger.error(
......
......@@ -6,8 +6,9 @@ Common runtime utilities shared across Dynamo engine backends.
Provides:
- parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings
- graceful_shutdown: Shutdown DistributedRuntime with optional event signaling
- create_runtime: Create DistributedRuntime.
- run_async: Helper to run async functions in non-async functions that
may be run in either sync or async context.
"""
import asyncio
......@@ -70,3 +71,28 @@ def create_runtime(
runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats)
return runtime, loop
def run_async(func, *args, **kwargs):
"""Run an async function as if it is synchronous, handling both sync and async contexts.
Args:
func: An async function to execute.
*args: Positional arguments to pass to the function.
**kwargs: Keyword arguments to pass to the function.
Returns:
The result of the async function.
"""
# Check if we're in async context, exception is raised if not and we can safely
# run 'func' with asyncio.run()
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(func(*args, **kwargs))
# In an async context, we want to run 'func' in a separate thread to avoid blocking the event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, func(*args, **kwargs)).result()
......@@ -23,7 +23,6 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
......@@ -352,14 +351,9 @@ class BaseWorkerHandler(ABC):
self.generate_endpoint = generate_endpoint
self.config = config
self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len
self.enable_multimodal = enable_multimodal
self.enable_frontend_decoding = enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
self._nixl_connector: nixl_connect.Connector | None = None
self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
......@@ -367,6 +361,10 @@ class BaseWorkerHandler(ABC):
# Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock()
self.image_loader = ImageLoader(
enable_frontend_decoding=enable_frontend_decoding
)
self.use_vllm_tokenizer = use_vllm_tokenizer
self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config)
......@@ -1014,18 +1012,9 @@ class BaseWorkerHandler(ABC):
mm_map = request["multi_modal_data"]
vllm_mm_data = {}
# Lazy-init NIXL connector only when frontend decoding is enabled
if self.enable_frontend_decoding:
async with self._nixl_connector_lock:
if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
# Process image_url entries
images = await self.image_loader.load_image_batch(
mm_map.get(IMAGE_URL_KEY, []),
enable_frontend_decoding=self.enable_frontend_decoding,
nixl_connector=self._nixl_connector,
)
if images:
......
......@@ -4,14 +4,12 @@
import copy
import logging
import uuid
from collections import defaultdict
from typing import Any
import torch
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
......@@ -34,7 +32,7 @@ from ..multimodal_utils import (
vLLMMultimodalRequest,
)
from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
from ..multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
logger = logging.getLogger(__name__)
......@@ -71,17 +69,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
)
self.config = config
self.encode_worker_client = encode_worker_client
self.decode_worker_client = decode_worker_client
self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
# Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.")
......@@ -91,12 +80,13 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
else:
self.EMBEDDINGS_DTYPE = torch.float16
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
# Embedding loader consist of two main components:
# 1) An remote encode worker client and matching embedding receiver,
# which can request remote encode and handle the transfer of embeddings
# from the encode worker to this prefill worker.
# 2) A local embedding cache manager, which can store previously fetched embeddings
# and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
......@@ -109,13 +99,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
raise ValueError(
f"Invalid embedding transfer mode: {config.embedding_transfer_mode}"
)
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
self.embedding_loader = MultiModalEmbeddingLoader(
encode_worker_client=self.encode_worker_client, # type: ignore
receiver=self.embedding_receiver,
embedding_cache_manager=self.embedding_cache_manager,
)
logger.info("Multimodal PD Worker has been initialized")
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.")
def _parse_frontend_request(
......@@ -164,17 +165,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
Returns an empty dict when no encode worker is configured or no images
are present.
"""
if not self.encode_worker_client or not image_urls:
return defaultdict(list)
return await load_multimodal_embeddings(
self.encode_worker_client, # type: ignore[arg-type]
return await self.embedding_loader.load_multimodal_embeddings(
image_urls,
request_id,
self.embedding_receiver,
model=self.config.model,
embeddings_dtype=self.EMBEDDINGS_DTYPE,
cache=self.embedding_cache_manager,
context=context,
)
......@@ -212,10 +208,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if not value:
del multi_modal_data[key]
else:
# [gluo FIXME] should be mindful to default dict, move this evaluation logic to here
# so that we don't accidentally add empty keys to the dict which causes vLLM misbehavior
logger.debug(
f"Prepared multimodal data size: {len(multi_modal_data[key])}"
f"Prepared multimodal data key {key}, number of items: {len(multi_modal_data[key])}"
)
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
......
......@@ -19,7 +19,7 @@ from dynamo.vllm.multimodal_utils.model import (
construct_mm_data,
load_vision_model,
)
from dynamo.vllm.multimodal_utils.prefill_worker_utils import load_multimodal_embeddings
from dynamo.vllm.multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
......@@ -48,5 +48,5 @@ __all__ = [
"MultiModalRequest",
"MyRequestOutput",
"vLLMMultimodalRequest",
"load_multimodal_embeddings",
"MultiModalEmbeddingLoader",
]
......@@ -4,7 +4,6 @@
import asyncio
import logging
import os
from collections import defaultdict
from typing import Any, Dict, List
import torch
......@@ -92,12 +91,12 @@ def _accumulate_embeddings(
image_grid_thw=image_grid_thw,
)
if "image" not in multi_modal_data:
multi_modal_data["image"] = mm_data["image"]
return
if isinstance(mm_data["image"], dict):
# Qwen-VL style: dict with image_embeds + image_grid_thw tensors
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
# [gluo FIXME] need to understand how Qwen consumes multi-image embeddings
multi_modal_data["image"]["image_embeds"] = torch.cat(
(
multi_modal_data["image"]["image_embeds"],
......@@ -110,14 +109,14 @@ def _accumulate_embeddings(
mm_data["image"]["image_grid_thw"],
)
)
else:
# [gluo FIXME] embedding with multiple images?
if multi_modal_data["image"] == []:
multi_modal_data["image"] = mm_data["image"]
else:
elif isinstance(mm_data["image"], torch.Tensor):
multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"])
)
else:
raise ValueError(
f"Unexpected image data format from construct_mm_data: {type(mm_data['image'])}"
)
def _ensure_owned_tensors(multi_modal_data: Dict[str, Any]) -> None:
......@@ -287,17 +286,28 @@ async def _fetch_embeddings(
# ── Public API (single entry point) ─────────────────────────────────
async def load_multimodal_embeddings(
class MultiModalEmbeddingLoader:
"""Helper class for requesting remote encode and receive embeddings."""
def __init__(
self,
encode_worker_client: Client,
receiver: AbstractEmbeddingReceiver,
embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None,
):
self._encode_worker_client = encode_worker_client
self._receiver = receiver
self._embedding_cache_manager = embedding_cache_manager
async def load_multimodal_embeddings(
self,
image_urls: list[str],
request_id: str,
receiver: AbstractEmbeddingReceiver,
*,
model: str,
embeddings_dtype: torch.dtype,
cache: MultimodalEmbeddingCacheManager | None = None,
context=None,
) -> Dict[str, Any]:
) -> Dict[str, Any]:
"""Fetch embeddings and build engine-ready ``multi_modal_data``.
Full pipeline:
......@@ -305,16 +315,19 @@ async def load_multimodal_embeddings(
Returns a dict suitable for passing to ``TokensPrompt(multi_modal_data=...)``.
"""
if not self._encode_worker_client or not image_urls:
return {}
groups, pending = await _fetch_embeddings(
encode_worker_client,
self._encode_worker_client,
image_urls,
request_id,
receiver,
cache=cache,
self._receiver,
cache=self._embedding_cache_manager,
context=context,
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
multi_modal_data: Dict[str, Any] = {}
with time_and_log_code_section(
f"[PREFILL] request: {request_id} accumulate embeddings"
):
......
......@@ -182,7 +182,7 @@ class TestLoadMultimodalData:
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) # type: ignore
with patch.object(
mod,
handler.embedding_loader,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=fake_mm_data,
......@@ -192,24 +192,6 @@ class TestLoadMultimodalData:
mock_load.assert_awaited_once()
assert result is fake_mm_data
@pytest.mark.asyncio
async def test_passes_cache_to_load_multimodal_embeddings(self):
"""With cache enabled -> passes cache manager kwarg."""
mock_client = MagicMock()
config = _make_config(multimodal_embedding_cache_capacity_gb=1.0)
handler = _make_handler(config=config, encode_worker_client=mock_client)
with patch.object(
mod,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=defaultdict(list),
) as mock_load:
await handler._load_multimodal_data(["http://img.png"], "req-1")
mock_load.assert_awaited_once()
assert mock_load.call_args.kwargs["cache"] is handler.embedding_cache_manager
@pytest.mark.asyncio
async def test_passes_model_and_dtype(self):
"""Model name and embeddings dtype are forwarded."""
......@@ -217,7 +199,7 @@ class TestLoadMultimodalData:
handler = _make_handler(encode_worker_client=mock_client)
with patch.object(
mod,
handler.embedding_loader,
"load_multimodal_embeddings",
new_callable=AsyncMock,
return_value=defaultdict(list),
......
......@@ -26,7 +26,7 @@ MODEL = "test-model"
DTYPE = torch.float16
class TestLoadMultimodalEmbeddings:
class TestMultimodalEmbeddingsLoader:
@pytest.mark.asyncio
async def test_all_cached(self):
"""All URLs cached -> no encode worker call, returns accumulated mm_data."""
......@@ -42,14 +42,12 @@ class TestLoadMultimodalEmbeddings:
"_fetch_from_encode_workers",
new_callable=AsyncMock,
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
mm_data = await embedding_loader.load_multimodal_embeddings(
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_not_awaited()
......@@ -73,14 +71,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
mm_data = await embedding_loader.load_multimodal_embeddings(
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
......@@ -107,14 +103,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, None)
mm_data = await embedding_loader.load_multimodal_embeddings(
[url],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=None,
)
mock_fetch.assert_awaited_once()
......@@ -145,14 +139,12 @@ class TestLoadMultimodalEmbeddings:
new_callable=AsyncMock,
return_value=([fake_group], None),
) as mock_fetch:
mm_data = await mod.load_multimodal_embeddings(
AsyncMock(),
embedding_loader = mod.MultiModalEmbeddingLoader(AsyncMock(), None, cache)
mm_data = await embedding_loader.load_multimodal_embeddings(
[url_cached, url_miss],
"req-1",
None,
model=MODEL,
embeddings_dtype=DTYPE,
cache=cache,
)
mock_fetch.assert_awaited_once()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment