Unverified Commit 22d910a5 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: support for agg llama4 mulimodal (#3984)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent f2a3c638
...@@ -11,11 +11,10 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" ...@@ -11,11 +11,10 @@ MODEL_NAME="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
python -m dynamo.frontend --http-port=8000 & python -m dynamo.frontend --http-port=8000 &
# run processor # run processor
python3 components/processor.py --model $MODEL_NAME --prompt-template "<|image|>\n<prompt>" & python -m dynamo.vllm --multimodal-processor --model $MODEL_NAME --mm-prompt-template "<|image|>\n<prompt>" &
# LLama 4 doesn't support image embedding input, so the prefill worker will also # Llama 4 doesn't support image embedding input, so use encode+prefill worker
# handle image encoding. # that handles image encoding inline
# run EP/D workers python -m dynamo.vllm --multimodal-encode-prefill-worker --model $MODEL_NAME --tensor-parallel-size=8 --max-model-len=208960 --gpu-memory-utilization 0.80 &
python3 components/worker.py --model $MODEL_NAME --worker-type encode_prefill --tensor-parallel-size=8 --max-model-len=208960 &
# Wait for all background processes to complete # Wait for all background processes to complete
wait wait
...@@ -69,6 +69,7 @@ class Config: ...@@ -69,6 +69,7 @@ class Config:
multimodal_processor: bool = False multimodal_processor: bool = False
multimodal_encode_worker: bool = False multimodal_encode_worker: bool = False
multimodal_worker: bool = False multimodal_worker: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:" mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
# dump config to file # dump config to file
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
...@@ -169,6 +170,11 @@ def parse_args() -> Config: ...@@ -169,6 +170,11 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data", help="Run as multimodal worker component for LLM inference with multimodal data",
) )
parser.add_argument(
"--multimodal-encode-prefill-worker",
action="store_true",
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)",
)
parser.add_argument( parser.add_argument(
"--mm-prompt-template", "--mm-prompt-template",
type=str, type=str,
...@@ -212,10 +218,11 @@ def parse_args() -> Config: ...@@ -212,10 +218,11 @@ def parse_args() -> Config:
int(bool(args.multimodal_processor)) int(bool(args.multimodal_processor))
+ int(bool(args.multimodal_encode_worker)) + int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker)) + int(bool(args.multimodal_worker))
+ int(bool(args.multimodal_encode_prefill_worker))
) )
if mm_flags > 1: if mm_flags > 1:
raise ValueError( raise ValueError(
"Use only one of --multimodal-processor, --multimodal-encode-worker, or --multimodal-worker" "Use only one of --multimodal-processor, --multimodal-encode-worker, --multimodal-worker, or --multimodal-encode-prefill-worker"
) )
# Set component and endpoint based on worker type # Set component and endpoint based on worker type
...@@ -225,6 +232,9 @@ def parse_args() -> Config: ...@@ -225,6 +232,9 @@ def parse_args() -> Config:
elif args.multimodal_encode_worker: elif args.multimodal_encode_worker:
config.component = "encoder" config.component = "encoder"
config.endpoint = "generate" config.endpoint = "generate"
elif args.multimodal_encode_prefill_worker:
config.component = "encoder"
config.endpoint = "generate"
elif args.multimodal_worker and args.is_prefill_worker: elif args.multimodal_worker and args.is_prefill_worker:
config.component = "prefill" config.component = "prefill"
config.endpoint = "generate" config.endpoint = "generate"
...@@ -248,6 +258,7 @@ def parse_args() -> Config: ...@@ -248,6 +258,7 @@ def parse_args() -> Config:
config.multimodal_processor = args.multimodal_processor config.multimodal_processor = args.multimodal_processor
config.multimodal_encode_worker = args.multimodal_encode_worker config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker config.multimodal_worker = args.multimodal_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
......
...@@ -106,7 +106,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -106,7 +106,7 @@ async def worker(runtime: DistributedRuntime):
elif config.multimodal_encode_worker: elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config) await init_multimodal_encode_worker(runtime, config)
logger.debug("init_multimodal_encode_worker completed") logger.debug("init_multimodal_encode_worker completed")
elif config.multimodal_worker: elif config.multimodal_worker or config.multimodal_encode_prefill_worker:
await init_multimodal_worker(runtime, config) await init_multimodal_worker(runtime, config)
logger.debug("init_multimodal_worker completed") logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker: elif config.is_prefill_worker:
...@@ -605,8 +605,15 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con ...@@ -605,8 +605,15 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal worker component for aggregated or disaggregated mode""" """
Initialize multimodal worker component.
Supports two modes:
1. --multimodal-worker: Receives embeddings from separate encoder
2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4)
Both can operate in aggregated (P+D) or disaggregated (P→D) mode.
"""
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
await component.create_service() await component.create_service()
...@@ -615,16 +622,12 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -615,16 +622,12 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config) engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# TODO: Support Disaggregated mode separately # For aggregated mode, no downstream client is needed
client = ( # TODO: Implement disaggregated mode with proper decode worker client
await runtime.namespace(config.namespace) downstream_client = None
.component("backend")
.endpoint("generate")
.client()
)
handler = MultimodalPDWorkerHandler( handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, client runtime, component, engine_client, config, downstream_client
) )
await handler.async_init(runtime) await handler.async_init(runtime)
...@@ -637,14 +640,15 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): ...@@ -637,14 +640,15 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
handler.kv_publisher = kv_publisher handler.kv_publisher = kv_publisher
metrics_labels = [("model", config.model)] metrics_labels = [("model", config.model)]
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=metrics_labels handler.generate,
metrics_labels=metrics_labels,
), ),
clear_endpoint.serve_endpoint( clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=metrics_labels handler.clear_kv_blocks,
metrics_labels=metrics_labels,
), ),
) )
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment