"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "cf5f65f7dfe87c88ac33cfadf3cd17b8ad96c8e3"
Unverified Commit 763264ff authored by Daniel Socek's avatar Daniel Socek Committed by GitHub
Browse files

fix: Multimodal disaggregation improvements (#5895)


Signed-off-by: default avatarDaniel Socek <daniel.socek@intel.com>
parent c7f6f6d9
...@@ -82,9 +82,17 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -82,9 +82,17 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
# values prevent incorrect prefix cache matches between different images. # values prevent incorrect prefix cache matches between different images.
multi_modal_data = None multi_modal_data = None
if is_qwen_vl_model(self.config.model): if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data( image_grid_thw = getattr(request, "image_grid_thw", None)
request.image_grid_thw, request.embeddings_shape, request.request_id embeddings_shape = getattr(request, "embeddings_shape", None)
) if image_grid_thw is None or embeddings_shape is None:
logger.warning(
"Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); "
"skipping multi_modal_data construction."
)
else:
multi_modal_data = construct_qwen_decode_mm_data(
image_grid_thw, embeddings_shape, request.request_id
)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
...@@ -277,6 +285,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -277,6 +285,24 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
await self.image_loader.load_image(mi.multimodal_input.image_url) await self.image_loader.load_image(mi.multimodal_input.image_url)
) )
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
request.image_grid_thw = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape)
# Remove the image features from the request as they are not required # Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade # Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request.multimodal_inputs = [] request.multimodal_inputs = []
......
...@@ -35,8 +35,9 @@ class SupportedModels: ...@@ -35,8 +35,9 @@ class SupportedModels:
LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf" LLAVA_1_5_7B = "llava-hf/llava-1.5-7b-hf"
QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct" QWEN_2_VL_2B = "Qwen/Qwen2-VL-2B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN_2_5_VL_3B = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
QWEN_2_5_VL_32B = "Qwen/Qwen2.5-VL-32B-Instruct"
LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf" LLAVA_NEXT_VIDEO_7B = "llava-hf/LLaVA-NeXT-Video-7B-hf"
...@@ -112,8 +113,9 @@ def is_model_supported(model_name: str, supported_model: str) -> bool: ...@@ -112,8 +113,9 @@ def is_model_supported(model_name: str, supported_model: str) -> bool:
# List of all Qwen VL model variants for easy extension # List of all Qwen VL model variants for easy extension
QWEN_VL_MODELS = [ QWEN_VL_MODELS = [
SupportedModels.QWEN_2_VL_2B, SupportedModels.QWEN_2_VL_2B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_3B, SupportedModels.QWEN_2_5_VL_3B,
SupportedModels.QWEN_2_5_VL_7B,
SupportedModels.QWEN_2_5_VL_32B,
] ]
...@@ -143,13 +145,38 @@ def load_vision_model(model_id: str) -> torch.nn.Module: ...@@ -143,13 +145,38 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
"VLLM_ENABLE_V1_MULTIPROCESSING": "0", "VLLM_ENABLE_V1_MULTIPROCESSING": "0",
} }
) )
# [gluo NOTE] this actually loads the full model, # Load only the vision model via vLLM on encoder workers to avoid loading the full LLM weights, significantly reducing memory usage.
# which require more GPU memory than needed. # Uses native vLLM encoder only model loading added in https://github.com/vllm-project/vllm/pull/30242.
# Model needs the class method get_language_model_spec to be defined for this to work.
# TODO(gluo/dsocek): Remove this monkey patch once vLLM upstream adds
# get_language_model_spec to Qwen VL model classes.
# Monkey patch to vLLM's Qwen 2 VL and Qwen 2.5 VL classes to add get_language_model_spec
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLForConditionalGeneration
@classmethod
def get_language_model_spec(cls):
return (Qwen2ForCausalLM, "language_model")
Qwen2_5_VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
Qwen2VLForConditionalGeneration.get_language_model_spec = (
get_language_model_spec
)
# Load only the vision model via vLLM
vllm_model = LLM( vllm_model = LLM(
model=model_id, model=model_id,
enforce_eager=True, enforce_eager=True,
gpu_memory_utilization=0.4, gpu_memory_utilization=0.4,
max_model_len=10, max_model_len=10,
convert="mm_encoder_only",
enable_prefix_caching=False,
) )
return ( return (
vllm_model.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.model.visual vllm_model.llm_engine.engine_core.engine_core.model_executor.driver_worker.worker.model_runner.model.visual
......
...@@ -175,7 +175,11 @@ class MultiModalGroup(BaseModel): ...@@ -175,7 +175,11 @@ class MultiModalGroup(BaseModel):
class vLLMMultimodalRequest(vLLMGenerateRequest): class vLLMMultimodalRequest(vLLMGenerateRequest):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list) # Decode-only worker can have None for multimodal_inputs
multimodal_inputs: Optional[List[MultiModalGroup]] = Field(default_factory=list)
# Add these fields for Qwen VL (mRoPE) decode-only worker
image_grid_thw: Optional[List[List[int]]] = None
embeddings_shape: Optional[List[int]] = None
class VLLMNativeEncoderRequest(BaseModel): class VLLMNativeEncoderRequest(BaseModel):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment