Unverified Commit 06bc1580 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: mm epd disagg (#4151)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 09b26bf6
...@@ -58,6 +58,7 @@ class Config: ...@@ -58,6 +58,7 @@ class Config:
multimodal_processor: bool = False multimodal_processor: bool = False
multimodal_encode_worker: bool = False multimodal_encode_worker: bool = False
multimodal_worker: bool = False multimodal_worker: bool = False
multimodal_decode_worker: bool = False
multimodal_encode_prefill_worker: bool = False multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:" mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
# dump config to file # dump config to file
...@@ -147,6 +148,11 @@ def parse_args() -> Config: ...@@ -147,6 +148,11 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data", help="Run as multimodal worker component for LLM inference with multimodal data",
) )
parser.add_argument(
"--multimodal-decode-worker",
action="store_true",
help="Run as multimodal decode worker in disaggregated mode",
)
parser.add_argument( parser.add_argument(
"--multimodal-encode-prefill-worker", "--multimodal-encode-prefill-worker",
action="store_true", action="store_true",
...@@ -201,11 +207,12 @@ def parse_args() -> Config: ...@@ -201,11 +207,12 @@ def parse_args() -> Config:
int(bool(args.multimodal_processor)) int(bool(args.multimodal_processor))
+ int(bool(args.multimodal_encode_worker)) + int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker)) + int(bool(args.multimodal_worker))
+ int(bool(args.multimodal_decode_worker))
+ int(bool(args.multimodal_encode_prefill_worker)) + int(bool(args.multimodal_encode_prefill_worker))
) )
if mm_flags > 1: if mm_flags > 1:
raise ValueError( raise ValueError(
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, or --multimodal-encode-prefill-worker" "Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, --multimodal-decode-worker, or --multimodal-encode-prefill-worker"
) )
# Set component and endpoint based on worker type # Set component and endpoint based on worker type
...@@ -218,8 +225,14 @@ def parse_args() -> Config: ...@@ -218,8 +225,14 @@ def parse_args() -> Config:
elif args.multimodal_encode_prefill_worker: elif args.multimodal_encode_prefill_worker:
config.component = "encoder" config.component = "encoder"
config.endpoint = "generate" config.endpoint = "generate"
elif args.multimodal_decode_worker:
# Uses "decoder" component name because prefill worker connects to "decoder"
# (prefill uses "backend" to receive from encoder)
config.component = "decoder"
config.endpoint = "generate"
elif args.multimodal_worker and args.is_prefill_worker: elif args.multimodal_worker and args.is_prefill_worker:
config.component = "prefill" # Multimodal prefill worker stays as "backend" to maintain encoder connection
config.component = "backend"
config.endpoint = "generate" config.endpoint = "generate"
elif args.is_prefill_worker: elif args.is_prefill_worker:
config.component = "prefill" config.component = "prefill"
...@@ -238,6 +251,7 @@ def parse_args() -> Config: ...@@ -238,6 +251,7 @@ def parse_args() -> Config:
config.multimodal_processor = args.multimodal_processor config.multimodal_processor = args.multimodal_processor
config.multimodal_encode_worker = args.multimodal_encode_worker config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker config.multimodal_worker = args.multimodal_worker
config.multimodal_decode_worker = args.multimodal_decode_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv config.store_kv = args.store_kv
......
...@@ -29,6 +29,7 @@ from dynamo.runtime import DistributedRuntime ...@@ -29,6 +29,7 @@ from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import ( from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler, EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler, MultimodalPDWorkerHandler,
ProcessorHandler, ProcessorHandler,
) )
...@@ -105,7 +106,11 @@ async def worker(): ...@@ -105,7 +106,11 @@ async def worker():
elif config.multimodal_encode_worker: elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config) await init_multimodal_encode_worker(runtime, config)
logger.debug("init_multimodal_encode_worker completed") logger.debug("init_multimodal_encode_worker completed")
elif config.multimodal_worker or config.multimodal_encode_prefill_worker: elif (
config.multimodal_worker
or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker
):
await init_multimodal_worker(runtime, config) await init_multimodal_worker(runtime, config)
logger.debug("init_multimodal_worker completed") logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker: elif config.is_prefill_worker:
...@@ -129,7 +134,6 @@ def setup_kv_event_publisher( ...@@ -129,7 +134,6 @@ def setup_kv_event_publisher(
""" """
Set up KV event publishers for prefix caching if enabled. Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port. Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args: Args:
config: Worker configuration config: Worker configuration
component: Component for runtime integration component: Component for runtime integration
...@@ -632,13 +636,28 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -632,13 +636,28 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config) engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# For aggregated mode, no downstream client is needed # Set up decode worker client for disaggregated mode
# TODO: Implement disaggregated mode with proper decode worker client decode_worker_client = None
downstream_client = None if config.is_prefill_worker:
# Prefill worker needs to connect to decode worker
decode_worker_client = (
await runtime.namespace(config.namespace)
.component("decoder")
.endpoint("generate")
.client()
)
await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode")
handler = MultimodalPDWorkerHandler( # Choose handler based on worker type
runtime, component, engine_client, config, downstream_client if config.multimodal_decode_worker:
) handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config
)
else:
handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, decode_worker_client
)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -22,10 +22,20 @@ while [[ $# -gt 0 ]]; do ...@@ -22,10 +22,20 @@ while [[ $# -gt 0 ]]; do
;; ;;
-h|--help) -h|--help)
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo ""
echo "Disaggregated multimodal serving with separate Encode/Prefill/Decode workers"
echo ""
echo "Options:" echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)" echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " --prompt-template <template> Specify the multi-modal prompt template to use. LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates." echo " --prompt-template <template> Specify the multi-modal prompt template to use"
echo " -h, --help Show this help message" echo " LLaVA 1.5 7B, Qwen2.5-VL, and Phi3V models have predefined templates"
echo " -h, --help Show this help message"
echo ""
echo "Examples:"
echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo " $0 --model microsoft/Phi-3.5-vision-instruct"
echo " $0 --model Qwen/Qwen2.5-VL-7B-Instruct"
echo ""
exit 0 exit 0
;; ;;
*) *)
...@@ -52,17 +62,44 @@ else ...@@ -52,17 +62,44 @@ else
exit 1 exit 1
fi fi
# run ingress echo "=================================================="
echo "Disaggregated Multimodal Serving"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "Prompt Template: $PROMPT_TEMPLATE"
echo "=================================================="
# Start frontend (no router mode)
echo "Starting frontend..."
python -m dynamo.frontend --http-port=8000 & python -m dynamo.frontend --http-port=8000 &
# Start processor
echo "Starting processor..."
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "$PROMPT_TEMPLATE" &
# Configure GPU memory optimization for specific models
EXTRA_ARGS=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048"
fi
# run processor # Start encode worker
python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_TEMPLATE" & echo "Starting encode worker on GPU 1..."
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --model $MODEL_NAME $EXTRA_ARGS &
# run E/P/D workers # Start prefill worker
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python3 components/encode_worker.py --model $MODEL_NAME & echo "Starting prefill worker on GPU 2..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill --enable-disagg & CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --model $MODEL_NAME $EXTRA_ARGS &
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 CUDA_VISIBLE_DEVICES=2 python3 components/worker.py --model $MODEL_NAME --worker-type decode --enable-disagg &
# Start decode worker
echo "Starting decode worker on GPU 3..."
CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --model $MODEL_NAME $EXTRA_ARGS &
echo "=================================================="
echo "All components started. Waiting for initialization..."
echo "=================================================="
# Wait for all background processes to complete # Wait for all background processes to complete
wait wait
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment