Unverified Commit 0b7e1271 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: [vLLM multimodel] launch image loading in parallel (#5444)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 7ebd5f82
......@@ -757,6 +757,48 @@ class BaseWorkerHandler(ABC):
return prompt, sequence_length, embeddings_tensor
async def _load_image_batch(
self, image_mm_items: list[Dict[str, Any]]
) -> list[Any]:
"""
Load a batch of images from multimodal data items.
Args:
image_mm_items: List of multimodal data items for images
Returns:
List of loaded image data
Raises:
Exception: If any image fails to load
"""
image_futures = []
for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
image_futures.append(self.image_loader.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
)
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for i, result in enumerate(results):
if isinstance(result, Exception):
url = image_mm_items[i].get(URL_VARIANT_KEY, "unknown")
logger.error(f"Failed to load image from {url[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {url[:80]}...: {result}\n"
)
continue
loaded_images.append(result)
if collective_exceptions:
raise Exception(collective_exceptions)
return loaded_images
async def _extract_multimodal_data(
self, request: Dict[str, Any]
) -> Dict[str, Any] | None:
......@@ -777,25 +819,7 @@ class BaseWorkerHandler(ABC):
vllm_mm_data = {}
# Process image_url entries
images = []
for item in mm_map.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
url = item[URL_VARIANT_KEY]
try:
# ImageLoader supports both data: and http(s): URLs with caching
image = await self.image_loader.load_image(url)
images.append(image)
logger.debug(f"Loaded image from URL: {url[:80]}...")
except Exception:
logger.exception(f"Failed to load image from {url[:80]}...")
raise
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
)
images = await self._load_image_batch(mm_map.get(IMAGE_URL_KEY, []))
if images:
# vLLM expects single image or list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment