"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "98fcba1575da8d80e47d0540898015d2906d4720"
Unverified Commit ac50dccf authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: batch process images in encode worker. Add qwen3 to supported models (#6021)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 3e417022
......@@ -6,9 +6,11 @@ import logging
import os
import shutil
import time
from dataclasses import dataclass
from typing import AsyncGenerator, AsyncIterator
import safetensors
import torch
from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
......@@ -23,11 +25,12 @@ from ..multimodal_utils import (
VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse,
encode_image_embeddings,
get_embedding_hash,
get_encoder_components,
load_vision_model,
vLLMMultimodalRequest,
)
from ..multimodal_utils.embedding_cache import EmbeddingCache
from ..multimodal_utils.model import is_qwen_vl_model
logger = logging.getLogger(__name__)
......@@ -49,6 +52,13 @@ CACHE_SIZE_MAXIMUM = 8
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
@dataclass
class EmbeddingItem:
key: str
image_grid_thw: list
embeddings_cpu: torch.Tensor
class EncodeWorkerHandler:
def __init__(
self,
......@@ -74,7 +84,7 @@ class EncodeWorkerHandler:
self._accumulated_time = 0.0
self._processed_requests = 0
self.readables = []
self.cached_embeddings = {}
self.embedding_cache = EmbeddingCache()
def cleanup(self):
pass
......@@ -112,35 +122,56 @@ class EncodeWorkerHandler:
try:
time_start = time.perf_counter()
# Before batch process images, check cache first
need_encode_indexes = []
embedding_lists = [None] * len(request.multimodal_inputs)
for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image_url = request.multimodal_inputs[idx].multimodal_input.image_url
# see if we have local cache
if image_url in self.cached_embeddings:
(
embedding_key,
image_grid_thw,
embeddings_shape,
) = self.cached_embeddings[image_url]
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
embedding_key = self.embedding_cache.generate_hash_key(image_url)
if self.embedding_cache.has_key(embedding_key):
(image_grid_thw, embeddings_cpu) = self.embedding_cache.get(
embedding_key
)
embedding_lists[idx] = EmbeddingItem(
embedding_key, image_grid_thw, embeddings_cpu
)
# compute
else:
# keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key))
# Load and generate image tensors
image_futures = []
image_to_load = []
for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url
image_futures.append(self.image_loader.load_image(url))
image_to_load.append(url)
results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = []
collective_exceptions = ""
for i, result in enumerate(results):
if isinstance(result, Exception):
url = image_to_load[i]
logger.error(f"Failed to load image from {url[:80]}...: {result}")
collective_exceptions += (
f"Failed to load image from {url[:80]}...: {result}\n"
)
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = embeddings_shape
continue
loaded_images.append(result)
if collective_exceptions:
raise ValueError(
f"Errors occurred during image loading:\n{collective_exceptions}"
)
image = await self.image_loader.load_image(image_url)
logger.debug(
f"Processing image {image_url} for request: {{ id: {request_id} }}"
if loaded_images:
image_embeds = self.image_processor(
images=loaded_images, return_tensors="pt"
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
# Encode the image embeddings using model-specific encoder
embeddings = await asyncio.to_thread(
......@@ -151,46 +182,76 @@ class EncodeWorkerHandler:
projector=self.projector,
)
# [gluo FIXME] This is specific to qwen vision processing..
# Split concatenated embeddings for each image item.
if is_qwen_vl_model(self.model):
merge_size = self.vision_encoder.spatial_merge_size
sizes = (
image_embeds["image_grid_thw"].prod(-1)
// merge_size
// merge_size
).tolist()
splitted_embeddings = embeddings.cpu().squeeze(0).split(sizes)
logger.debug(
f"Splitted embeddings lengths: {[e.shape for e in splitted_embeddings]}"
)
else:
# Validated on llava (NOTE need to double check on other models) that the
# embeddings already has batch dimension for images, so we can directly
# split by batch dimension
logger.debug(f"image embedding shape: {embeddings.shape}")
splitted_embeddings = embeddings.cpu()
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
logger.debug(
f"Pixel values stats: mean={image_embeds['pixel_values'].mean().item()}, std={image_embeds['pixel_values'].std().item()}, min={image_embeds['pixel_values'].min().item()}, max={image_embeds['pixel_values'].max().item()}"
)
# Move embeddings to CPU for NIXL transfer to avoid UCX/InfiniBand issues
embeddings_cpu = embeddings.cpu()
# fill in the embedding_lists with new computed embeddings and cache them
for split_idx, (list_idx, key) in enumerate(need_encode_indexes):
embedding_lists[list_idx] = EmbeddingItem(
key,
[image_grid_thw[split_idx]] if image_grid_thw else None,
splitted_embeddings[split_idx].unsqueeze(0),
)
# Cache the computed value for future use
self.embedding_cache.set(
embedding_lists[list_idx].key,
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings_cpu,
),
)
request.multimodal_inputs[idx].image_grid_thw = image_grid_thw
for idx, embedding_item in enumerate(embedding_lists):
# Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[
idx
].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple(
embeddings.shape
embedding_item.embeddings_cpu.shape
)
# Prepare transfer
if TRANSFER_LOCAL:
embedding_key = get_embedding_hash(image_url)
logger.debug(
f"ENCODER: saving local safetensors file with key {embedding_key}, {embeddings_cpu.numel()} * {embeddings_cpu.element_size()} bytes"
f"ENCODER: saving local safetensors file with key {embedding_item.key}, {embedding_item.embeddings_cpu.numel()} * {embedding_item.embeddings_cpu.element_size()} bytes"
)
tensors = {"ec_cache": embeddings_cpu}
tensors = {"ec_cache": embedding_item.embeddings_cpu}
safetensors.torch.save_file(
tensors, f"/tmp/encoder_cache.{embedding_key}.safetensors"
tensors,
f"/tmp/encoder_cache.{embedding_item.key}.safetensors",
)
# [gluo FIXME] need mechanism to clean up local files
request.multimodal_inputs[
idx
].serialized_request = (
f"/tmp/encoder_cache.{embedding_key}.safetensors"
)
self.cached_embeddings[image_url] = (
embedding_key,
request.multimodal_inputs[idx].image_grid_thw,
request.multimodal_inputs[idx].embeddings_shape,
f"/tmp/encoder_cache.{embedding_item.key}.safetensors"
)
else:
# [gluo FIXME] nixl_connector path needs to be update to handle multiple embeddings
descriptor = connect.Descriptor(embeddings_cpu)
descriptor = connect.Descriptor(embedding_item.embeddings_cpu)
self.readables.append(
await self._connector.create_readable(descriptor)
)
......@@ -198,9 +259,6 @@ class EncodeWorkerHandler:
-1
].metadata()
# Clear the image URL as hint that the image is passed as embeddings.
request.multimodal_inputs[idx].multimodal_input.image_url = None
logger.debug(f"Request: {request.model_dump_json()}")
time_end = time.perf_counter()
......
......@@ -51,13 +51,32 @@ class PreprocessedHandler(ProcessMixIn):
pd_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.encode_worker_count = 0
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.stop = False
self._worker_count_task = asyncio.create_task(
self._update_encode_worker_count()
)
async def _update_encode_worker_count(self):
"""
Periodically updates the count of available encode workers.
"""
while self.stop is False:
try:
self.encode_worker_count = len(self.encode_worker_client.instance_ids())
logger.debug(f"Updated encode worker count: {self.encode_worker_count}")
except Exception as e:
logger.error(f"Failed to update encode worker count: {e}")
await asyncio.sleep(1) # Update every 1 second
def cleanup(self):
pass
self.stop = True
if hasattr(self, "_worker_count_task"):
self._worker_count_task.cancel()
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
......@@ -85,8 +104,17 @@ class PreprocessedHandler(ProcessMixIn):
multimodal_inputs=[],
)
# [gluo WIP] experiment with batching..
ENCODE_BATCH_SIZE = 1
# [gluo WIP] batching helps for encoding step to fully utilize GPU,
# should handle dispatch in a more intelligent way, i.e. splitting
# jobs based on availability of encode worker, rather than fixed mm
# mm item size per request. Also need to consider encoding load and
# balancing it between encoders.
if self.encode_worker_count == 0:
raise RuntimeError(
"No encode workers available to process multimodal input"
)
total_items = sum(len(urls) for urls in multimodal_inputs.values())
encode_batch_size = max(1, total_items // self.encode_worker_count)
encode_res_gen = []
for mm_type, urls in multimodal_inputs.items():
for url in urls:
......@@ -101,7 +129,7 @@ class PreprocessedHandler(ProcessMixIn):
MultiModalGroup(multimodal_input=multimodal_input)
)
if len(encode_request.multimodal_inputs) >= ENCODE_BATCH_SIZE:
if len(encode_request.multimodal_inputs) >= encode_batch_size:
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import hashlib
class EmbeddingCache:
def __init__(self):
# Initialize an empty dictionary to store key-value pairs
self.cache = {}
def generate_hash_key(self, *args):
"""
Generate a hashable key based on the provided arguments.
Args:
*args: A variable number of arguments to generate the key.
Returns:
A string representing the hashable key.
"""
key = hashlib.sha256()
for arg in args:
key.update(str(arg).encode("utf-8"))
return key.hexdigest()
def has_key(self, key):
"""
Check if a key exists in the cache.
Args:
key: The key to check.
Returns:
True if the key exists in the cache, False otherwise.
"""
return key in self.cache
def set(self, key, value):
"""
Store a key-value pair in the cache.
Args:
key: The key to store the value under.
value: The value to store, expected to be a tuple.
"""
self.cache[key] = value
def get(self, key):
"""
Retrieve the value associated with a key.
Args:
key: The key to look up.
Returns:
The value (tuple) associated with the key, or None if the key is not found.
"""
return self.cache.get(key)
......@@ -38,6 +38,7 @@ class SupportedModels:
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
QWEN_3_VL_30B_A3B_FP8 = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
......@@ -116,6 +117,7 @@ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_5_VL_3B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B,
SupportedModels.QWEN_3_VL_30B_A3B_FP8,
]
......@@ -145,6 +147,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0",
}
)
# [NOTE] For vLLM pre-0.15.0, see https://github.com/vllm-project/vllm/pull/32605 for enhancement after 0.15.0
#
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work.
......@@ -157,6 +161,8 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
Qwen2_5_VLForConditionalGeneration,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
@classmethod
def get_language_model_spec(cls):
......@@ -169,6 +175,14 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
get_language_model_spec
)
@classmethod
def get_language_model_spec(cls):
return (Qwen3ForCausalLM, "language_model")
Qwen3VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM
vllm_model = LLM(
model=model_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment