Unverified Commit 22d910a5 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: support for agg llama4 mulimodal (#3984)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent f2a3c638
......@@ -11,11 +11,10 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
python -m dynamo.frontend --http-port=8000 &
# run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also
# handle image encoding.
# run EP/D workers
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --tensor-parallel-size=8 --max-model-len=208960 &
python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
# Llama 4 doesn't support image embedding input, so use encode+prefill worker
# that handles image encoding inline
python -m dynamo.vllm --multimodal-encode-prefill-worker --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
# Wait for all background processes to complete
wait
......@@ -69,6 +69,7 @@ class Config:
multimodal_processor: bool = False
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
# dump config to file
dump_config_to: Optional[str] = None
......@@ -169,6 +170,11 @@ def parse_args() -> Config:
action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data",
)
parser.add_argument(
"--multimodal-encode-prefill-worker",
action="store_true",
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)",
)
parser.add_argument(
"--mm-prompt-template",
type=str,
......@@ -212,10 +218,11 @@ def parse_args() -> Config:
int(bool(args.multimodal_processor))
+ int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker))
+ int(bool(args.multimodal_encode_prefill_worker))
)
if mm_flags > 1:
raise ValueError(
"Use only one of --multimodal-processor, --multimodal-encode-worker, or --multimodal-worker"
"Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, or --multimodal-encode-prefill-worker"
)
# Set component and endpoint based on worker type
......@@ -225,6 +232,9 @@ def parse_args() -> Config:
elif args.multimodal_encode_worker:
config.component = "encoder"
config.endpoint = "generate"
elif args.multimodal_encode_prefill_worker:
config.component = "encoder"
config.endpoint = "generate"
elif args.multimodal_worker and args.is_prefill_worker:
config.component = "prefill"
config.endpoint = "generate"
......@@ -248,6 +258,7 @@ def parse_args() -> Config:
config.multimodal_processor = args.multimodal_processor
config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template
# Validate custom Jinja template file exists if provided
......
......@@ -106,7 +106,7 @@ async def worker(runtime: DistributedRuntime):
elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
logger.debug("init_multimodal_encode_worker completed")
elif config.multimodal_worker:
elif config.multimodal_worker or config.multimodal_encode_prefill_worker:
await init_multimodal_worker(runtime, config)
logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker:
......@@ -605,8 +605,15 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal worker component for aggregated or disaggregated mode"""
"""
Initialize multimodal worker component.
Supports two modes:
1. --multimodal-worker: Receives embeddings from separate encoder
2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
"""
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()
......@@ -615,16 +622,12 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# TODO: Support Disaggregated mode separately
client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
# For aggregated mode, no downstream client is needed
# TODO: Implement disaggregated mode with proper decode worker client
downstream_client = None
handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, client
runtime, component, engine_client, config, downstream_client
)
await handler.async_init(runtime)
......@@ -637,14 +640,15 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
handler.kv_publisher = kv_publisher
metrics_labels = [("model", config.model)]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=metrics_labels
handler.generate,
metrics_labels=metrics_labels,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels
handler.clear_kv_blocks,
metrics_labels=metrics_labels,
),
)
except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment