Unverified Commit e71f1d2b authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent c68d159f
...@@ -1685,6 +1685,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1685,6 +1685,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
embedding_params = prefill_result.get("disaggregated_params", {}).get( embedding_params = prefill_result.get("disaggregated_params", {}).get(
"embedding_params" "embedding_params"
) )
# Normalize embedding_params to None if it is an empty dict
if not embedding_params:
embedding_params = None
else: else:
kv_params = None kv_params = None
embedding_params = None embedding_params = None
...@@ -1723,6 +1726,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1723,6 +1726,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
logger.error("Request %s: %s", request_id, msg) logger.error("Request %s: %s", request_id, msg)
yield {"status": "error", "message": msg} yield {"status": "error", "message": msg}
return return
else:
# Non-qwen model, assume the multi_modal_data has been consumed
# in prefill, so we can use the expanded prompt token ids
# without multimodal data
if embedding_params and "expanded_prompt_token_ids" in embedding_params:
request["token_ids"] = embedding_params["expanded_prompt_token_ids"]
has_mm_data = False
# TODO(DIS-1661): video/audio re-downloaded on decode. # TODO(DIS-1661): video/audio re-downloaded on decode.
# TODO(DIS-1664): mixed image+video in disagg decode is not # TODO(DIS-1664): mixed image+video in disagg decode is not
# supported — synthetic image data would be overwritten. # supported — synthetic image data would be overwritten.
...@@ -1967,7 +1977,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1967,7 +1977,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
context, context,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
embedding_params = self._build_embedding_params(multi_modal_data or {})
# Build prompt from request (handles both prompt_embeds and token_ids) # Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request( prompt, embedding_sequence_length, error = self._build_prompt_from_request(
...@@ -2049,10 +2058,16 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -2049,10 +2058,16 @@ class PrefillWorkerHandler(BaseWorkerHandler):
token_ids = res.outputs[0].token_ids if res.outputs else [] token_ids = res.outputs[0].token_ids if res.outputs else []
# For prefill worker, only one res will be generated,
# so we can always build embedding params here without conditionals
embedding_params = self._build_embedding_params(
multi_modal_data or {}, res.prompt_token_ids
)
output: Dict[str, Any] = { output: Dict[str, Any] = {
"token_ids": list(token_ids), "token_ids": list(token_ids),
"disaggregated_params": self._build_disaggregated_params( "disaggregated_params": self._build_disaggregated_params(
res.kv_transfer_params, embedding_params res.kv_transfer_params,
embedding_params,
), ),
"completion_usage": BaseWorkerHandler._build_completion_usage( "completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res, request_output=res,
...@@ -2073,18 +2088,36 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -2073,18 +2088,36 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield output yield output
def _build_disaggregated_params(self, kv_transfer_params, embedding_params=None): def _build_disaggregated_params(
self, kv_transfer_params, embedding_params=None, expanded_prompt_token_ids=None
):
disaggregated_params = {} disaggregated_params = {}
if kv_transfer_params is not None: if kv_transfer_params is not None:
disaggregated_params["kv_transfer_params"] = kv_transfer_params disaggregated_params["kv_transfer_params"] = kv_transfer_params
if embedding_params is not None: if embedding_params is not None:
disaggregated_params["embedding_params"] = embedding_params disaggregated_params["embedding_params"] = embedding_params
if expanded_prompt_token_ids is not None:
disaggregated_params[
"expanded_prompt_token_ids"
] = expanded_prompt_token_ids
return disaggregated_params if disaggregated_params else None return disaggregated_params if disaggregated_params else None
def _build_embedding_params( def _build_embedding_params(
self, multi_modal_data: dict[str, Any] self, multi_modal_data: dict[str, Any], prompt_token_ids: list[int]
) -> Dict[str, Any] | None: ) -> Dict[str, Any] | None:
# [gluo NOTE] there could be different model architectures that
# need different embedding params, will add more logic if needed
if not is_qwen_vl_model(self.config.model): if not is_qwen_vl_model(self.config.model):
return None # For non-qwen models, vLLM doesn't trigger mm preprocess so
# decode worker only needs expanded prompt to properly fetch KV blocks
# from prefill.
if multi_modal_data:
return {"expanded_prompt_token_ids": prompt_token_ids}
else:
# For qwen models, vLLM triggers mm preprocess so decode worker will
# perform token expansion unconditionally, so we need to pass
# original prompt and sufficient metadata to reconstruct mm embedding
# as request input.
return build_qwen_embedding_params(multi_modal_data, self._qwen_grid_params) return build_qwen_embedding_params(multi_modal_data, self._qwen_grid_params)
return None
...@@ -6,6 +6,7 @@ from typing import Any, Sequence ...@@ -6,6 +6,7 @@ from typing import Any, Sequence
import blake3 import blake3
import numpy as np import numpy as np
import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,6 +21,10 @@ def image_to_bytes(img: Any) -> bytes: ...@@ -20,6 +21,10 @@ def image_to_bytes(img: Any) -> bytes:
if isinstance(img, Image.Image | np.ndarray): if isinstance(img, Image.Image | np.ndarray):
return img.tobytes() return img.tobytes()
if isinstance(img, torch.Tensor):
# Make sure the bytes are on the CPU
return img.cpu().numpy().tobytes()
raise TypeError(f"Unsupported image type for hashing: {type(img)}") raise TypeError(f"Unsupported image type for hashing: {type(img)}")
......
...@@ -495,7 +495,7 @@ class TestBuildEmbeddingParams: ...@@ -495,7 +495,7 @@ class TestBuildEmbeddingParams:
"image_grid_thw": torch.tensor([[1, 16, 16]]), "image_grid_thw": torch.tensor([[1, 16, 16]]),
} }
} }
result = handler._build_embedding_params(mm_data) result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result is not None assert result is not None
assert "image_grid_thw" in result assert "image_grid_thw" in result
...@@ -520,7 +520,7 @@ class TestBuildEmbeddingParams: ...@@ -520,7 +520,7 @@ class TestBuildEmbeddingParams:
) )
img = Image.new("RGB", (640, 480)) img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img}) result = handler._build_embedding_params({"image": img}, [1, 2, 3])
assert result is not None assert result is not None
assert result["image_grid_thw"] == [[1, 30, 40]] assert result["image_grid_thw"] == [[1, 30, 40]]
...@@ -544,7 +544,7 @@ class TestBuildEmbeddingParams: ...@@ -544,7 +544,7 @@ class TestBuildEmbeddingParams:
) )
imgs = [Image.new("RGB", (640, 480)), Image.new("RGB", (320, 320))] imgs = [Image.new("RGB", (640, 480)), Image.new("RGB", (320, 320))]
result = handler._build_embedding_params({"image": imgs}) result = handler._build_embedding_params({"image": imgs}, [1, 2, 3])
assert result is not None assert result is not None
assert len(result["image_grid_thw"]) == 2 assert len(result["image_grid_thw"]) == 2
...@@ -559,21 +559,21 @@ class TestBuildEmbeddingParams: ...@@ -559,21 +559,21 @@ class TestBuildEmbeddingParams:
handler._qwen_grid_params = None handler._qwen_grid_params = None
img = Image.new("RGB", (640, 480)) img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img}) result = handler._build_embedding_params({"image": img}, [1, 2, 3])
assert result is None assert result is None
def test_pil_image_list_non_qwen_returns_none(self): def test_pil_image_list_llava_returns_expanded_prompt_token_ids(self):
"""PIL image list for non-Qwen model -> returns None.""" """PIL image list for LLaVA model -> returns expanded prompt token ids."""
handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf") handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf")
mm_data = {"image": [MagicMock()]} mm_data = {"image": [MagicMock()]}
result = handler._build_embedding_params(mm_data) result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result is None assert result["expanded_prompt_token_ids"] == [1, 2, 3]
def test_no_image_data_returns_none(self): def test_no_image_data_returns_none(self):
"""No image data -> returns None.""" """No image data -> returns None."""
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct") handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
mm_data = {} mm_data = {}
result = handler._build_embedding_params(mm_data) result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result is None assert result is None
...@@ -86,7 +86,7 @@ if [[ "$SINGLE_GPU" == "true" ]]; then ...@@ -86,7 +86,7 @@ if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0} DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.1} DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.1}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.7} DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.7}
EXTRA_ARGS="--enforce-eager" EXTRA_ARGS="--enforce-eager --max-model-len $PD_MAX_MODEL_LEN"
else else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1} DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-2} DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-2}
...@@ -112,7 +112,6 @@ python -m dynamo.vllm \ ...@@ -112,7 +112,6 @@ python -m dynamo.vllm \
--enable-multimodal \ --enable-multimodal \
--enable-mm-embeds \ --enable-mm-embeds \
--model "$MODEL_NAME" \ --model "$MODEL_NAME" \
--max-model-len "$PD_MAX_MODEL_LEN" \
--gpu-memory-utilization "$DYN_PD_GPU_MEM" \ --gpu-memory-utilization "$DYN_PD_GPU_MEM" \
$EXTRA_ARGS \ $EXTRA_ARGS \
"${EXTRA_PD_ARGS[@]}" & "${EXTRA_PD_ARGS[@]}" &
......
...@@ -108,4 +108,23 @@ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [ ...@@ -108,4 +108,23 @@ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [
extra_vllm_args=["--dtype", "bfloat16"], extra_vllm_args=["--dtype", "bfloat16"],
gated=True, gated=True,
), ),
# [gluo NOTE] LLaVA 1.5 7B is big model and require at least 3 GPUs to run.
# We may use less GPUs by squeezing the model onto 2 GPUs.
MultimodalModelProfile(
name="llava-hf/llava-1.5-7b-hf",
short_name="llava-1.5-7b",
topologies={
"e_pd": TopologyConfig(
marks=[pytest.mark.pre_merge],
timeout_s=340,
gpu_marker="gpu_4",
),
"epd": TopologyConfig(
marks=[pytest.mark.pre_merge],
timeout_s=300,
gpu_marker="gpu_4",
),
},
request_payloads=[make_image_payload(["green"])],
),
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment