Unverified Commit 763264ff authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

fix: Multimodal disaggregation improvements (#5895)


Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
parent c7f6f6d9
......@@ -82,8 +82,16 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
image_grid_thw = getattr(request, "image_grid_thw", None)
embeddings_shape = getattr(request, "embeddings_shape", None)
if image_grid_thw is None or embeddings_shape is None:
logger.warning(
"Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); "
"skipping multi_modal_data construction."
)
else:
multi_modal_data = construct_qwen_decode_mm_data(
request.image_grid_thw, request.embeddings_shape, request.request_id
image_grid_thw, embeddings_shape, request.request_id
)
gen = self.engine_client.generate(
......@@ -277,6 +285,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
request.image_grid_thw = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request.multimodal_inputs = []
......
......@@ -35,8 +35,9 @@ class SupportedModels:
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
......@@ -112,8 +113,9 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
# List of all Qwen VL model variants for easy extension
QWEN_VL_MODELS = [
SupportedModels.QWEN_2_VL_2B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_3B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B,
]
......@@ -143,13 +145,38 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0",
}
)
# [gluo NOTE] this actually loads the full model,
# which require more GPU memory than needed.
# Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work.
# TODO(gluo/dsocek): Remove this monkey patch once vLLM upstream adds
# get_language_model_spec to Qwen VL model classes.
# Monkey patch to vLLM's Qwen 2 VL and Qwen 2.5 VL classes to add get_language_model_spec
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
@classmethod
def get_language_model_spec(cls):
return (Qwen2ForCausalLM, "language_model")
Qwen2_5_VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
Qwen2VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM
vllm_model = LLM(
model=model_id,
enforce_eager=True,
gpu_memory_utilization=0.4,
max_model_len=10,
convert="mm_encoder_only",
enable_prefix_caching=False,
)
return (
vllm_model.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.model.visual
......
......@@ -175,7 +175,11 @@ class MultiModalGroup(BaseModel):
class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
# Decode-only worker can have None for multimodal_inputs
multimodal_inputs: Optional[List[MultiModalGroup]] = Field(default_factory=list)
# Add these fields for Qwen VL (mRoPE) decode-only worker
image_grid_thw: Optional[List[List[int]]] = None
embeddings_shape: Optional[List[int]] = None
class VLLMNativeEncoderRequest(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment