Unverified Commit 5cd8005c authored by Indrajit Bhosale's avatar Indrajit Bhosale Committed by GitHub
Browse files

fix: Fix decode worker in vllm for qwen_vl models (#5281)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
Co-authored-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 403ff669
...@@ -18,6 +18,7 @@ from ..multimodal_utils import ( ...@@ -18,6 +18,7 @@ from ..multimodal_utils import (
construct_mm_data, construct_mm_data,
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,10 +64,25 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -63,10 +64,25 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
request = vLLMMultimodalRequest.model_validate(request) request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.") logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# Decode worker doesn't process embeddings, so we pass None or empty tensor # For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
# image_grid_thw for position embeddings calculation. The decode worker
# receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM
# will expand it using the multi_modal_data, ensuring the block count
# matches what prefill computed.
#
# We pass unique placeholder embeddings (seeded by request_id) since the
# actual embeddings are already in the KV cache from prefill. The unique
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data(
request.image_grid_thw, request.embeddings_shape, request.request_id
)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"], prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
), ),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
...@@ -254,9 +270,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -254,9 +270,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
decode_request = copy.deepcopy(request) decode_request = copy.deepcopy(request)
async for prefill_response in gen: async for prefill_response in gen:
# Update the prompt token id in the decode request to the one # For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# in response, which has image templated filled in. So that # The decode worker will pass multi_modal_data which causes vLLM to
# the decode worker will fetch correct amount of KV blocks. # expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model):
decode_request.engine_prompt[ decode_request.engine_prompt[
"prompt_token_ids" "prompt_token_ids"
] = prefill_response.prompt_token_ids ] = prefill_response.prompt_token_ids
......
...@@ -177,3 +177,61 @@ def _construct_qwen_image_data( ...@@ -177,3 +177,61 @@ def _construct_qwen_image_data(
"image_grid_thw": grid_thw_tensor, "image_grid_thw": grid_thw_tensor,
} }
} }
def construct_qwen_decode_mm_data(
image_grid_thw: Optional[List[Any]],
embeddings_shape: Optional[Any],
request_id: str,
*,
dtype: torch.dtype = torch.float16,
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Construct schema-valid Qwen multimodal data for vLLM v1 disagg decode.
This is a WORKAROUND (WAR) for vLLM's disaggregated multimodal decode limitations.
Notes:
- vLLM parses multimodal inputs and builds `mm_features` from `multi_modal_data`.
- For Qwen VL models, the parser enforces that image data contains BOTH
`image_embeds` and `image_grid_thw` keys.
- In disaggregated decode, the KV cache already includes the vision context
from prefill; decode still needs `mm_features` for mRoPE initialization.
WAR Details:
- We generate unique placeholder embeddings based on request_id to prevent
incorrect prefix cache matches between different images with same dimensions.
- Without this, zero embeddings + same image_grid_thw would create identical
cache signatures, causing decode to incorrectly reuse cached KV from
different images.
Caching Caveat:
- This WAR disables prefix cache reuse on the DECODE worker (each request
has unique placeholder embeddings).
- Prefix caching still works correctly on the PREFILL worker, which uses
actual image embeddings. This is where the caching benefit matters since
prefill does the heavy computation.
- Decode receives KV blocks from prefill via NIXL transfer anyway, so
decode-side prefix caching provides minimal benefit in disaggregated setup.
"""
if image_grid_thw is None or len(image_grid_thw) == 0:
raise ValueError("No image grid provided for Qwen model.")
if embeddings_shape is None:
raise ValueError("embeddings_shape is required for Qwen decode mm data.")
# WAR: Use request_id hash as seed for unique placeholder values.
# This prevents prefix cache from incorrectly matching different images
# that happen to have the same dimensions (same image_grid_thw).
seed = hash(request_id) & 0xFFFFFFFF # Convert to positive 32-bit int
generator = torch.Generator().manual_seed(seed)
image_embeds = torch.randn(
embeddings_shape, dtype=dtype, device="cpu", generator=generator
)
if image_embeds.ndim == 3:
image_embeds = image_embeds.squeeze(0)
return {
"image": {
"image_embeds": image_embeds,
"image_grid_thw": torch.tensor(image_grid_thw),
}
}
...@@ -93,7 +93,7 @@ CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-wo ...@@ -93,7 +93,7 @@ CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-wo
# Start decode worker # Start decode worker
echo "Starting decode worker on GPU 2..." echo "Starting decode worker on GPU 2..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
echo "==================================================" echo "=================================================="
echo "All components started. Waiting for initialization..." echo "All components started. Waiting for initialization..."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment