Unverified Commit ac50dccf authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: batch process images in encode worker. Add qwen3 to supported models (#6021)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 3e417022
...@@ -6,9 +6,11 @@ import logging ...@@ -6,9 +6,11 @@ import logging
import os import os
import shutil import shutil
import time import time
from dataclasses import dataclass
from typing import AsyncGenerator, AsyncIterator from typing import AsyncGenerator, AsyncIterator
import safetensors import safetensors
import torch
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
...@@ -23,11 +25,12 @@ from ..multimodal_utils import ( ...@@ -23,11 +25,12 @@ from ..multimodal_utils import (
VLLMNativeEncoderRequest, VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse, VLLMNativeEncoderResponse,
encode_image_embeddings, encode_image_embeddings,
get_embedding_hash,
get_encoder_components, get_encoder_components,
load_vision_model, load_vision_model,
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
from ..multimodal_utils.embedding_cache import EmbeddingCache
from ..multimodal_utils.model import is_qwen_vl_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -49,6 +52,13 @@ CACHE_SIZE_MAXIMUM = 8 ...@@ -49,6 +52,13 @@ CACHE_SIZE_MAXIMUM = 8
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
@dataclass
class EmbeddingItem:
key: str
image_grid_thw: list
embeddings_cpu: torch.Tensor
class EncodeWorkerHandler: class EncodeWorkerHandler:
def __init__( def __init__(
self, self,
...@@ -74,7 +84,7 @@ class EncodeWorkerHandler: ...@@ -74,7 +84,7 @@ class EncodeWorkerHandler:
self._accumulated_time = 0.0 self._accumulated_time = 0.0
self._processed_requests = 0 self._processed_requests = 0
self.readables = [] self.readables = []
self.cached_embeddings = {} self.embedding_cache = EmbeddingCache()
def cleanup(self): def cleanup(self):
pass pass
...@@ -112,35 +122,56 @@ class EncodeWorkerHandler: ...@@ -112,35 +122,56 @@ class EncodeWorkerHandler:
try: try:
time_start = time.perf_counter() time_start = time.perf_counter()
# Before batch process images, check cache first
need_encode_indexes = []
embedding_lists = [None] * len(request.multimodal_inputs)
for idx in range(len(request.multimodal_inputs)): for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url: if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.") raise ValueError("image_url is required for the encode worker.")
image_url = request.multimodal_inputs[idx].multimodal_input.image_url image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache # see if we have local cache
if image_url in self.cached_embeddings: embedding_key = self.embedding_cache.generate_hash_key(image_url)
( if self.embedding_cache.has_key(embedding_key):
embedding_key, (image_grid_thw, embeddings_cpu) = self.embedding_cache.get(
image_grid_thw, embedding_key
embeddings_shape, )
) = self.cached_embeddings[image_url] embedding_lists[idx] = EmbeddingItem(
# [gluo FIXME] need mechanism to clean up local files embedding_key, image_grid_thw, embeddings_cpu
request.multimodal_inputs[ )
idx # compute
].serialized_request = ( else:
f"/tmp/encoder_cache.{embedding_key}.safetensors" # keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key))
# Load and generate image tensors
image_futures = []
image_to_load = []
for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url
image_futures.append(self.image_loader.load_image(url))
image_to_load.append(url)
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for i, result in enumerate(results):
if isinstance(result, Exception):
url = image_to_load[i]
logger.error(f"Failed to load image from {url[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {url[:80]}...: {result}\n"
) )
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = embeddings_shape
continue continue
loaded_images.append(result)
if collective_exceptions:
raise ValueError(
f"Errors occurred during image loading:\n{collective_exceptions}"
)
image = await self.image_loader.load_image(image_url) if loaded_images:
image_embeds = self.image_processor(
logger.debug( images=loaded_images, return_tensors="pt"
f"Processing image {image_url} for request: {{ id: {request_id} }}"
) )
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder # Encode the image embeddings using model-specific encoder
embeddings = await asyncio.to_thread( embeddings = await asyncio.to_thread(
...@@ -151,46 +182,76 @@ class EncodeWorkerHandler: ...@@ -151,46 +182,76 @@ class EncodeWorkerHandler:
projector=self.projector, projector=self.projector,
) )
# [gluo FIXME] This is specific to qwen vision processing..
# Split concatenated embeddings for each image item.
if is_qwen_vl_model(self.model):
merge_size = self.vision_encoder.spatial_merge_size
sizes = (
image_embeds["image_grid_thw"].prod(-1)
// merge_size
// merge_size
).tolist()
splitted_embeddings = embeddings.cpu().squeeze(0).split(sizes)
logger.debug(
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}"
)
else:
# Validated on llava (NOTE need to double check on other models) that the
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings.cpu()
image_grid_thw = ( image_grid_thw = (
image_embeds["image_grid_thw"].tolist() image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds if "image_grid_thw" in image_embeds
else None else None
) )
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues # fill in the embedding_lists with new computed embeddings and cache them
embeddings_cpu = embeddings.cpu() for split_idx, (list_idx, key) in enumerate(need_encode_indexes):
embedding_lists[list_idx] = EmbeddingItem(
key,
[image_grid_thw[split_idx]] if image_grid_thw else None,
splitted_embeddings[split_idx].unsqueeze(0),
)
# Cache the computed value for future use
self.embedding_cache.set(
embedding_lists[list_idx].key,
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu,
),
)
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw for idx, embedding_item in enumerate(embedding_lists):
# Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[
idx
].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple( request.multimodal_inputs[idx].embeddings_shape = tuple(
embeddings.shape embedding_item.embeddings_cpu.shape
) )
# Prepare transfer
if TRANSFER_LOCAL: if TRANSFER_LOCAL:
embedding_key = get_embedding_hash(image_url)
logger.debug( logger.debug(
f"ENCODER: saving local safetensors file with key {embedding_key}, {embeddings_cpu.numel()} * {embeddings_cpu.element_size()} bytes" f"ENCODER: saving local safetensors file with key {embedding_item.key}, {embedding_item.embeddings_cpu.numel()} * {embedding_item.embeddings_cpu.element_size()} bytes"
) )
tensors = {"ec_cache": embeddings_cpu} tensors = {"ec_cache": embedding_item.embeddings_cpu}
safetensors.torch.save_file( safetensors.torch.save_file(
tensors, f"/tmp/encoder_cache.{embedding_key}.safetensors" tensors,
f"/tmp/encoder_cache.{embedding_item.key}.safetensors",
) )
# [gluo FIXME] need mechanism to clean up local files # [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[ request.multimodal_inputs[
idx idx
].serialized_request = ( ].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors" f"/tmp/encoder_cache.{embedding_item.key}.safetensors"
)
self.cached_embeddings[image_url] = (
embedding_key,
request.multimodal_inputs[idx].image_grid_thw,
request.multimodal_inputs[idx].embeddings_shape,
) )
else: else:
# [gluo FIXME] nixl_connector path needs to be update to handle multiple embeddings descriptor = connect.Descriptor(embedding_item.embeddings_cpu)
descriptor = connect.Descriptor(embeddings_cpu)
self.readables.append( self.readables.append(
await self._connector.create_readable(descriptor) await self._connector.create_readable(descriptor)
) )
...@@ -198,9 +259,6 @@ class EncodeWorkerHandler: ...@@ -198,9 +259,6 @@ class EncodeWorkerHandler:
-1 -1
].metadata() ].metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_inputs[idx].multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
time_end = time.perf_counter() time_end = time.perf_counter()
......
...@@ -51,13 +51,32 @@ class PreprocessedHandler(ProcessMixIn): ...@@ -51,13 +51,32 @@ class PreprocessedHandler(ProcessMixIn):
pd_worker_client: Client, pd_worker_client: Client,
): ):
self.encode_worker_client = encode_worker_client self.encode_worker_client = encode_worker_client
self.encode_worker_count = 0
self.pd_worker_client = pd_worker_client self.pd_worker_client = pd_worker_client
self.engine_args = engine_args self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config() self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.stop = False
self._worker_count_task = asyncio.create_task(
self._update_encode_worker_count()
)
async def _update_encode_worker_count(self):
"""
Periodically updates the count of available encode workers.
"""
while self.stop is False:
try:
self.encode_worker_count = len(self.encode_worker_client.instance_ids())
logger.debug(f"Updated encode worker count: {self.encode_worker_count}")
except Exception as e:
logger.error(f"Failed to update encode worker count: {e}")
await asyncio.sleep(1) # Update every 1 second
def cleanup(self): def cleanup(self):
pass self.stop = True
if hasattr(self, "_worker_count_task"):
self._worker_count_task.cancel()
# Main method to parse the request and send the request to the vllm worker. # Main method to parse the request and send the request to the vllm worker.
async def _generate( async def _generate(
...@@ -85,8 +104,17 @@ class PreprocessedHandler(ProcessMixIn): ...@@ -85,8 +104,17 @@ class PreprocessedHandler(ProcessMixIn):
multimodal_inputs=[], multimodal_inputs=[],
) )
# [gluo WIP] experiment with batching.. # [gluo WIP] batching helps for encoding step to fully utilize GPU,
ENCODE_BATCH_SIZE = 1 # should handle dispatch in a more intelligent way, i.e. splitting
# jobs based on availability of encode worker, rather than fixed mm
# mm item size per request. Also need to consider encoding load and
# balancing it between encoders.
if self.encode_worker_count == 0:
raise RuntimeError(
"No encode workers available to process multimodal input"
)
total_items = sum(len(urls) for urls in multimodal_inputs.values())
encode_batch_size = max(1, total_items // self.encode_worker_count)
encode_res_gen = [] encode_res_gen = []
for mm_type, urls in multimodal_inputs.items(): for mm_type, urls in multimodal_inputs.items():
for url in urls: for url in urls:
...@@ -101,7 +129,7 @@ class PreprocessedHandler(ProcessMixIn): ...@@ -101,7 +129,7 @@ class PreprocessedHandler(ProcessMixIn):
MultiModalGroup(multimodal_input=multimodal_input) MultiModalGroup(multimodal_input=multimodal_input)
) )
if len(encode_request.multimodal_inputs) >= ENCODE_BATCH_SIZE: if len(encode_request.multimodal_inputs) >= encode_batch_size:
# model_dump_json() serializes the request to JSON string # model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams # This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will # in vLLMMultimodalRequest is not a Pydantic class and will
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import hashlib
class EmbeddingCache:
def __init__(self):
# Initialize an empty dictionary to store key-value pairs
self.cache = {}
def generate_hash_key(self, *args):
"""
Generate a hashable key based on the provided arguments.
Args:
*args: A variable number of arguments to generate the key.
Returns:
A string representing the hashable key.
"""
key = hashlib.sha256()
for arg in args:
key.update(str(arg).encode("utf-8"))
return key.hexdigest()
def has_key(self, key):
"""
Check if a key exists in the cache.
Args:
key: The key to check.
Returns:
True if the key exists in the cache, False otherwise.
"""
return key in self.cache
def set(self, key, value):
"""
Store a key-value pair in the cache.
Args:
key: The key to store the value under.
value: The value to store, expected to be a tuple.
"""
self.cache[key] = value
def get(self, key):
"""
Retrieve the value associated with a key.
Args:
key: The key to look up.
Returns:
The value (tuple) associated with the key, or None if the key is not found.
"""
return self.cache.get(key)
...@@ -38,6 +38,7 @@ class SupportedModels: ...@@ -38,6 +38,7 @@ class SupportedModels:
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct" QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct" QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
...@@ -116,6 +117,7 @@ QWEN_VL_MODELS = [ ...@@ -116,6 +117,7 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_5_VL_3B, SupportedModels.QWEN_2_5_VL_3B,
SupportedModels.QWEN_2_5_VL_7B, SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B, SupportedModels.QWEN_2_5_VL_32B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8,
] ]
...@@ -145,6 +147,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -145,6 +147,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0", "VLLM_ENABLE_V1_MULTIPROCESSING": "0",
} }
) )
# [NOTE] For vLLM pre-0.15.0, see https://github.com/vllm-project/vllm/pull/32605 for enhancement after 0.15.0
#
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage. # Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242. # Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work. # Model needs the class method get_language_model_spec to be defined for this to work.
...@@ -157,6 +161,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -157,6 +161,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
) )
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
@classmethod @classmethod
def get_language_model_spec(cls): def get_language_model_spec(cls):
...@@ -169,6 +175,14 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -169,6 +175,14 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
get_language_model_spec get_language_model_spec
) )
@classmethod
def get_language_model_spec(cls):
return (Qwen3ForCausalLM, "language_model")
Qwen3VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM # Load only the vision model via vLLM
vllm_model = LLM( vllm_model = LLM(
model=model_id, model=model_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment