Unverified Commit 06bc1580 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: mm epd disagg (#4151)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 09b26bf6
......@@ -58,6 +58,7 @@ class Config:
multimodal_processor: bool = False
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
multimodal_decode_worker: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
# dump config to file
......@@ -147,6 +148,11 @@ def parse_args() -> Config:
action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data",
)
parser.add_argument(
"--multimodal-decode-worker",
action="store_true",
help="Run as multimodal decode worker in disaggregated mode",
)
parser.add_argument(
"--multimodal-encode-prefill-worker",
action="store_true",
......@@ -201,11 +207,12 @@ def parse_args() -> Config:
int(bool(args.multimodal_processor))
+ int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker))
+ int(bool(args.multimodal_decode_worker))
+ int(bool(args.multimodal_encode_prefill_worker))
)
if mm_flags > 1:
raise ValueError(
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, or --multimodal-encode-prefill-worker"
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, --multimodal-decode-worker, or --multimodal-encode-prefill-worker"
)
# Set component and endpoint based on worker type
......@@ -218,8 +225,14 @@ def parse_args() -> Config:
elif args.multimodal_encode_prefill_worker:
config.component = "encoder"
config.endpoint = "generate"
elif args.multimodal_decode_worker:
# Uses "decoder" component name because prefill worker connects to "decoder"
# (prefill uses "backend" to receive from encoder)
config.component = "decoder"
config.endpoint = "generate"
elif args.multimodal_worker and args.is_prefill_worker:
config.component = "prefill"
# Multimodal prefill worker stays as "backend" to maintain encoder connection
config.component = "backend"
config.endpoint = "generate"
elif args.is_prefill_worker:
config.component = "prefill"
......@@ -238,6 +251,7 @@ def parse_args() -> Config:
config.multimodal_processor = args.multimodal_processor
config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker
config.multimodal_decode_worker = args.multimodal_decode_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
......
......@@ -29,6 +29,7 @@ from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
ProcessorHandler,
)
......@@ -105,7 +106,11 @@ async def worker():
elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
logger.debug("init_multimodal_encode_worker completed")
elif config.multimodal_worker or config.multimodal_encode_prefill_worker:
elif (
config.multimodal_worker
or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker
):
await init_multimodal_worker(runtime, config)
logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker:
......@@ -129,7 +134,6 @@ def setup_kv_event_publisher(
"""
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
config: Worker configuration
component: Component for runtime integration
......@@ -632,12 +636,27 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# For aggregated mode, no downstream client is needed
# TODO: Implement disaggregated mode with proper decode worker client
downstream_client = None
# Set up decode worker client for disaggregated mode
decode_worker_client = None
if config.is_prefill_worker:
# Prefill worker needs to connect to decode worker
decode_worker_client = (
await runtime.namespace(config.namespace)
.component("decoder")
.endpoint("generate")
.client()
)
await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode")
# Choose handler based on worker type
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config
)
else:
handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, downstream_client
runtime, component, engine_client, config, decode_worker_client
)
await handler.async_init(runtime)
......
......@@ -22,10 +22,20 @@ while [[ $# -gt 0 ]]; do
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Disaggregated multimodal serving with separate Encode/Prefill/Decode workers"
echo ""
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates."
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use"
echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " -h, --help Show this help message"
echo ""
echo "Examples:"
echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo " $0 --model microsoft/Phi-3.5-vision-instruct"
echo " $0 --model Qwen/Qwen2.5-VL-7B-Instruct"
echo ""
exit 0
;;
*)
......@@ -52,17 +62,44 @@ else
exit 1
fi
# run ingress
echo "=================================================="
echo "Disaggregated Multimodal Serving"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "Prompt Template: $PROMPT_TEMPLATE"
echo "=================================================="
# Start frontend (no router mode)
echo "Starting frontend..."
python -m dynamo.frontend --http-port=8000 &
# Start processor
echo "Starting processor..."
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
# Configure GPU memory optimization for specific models
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" &
# Start encode worker
echo "Starting encode worker on GPU 1..."
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME $EXTRA_ARGS &
# run E/P/D workers
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME &
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg &
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg &
# Start prefill worker
echo "Starting prefill worker on GPU 2..."
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --model $MODEL_NAME $EXTRA_ARGS &
# Start decode worker
echo "Starting decode worker on GPU 3..."
CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --model $MODEL_NAME $EXTRA_ARGS &
echo "=================================================="
echo "All components started. Waiting for initialization..."
echo "=================================================="
# Wait for all background processes to complete
wait
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment