Unverified Commit e71f1d2b authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files
parent c68d159f
......@@ -1685,6 +1685,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
embedding_params = prefill_result.get("disaggregated_params", {}).get(
"embedding_params"
)
# Normalize embedding_params to None if it is an empty dict
if not embedding_params:
embedding_params = None
else:
kv_params = None
embedding_params = None
......@@ -1723,6 +1726,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
logger.error("Request %s: %s", request_id, msg)
yield {"status": "error", "message": msg}
return
else:
# Non-qwen model, assume the multi_modal_data has been consumed
# in prefill, so we can use the expanded prompt token ids
# without multimodal data
if embedding_params and "expanded_prompt_token_ids" in embedding_params:
request["token_ids"] = embedding_params["expanded_prompt_token_ids"]
has_mm_data = False
# TODO(DIS-1661): video/audio re-downloaded on decode.
# TODO(DIS-1664): mixed image+video in disagg decode is not
# supported — synthetic image data would be overwritten.
......@@ -1967,7 +1977,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
context,
mm_processor_kwargs=mm_processor_kwargs,
)
embedding_params = self._build_embedding_params(multi_modal_data or {})
# Build prompt from request (handles both prompt_embeds and token_ids)
prompt, embedding_sequence_length, error = self._build_prompt_from_request(
......@@ -2049,10 +2058,16 @@ class PrefillWorkerHandler(BaseWorkerHandler):
token_ids = res.outputs[0].token_ids if res.outputs else []
# For prefill worker, only one res will be generated,
# so we can always build embedding params here without conditionals
embedding_params = self._build_embedding_params(
multi_modal_data or {}, res.prompt_token_ids
)
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": self._build_disaggregated_params(
res.kv_transfer_params, embedding_params
res.kv_transfer_params,
embedding_params,
),
"completion_usage": BaseWorkerHandler._build_completion_usage(
request_output=res,
......@@ -2073,18 +2088,36 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield output
def _build_disaggregated_params(self, kv_transfer_params, embedding_params=None):
def _build_disaggregated_params(
self, kv_transfer_params, embedding_params=None, expanded_prompt_token_ids=None
):
disaggregated_params = {}
if kv_transfer_params is not None:
disaggregated_params["kv_transfer_params"] = kv_transfer_params
if embedding_params is not None:
disaggregated_params["embedding_params"] = embedding_params
if expanded_prompt_token_ids is not None:
disaggregated_params[
"expanded_prompt_token_ids"
] = expanded_prompt_token_ids
return disaggregated_params if disaggregated_params else None
def _build_embedding_params(
self, multi_modal_data: dict[str, Any]
self, multi_modal_data: dict[str, Any], prompt_token_ids: list[int]
) -> Dict[str, Any] | None:
# [gluo NOTE] there could be different model architectures that
# need different embedding params, will add more logic if needed
if not is_qwen_vl_model(self.config.model):
return None
# For non-qwen models, vLLM doesn't trigger mm preprocess so
# decode worker only needs expanded prompt to properly fetch KV blocks
# from prefill.
if multi_modal_data:
return {"expanded_prompt_token_ids": prompt_token_ids}
else:
# For qwen models, vLLM triggers mm preprocess so decode worker will
# perform token expansion unconditionally, so we need to pass
# original prompt and sufficient metadata to reconstruct mm embedding
# as request input.
return build_qwen_embedding_params(multi_modal_data, self._qwen_grid_params)
return None
......@@ -6,6 +6,7 @@ from typing import Any, Sequence
import blake3
import numpy as np
import torch
logger = logging.getLogger(__name__)
......@@ -20,6 +21,10 @@ def image_to_bytes(img: Any) -> bytes:
if isinstance(img, Image.Image | np.ndarray):
return img.tobytes()
if isinstance(img, torch.Tensor):
# Make sure the bytes are on the CPU
return img.cpu().numpy().tobytes()
raise TypeError(f"Unsupported image type for hashing: {type(img)}")
......
......@@ -495,7 +495,7 @@ class TestBuildEmbeddingParams:
"image_grid_thw": torch.tensor([[1, 16, 16]]),
}
}
result = handler._build_embedding_params(mm_data)
result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result is not None
assert "image_grid_thw" in result
......@@ -520,7 +520,7 @@ class TestBuildEmbeddingParams:
)
img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img})
result = handler._build_embedding_params({"image": img}, [1, 2, 3])
assert result is not None
assert result["image_grid_thw"] == [[1, 30, 40]]
......@@ -544,7 +544,7 @@ class TestBuildEmbeddingParams:
)
imgs = [Image.new("RGB", (640, 480)), Image.new("RGB", (320, 320))]
result = handler._build_embedding_params({"image": imgs})
result = handler._build_embedding_params({"image": imgs}, [1, 2, 3])
assert result is not None
assert len(result["image_grid_thw"]) == 2
......@@ -559,21 +559,21 @@ class TestBuildEmbeddingParams:
handler._qwen_grid_params = None
img = Image.new("RGB", (640, 480))
result = handler._build_embedding_params({"image": img})
result = handler._build_embedding_params({"image": img}, [1, 2, 3])
assert result is None
def test_pil_image_list_non_qwen_returns_none(self):
"""PIL image list for non-Qwen model -> returns None."""
def test_pil_image_list_llava_returns_expanded_prompt_token_ids(self):
"""PIL image list for LLaVA model -> returns expanded prompt token ids."""
handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf")
mm_data = {"image": [MagicMock()]}
result = handler._build_embedding_params(mm_data)
assert result is None
result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result["expanded_prompt_token_ids"] == [1, 2, 3]
def test_no_image_data_returns_none(self):
"""No image data -> returns None."""
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
mm_data = {}
result = handler._build_embedding_params(mm_data)
result = handler._build_embedding_params(mm_data, [1, 2, 3])
assert result is None
......@@ -86,7 +86,7 @@ if [[ "$SINGLE_GPU" == "true" ]]; then
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-0}
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.1}
DYN_PD_GPU_MEM=${DYN_PD_GPU_MEM:-0.7}
EXTRA_ARGS="--enforce-eager"
EXTRA_ARGS="--enforce-eager --max-model-len $PD_MAX_MODEL_LEN"
else
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-1}
DYN_PD_WORKER_GPU=${DYN_PD_WORKER_GPU:-2}
......@@ -112,7 +112,6 @@ python -m dynamo.vllm \
--enable-multimodal \
--enable-mm-embeds \
--model "$MODEL_NAME" \
--max-model-len "$PD_MAX_MODEL_LEN" \
--gpu-memory-utilization "$DYN_PD_GPU_MEM" \
$EXTRA_ARGS \
"${EXTRA_PD_ARGS[@]}" &
......
......@@ -108,4 +108,23 @@ VLLM_MULTIMODAL_PROFILES: list[MultimodalModelProfile] = [
extra_vllm_args=["--dtype", "bfloat16"],
gated=True,
),
# [gluo NOTE] LLaVA 1.5 7B is big model and require at least 3 GPUs to run.
# We may use less GPUs by squeezing the model onto 2 GPUs.
MultimodalModelProfile(
name="llava-hf/llava-1.5-7b-hf",
short_name="llava-1.5-7b",
topologies={
"e_pd": TopologyConfig(
marks=[pytest.mark.pre_merge],
timeout_s=340,
gpu_marker="gpu_4",
),
"epd": TopologyConfig(
marks=[pytest.mark.pre_merge],
timeout_s=300,
gpu_marker="gpu_4",
),
},
request_payloads=[make_image_payload(["green"])],
),
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment