Unverified Commit d517fb80 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat(vllm): compute Qwen VL grid_thw from PIL images for P/D disagg (#7885)


Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent c09abf7d
...@@ -56,6 +56,10 @@ from .constants import DisaggregationMode, EmbeddingTransferMode ...@@ -56,6 +56,10 @@ from .constants import DisaggregationMode, EmbeddingTransferMode
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from .multimodal_utils.models.qwen import (
build_qwen_embedding_params,
load_qwen_grid_params,
)
from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader from .multimodal_utils.prefill_worker_utils import MultiModalEmbeddingLoader
# Multimodal data dictionary keys # Multimodal data dictionary keys
...@@ -1608,6 +1612,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1608,6 +1612,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
request_id, request_id,
) )
elif has_mm_data and request["multi_modal_data"].get(IMAGE_URL_KEY): elif has_mm_data and request["multi_modal_data"].get(IMAGE_URL_KEY):
# Guard is on IMAGE_URL_KEY (not just has_mm_data) so
# text-only requests pass through and video/audio fall
# through to re-download below (TODO: proper support).
msg = ( msg = (
"Decode worker received multimodal request without " "Decode worker received multimodal request without "
"prefill result" "prefill result"
...@@ -1818,6 +1825,20 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1818,6 +1825,20 @@ class PrefillWorkerHandler(BaseWorkerHandler):
encode_worker_client, encode_worker_client,
) )
# Cache Qwen VL grid parameters for computing image_grid_thw from
# PIL images in the P/D path (no separate encode worker).
if is_qwen_vl_model(config.model):
self._qwen_grid_params = load_qwen_grid_params(config.model)
if self._qwen_grid_params is None and self.embedding_loader is None:
logger.error(
"Qwen VL grid params failed to load and no encode worker "
"is configured. P/D multimodal requests will fail because "
"prefill cannot produce embedding_params for decode. "
"Use --route-to-encoder or ensure the model is cached."
)
else:
self._qwen_grid_params = None
async def generate(self, request, context): async def generate(self, request, context):
# Use context ID for request tracking and correlation with decode phase # Use context ID for request tracking and correlation with decode phase
request_id = context.id() request_id = context.id()
...@@ -1948,25 +1969,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1948,25 +1969,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def _build_embedding_params( def _build_embedding_params(
self, multi_modal_data: dict[str, Any] self, multi_modal_data: dict[str, Any]
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
""" if not is_qwen_vl_model(self.config.model):
Helper function to build mm embedding parameters to be consumed by the decode worker, typically return None
decode worker doesn't require any metadata for mm embedding as the content has been consumed by return build_qwen_embedding_params(multi_modal_data, self._qwen_grid_params)
prefill. However, especially found for Qwen models, vLLM's processor will try to expand image
tokens in the prompt which requires such a metadata to pass through the processor.
"""
embedding_params = {}
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
embedding_params["image_grid_thw"] = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
embedding_params["embeddings_shape"] = list(image_embeds.shape)
return embedding_params if embedding_params else None
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from dataclasses import dataclass
from typing import Any, Dict
import torch
from PIL import Image
from transformers import AutoConfig, AutoImageProcessor
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class QwenGridParams:
"""Cached Qwen VL image processor parameters for grid_thw computation."""
patch_size: int
merge_size: int
factor: int
min_pixels: int
max_pixels: int
vision_hidden_dim: int
def load_qwen_grid_params(model_name: str) -> QwenGridParams | None:
"""Load Qwen VL grid parameters from model config.
Reads AutoImageProcessor and vision_config at init time so that
grid_thw can be computed from image dimensions alone (no GPU needed).
Returns None if loading fails (e.g. model not cached locally).
"""
try:
processor = AutoImageProcessor.from_pretrained(
model_name, trust_remote_code=True
)
vision_config = AutoConfig.from_pretrained(
model_name, trust_remote_code=True
).vision_config
patch_size: int = processor.patch_size
merge_size: int = processor.merge_size
factor = patch_size * merge_size
# Qwen2/2.5-VL use min_pixels/max_pixels directly.
# Qwen3-VL sets them to None and uses size.shortest_edge/longest_edge.
min_pixels: int = (
processor.min_pixels
if processor.min_pixels is not None
else processor.size.get("shortest_edge", factor)
)
max_pixels: int = (
processor.max_pixels
if processor.max_pixels is not None
else processor.size.get("longest_edge", factor * factor * 1280)
)
vision_hidden_dim: int = getattr(
vision_config, "out_hidden_size", vision_config.hidden_size
)
return QwenGridParams(
patch_size=patch_size,
merge_size=merge_size,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
vision_hidden_dim=vision_hidden_dim,
)
except (OSError, ValueError) as exc:
logger.warning(
"Failed to load Qwen VL image processor for %s: %s. "
"P/D disaggregation without encode worker will not "
"produce embedding_params for decode.",
model_name,
exc,
exc_info=True,
)
return None
def _compute_qwen_grid_thw(
image_data: Any,
params: QwenGridParams,
) -> tuple[list[list[int]], list[int]] | tuple[None, None]:
"""Compute image_grid_thw and embeddings_shape from PIL images.
Uses smart_resize with cached processor parameters — the same
resize logic that vLLM/transformers uses internally, but without
the overhead of full image preprocessing (~0.4us vs ~5ms).
Args:
image_data: Single PIL.Image.Image or list of them.
params: Cached grid parameters from load_qwen_grid_params().
Returns:
(grid_thw, embeddings_shape) or (None, None) on failure.
grid_thw: list of [grid_t, grid_h, grid_w] per image.
embeddings_shape: [total_tokens, vision_hidden_dim].
"""
if isinstance(image_data, Image.Image):
images = [image_data]
elif isinstance(image_data, list):
images = [img for img in image_data if isinstance(img, Image.Image)]
else:
return None, None
if not images:
return None, None
grid_thw: list[list[int]] = []
total_tokens = 0
merge_sq = params.merge_size**2
for img in images:
w, h = img.size # PIL is (width, height)
rh, rw = smart_resize(
h,
w,
factor=params.factor,
min_pixels=params.min_pixels,
max_pixels=params.max_pixels,
)
grid_t = 1 # single image, temporal dim always 1
grid_h = rh // params.patch_size
grid_w = rw // params.patch_size
grid_thw.append([grid_t, grid_h, grid_w])
total_tokens += (grid_t * grid_h * grid_w) // merge_sq
return grid_thw, [total_tokens, params.vision_hidden_dim]
def build_qwen_embedding_params(
multi_modal_data: Dict[str, Any],
grid_params: QwenGridParams | None,
) -> Dict[str, Any] | None:
"""Build embedding parameters for Qwen VL decode.
Qwen VL's processor expands image tokens using image_grid_thw for mRoPE
position initialization. The decode worker needs this metadata even though
it doesn't re-encode images — the KV cache has the vision context but the
processor still needs grid dimensions to compute positions.
Two input paths depending on how prefill processed images:
1. **Encode worker path** (dict): The encode worker produced embeddings
via the embedding loader. ``multi_modal_data["image"]`` is a dict with
``image_embeds`` (tensor) and ``image_grid_thw`` (tensor/list).
We extract and serialize them for transfer to decode.
2. **PIL path** (no encode worker): Prefill loaded images directly as
PIL.Image objects. We compute grid_thw from image dimensions using
smart_resize with cached processor parameters (~0.4us per image).
Args:
multi_modal_data: The multimodal data dict from prefill processing.
grid_params: Cached Qwen VL processor parameters, or None if
loading failed at init time.
Returns:
Dict with ``image_grid_thw`` and ``embeddings_shape``, or None if
no image data or parameters are unavailable.
"""
embedding_params: Dict[str, Any] = {}
image_data = multi_modal_data.get("image")
if isinstance(image_data, dict):
# Path 1: encode worker produced embeddings as a dict
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
embedding_params["image_grid_thw"] = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
embedding_params["embeddings_shape"] = list(image_embeds.shape)
elif image_data is not None and grid_params is not None:
# Path 2: PIL images — compute grid_thw from image dimensions
grid_thw, embeddings_shape = _compute_qwen_grid_thw(image_data, grid_params)
if grid_thw is not None:
embedding_params["image_grid_thw"] = grid_thw
embedding_params["embeddings_shape"] = embeddings_shape
# TODO(DIS-1679): handle np.ndarray from --frontend-decoding NIXL path
return embedding_params if embedding_params else None
...@@ -401,8 +401,6 @@ class TestDecodeWorkerMultimodalBranching: ...@@ -401,8 +401,6 @@ class TestDecodeWorkerMultimodalBranching:
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
disaggregation_mode="DECODE", disaggregation_mode="DECODE",
) )
# Return an error from _build_prompt_from_request so we don't need
# to mock the full engine — just verify we get past the decode guard.
handler._build_prompt_from_request = MagicMock( handler._build_prompt_from_request = MagicMock(
return_value=(None, None, {"status": "error", "message": "test stop"}) return_value=(None, None, {"status": "error", "message": "test stop"})
) )
...@@ -494,6 +492,66 @@ class TestBuildEmbeddingParams: ...@@ -494,6 +492,66 @@ class TestBuildEmbeddingParams:
assert "embeddings_shape" in result assert "embeddings_shape" in result
assert result["embeddings_shape"] == [1, 256, 1024] assert result["embeddings_shape"] == [1, 256, 1024]
def test_pil_image_qwen_computes_grid(self):
"""PIL image for Qwen VL with grid params -> computes valid embedding_params."""
from PIL import Image
from dynamo.vllm.multimodal_utils.models.qwen import QwenGridParams
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
# Qwen3-VL: patch=16, merge=2, factor=32
handler._qwen_grid_params = QwenGridParams(
patch_size=16,
merge_size=2,
factor=32,
min_pixels=65536,
max_pixels=16777216,
vision_hidden_dim=2048,
)
img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img})
assert result is not None
assert result["image_grid_thw"] == [[1, 30, 40]]
# total_tokens = 1*30*40 // 4 = 300
assert result["embeddings_shape"] == [300, 2048]
def test_pil_multi_image_qwen_computes_grid(self):
"""Multiple PIL images for Qwen VL -> computes combined embedding_params."""
from PIL import Image
from dynamo.vllm.multimodal_utils.models.qwen import QwenGridParams
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
handler._qwen_grid_params = QwenGridParams(
patch_size=16,
merge_size=2,
factor=32,
min_pixels=65536,
max_pixels=16777216,
vision_hidden_dim=2048,
)
imgs = [Image.new("RGB", (640, 480)), Image.new("RGB", (320, 320))]
result = handler._build_embedding_params({"image": imgs})
assert result is not None
assert len(result["image_grid_thw"]) == 2
assert result["image_grid_thw"][0] == [1, 30, 40]
assert result["embeddings_shape"][1] == 2048
def test_pil_image_qwen_params_unavailable_returns_none(self):
"""Qwen VL with no grid params -> returns None (fallback)."""
from PIL import Image
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
handler._qwen_grid_params = None
img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img})
assert result is None
def test_pil_image_list_non_qwen_returns_none(self): def test_pil_image_list_non_qwen_returns_none(self):
"""PIL image list for non-Qwen model -> returns None.""" """PIL image list for non-Qwen model -> returns None."""
handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf") handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf")
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Disaggregated multimodal P/D serving (no separate encode worker).
# Prefill handles image loading via PIL and computes image_grid_thw
# for the decode worker using Qwen2VLImageProcessor's smart_resize.
#
# This is a simpler deployment than E/P/D: only 2 workers instead of 3.
# Trade-off: prefill does vision encoding internally (no dedicated encoder),
# which uses more GPU memory on the prefill worker.
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/gpu_utils.sh"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
# Default values
MODEL_NAME="Qwen/Qwen3-VL-2B-Instruct"
SINGLE_GPU=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Disaggregated multimodal serving with Prefill/Decode workers (no encoder)"
echo "Prefill loads images via PIL and computes grid metadata for decode."
echo ""
echo "Options:"
echo " --model <model_name> Specify the VLM model (default: $MODEL_NAME)"
echo " --single-gpu Pack both workers on 1 GPU (for small models)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
done
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
if [[ "$SINGLE_GPU" == "true" ]]; then
GPU_LABEL="1 GPU"
else
GPU_LABEL="2 GPUs"
fi
print_launch_banner --multimodal "Launching Disaggregated Multimodal P/D ($GPU_LABEL)" "$MODEL_NAME" "$HTTP_PORT"
# Start frontend
echo "Starting frontend..."
python -m dynamo.frontend &
EXTRA_ARGS=""
PD_EXTRA_ARGS=""
# GPU assignments
DYN_PREFILL_WORKER_GPU=${DYN_PREFILL_WORKER_GPU:-0}
DYN_DECODE_WORKER_GPU=${DYN_DECODE_WORKER_GPU:-1}
# GPU memory utilization
DYN_PREFILL_GPU_MEM=${DYN_PREFILL_GPU_MEM:-0.9}
DYN_DECODE_GPU_MEM=${DYN_DECODE_GPU_MEM:-0.9}
PD_KV_CACHE_BYTES=$((512 * 1024 * 1024))
if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_PREFILL_WORKER_GPU=0
DYN_DECODE_WORKER_GPU=0
DYN_PREFILL_GPU_MEM=0.45
DYN_DECODE_GPU_MEM=0.45
EXTRA_ARGS="--enforce-eager"
PD_EXTRA_ARGS="--max-model-len 4096 \
--kv-cache-memory-bytes $PD_KV_CACHE_BYTES \
--limit-mm-per-prompt {\"image\":1,\"video\":0,\"audio\":0}"
fi
# Start prefill worker (handles image loading internally, no --route-to-encoder)
echo "Starting prefill worker on GPU $DYN_PREFILL_WORKER_GPU (GPU mem: $DYN_PREFILL_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU \
python -m dynamo.vllm \
--disaggregation-mode prefill \
--enable-multimodal \
--model $MODEL_NAME \
--gpu-memory-utilization $DYN_PREFILL_GPU_MEM \
$EXTRA_ARGS \
$PD_EXTRA_ARGS \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
# Start decode worker
echo "Starting decode worker on GPU $DYN_DECODE_WORKER_GPU (GPU mem: $DYN_DECODE_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU \
python -m dynamo.vllm \
--disaggregation-mode decode \
--enable-multimodal \
--enable-mm-embeds \
--model $MODEL_NAME \
--gpu-memory-utilization $DYN_DECODE_GPU_MEM \
$EXTRA_ARGS \
$PD_EXTRA_ARGS \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
echo "=================================================="
echo "All components started. Waiting for initialization..."
echo "=================================================="
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
...@@ -447,6 +447,38 @@ vllm_configs = { ...@@ -447,6 +447,38 @@ vllm_configs = {
) )
], ],
), ),
# P/D multimodal (no encoder): prefill loads images via PIL,
# computes grid_thw for decode using smart_resize.
"multimodal_p_d_qwen": VLLMConfig(
name="multimodal_p_d_qwen",
directory=vllm_dir,
script_name="disagg_multimodal_p_d.sh",
marks=[
pytest.mark.gpu_1,
pytest.mark.pre_merge,
],
model="Qwen/Qwen3-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
timeout=300,
request_payloads=[
chat_payload(
[
{
"type": "text",
"text": "What colors are in the following image? Respond only with the colors.",
},
{
"type": "image_url",
"image_url": {"url": MULTIMODAL_IMG_URL},
},
],
repeat_count=1,
expected_response=["green"],
temperature=0.0,
max_tokens=100,
)
],
),
"multimodal_agg_qwen": VLLMConfig( "multimodal_agg_qwen": VLLMConfig(
name="multimodal_agg_qwen", name="multimodal_agg_qwen",
directory=vllm_dir, directory=vllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment