Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
4d0380d5
Unverified
Commit
4d0380d5
authored
Feb 18, 2026
by
Qi Wang
Committed by
GitHub
Feb 18, 2026
Browse files
refactor: introduce worker factory in vLLM multimodal (#6060)
parent
638d8e68
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
355 additions
and
264 deletions
+355
-264
components/src/dynamo/vllm/args.py
components/src/dynamo/vllm/args.py
+1
-4
components/src/dynamo/vllm/backend_args.py
components/src/dynamo/vllm/backend_args.py
+1
-10
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+10
-174
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
+122
-0
components/src/dynamo/vllm/worker_factory.py
components/src/dynamo/vllm/worker_factory.py
+221
-0
docs/pages/features/multimodal/multimodal-vllm.md
docs/pages/features/multimodal/multimodal-vllm.md
+0
-1
examples/backends/vllm/launch/agg_ec_connector.sh
examples/backends/vllm/launch/agg_ec_connector.sh
+0
-75
No files found.
components/src/dynamo/vllm/args.py
View file @
4d0380d5
...
...
@@ -160,10 +160,7 @@ def update_dynamo_config_with_engine(
if
dynamo_config
.
route_to_encoder
:
dynamo_config
.
component
=
"processor"
dynamo_config
.
endpoint
=
"generate"
elif
(
dynamo_config
.
multimodal_encode_worker
or
dynamo_config
.
multimodal_encode_prefill_worker
):
elif
dynamo_config
.
multimodal_encode_worker
:
dynamo_config
.
component
=
"encoder"
dynamo_config
.
endpoint
=
"generate"
elif
dynamo_config
.
multimodal_decode_worker
:
...
...
components/src/dynamo/vllm/backend_args.py
View file @
4d0380d5
...
...
@@ -88,13 +88,6 @@ class DynamoVllmArgGroup(ArgGroup):
default
=
False
,
help
=
"Run as multimodal decode worker in disaggregated mode."
,
)
add_negatable_bool_argument
(
g
,
flag_name
=
"--multimodal-encode-prefill-worker"
,
env_var
=
"DYN_VLLM_MULTIMODAL_ENCODE_PREFILL_WORKER"
,
default
=
False
,
help
=
"Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)."
,
)
add_negatable_bool_argument
(
g
,
flag_name
=
"--enable-multimodal"
,
...
...
@@ -170,7 +163,6 @@ class DynamoVllmConfig(ConfigBase):
multimodal_encode_worker
:
bool
multimodal_worker
:
bool
multimodal_decode_worker
:
bool
multimodal_encode_prefill_worker
:
bool
enable_multimodal
:
bool
mm_prompt_template
:
str
frontend_decoding
:
bool
...
...
@@ -206,7 +198,6 @@ class DynamoVllmConfig(ConfigBase):
bool
(
self
.
multimodal_encode_worker
),
bool
(
self
.
multimodal_worker
),
bool
(
self
.
multimodal_decode_worker
),
bool
(
self
.
multimodal_encode_prefill_worker
),
]
)
...
...
@@ -215,7 +206,7 @@ class DynamoVllmConfig(ConfigBase):
if
self
.
_count_multimodal_roles
()
>
1
:
raise
ValueError
(
"Use only one of --multimodal-encode-worker, --multimodal-worker, "
"--multimodal-decode-worker
, --multimodal-encode-prefill-worker
"
"--multimodal-decode-worker"
)
def
_validate_multimodal_requires_flag
(
self
)
->
None
:
...
...
components/src/dynamo/vllm/main.py
View file @
4d0380d5
...
...
@@ -45,11 +45,7 @@ except ImportError:
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.vllm.multimodal_handlers
import
(
EncodeWorkerHandler
,
MultimodalDecodeWorkerHandler
,
MultimodalPDWorkerHandler
,
)
from
dynamo.vllm.worker_factory
import
WorkerFactory
from
.args
import
Config
,
parse_args
from
.chrek
import
get_checkpoint_config
...
...
@@ -148,18 +144,17 @@ async def worker():
)
# Route to appropriate initialization based on config flags
if
config
.
multimodal_encode_worker
:
await
init_multimodal_encode_worker
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_multimodal_encode_worker completed"
)
elif
(
config
.
multimodal_worker
or
config
.
multimodal_decode_worker
or
config
.
multimodal_encode_prefill_worker
):
await
init_multimodal_worker
(
if
WorkerFactory
.
handles
(
config
):
# Create worker factory with setup functions
factory
=
WorkerFactory
(
setup_vllm_engine_fn
=
setup_vllm_engine
,
setup_kv_event_publisher_fn
=
setup_kv_event_publisher
,
register_vllm_model_fn
=
register_vllm_model
,
)
await
factory
.
create
(
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
)
logger
.
debug
(
"
init_
multimodal
_
worker completed"
)
logger
.
debug
(
"multimodal
worker completed"
)
elif
config
.
omni
:
await
init_omni
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_omni completed"
)
...
...
@@ -924,165 +919,6 @@ def get_engine_cache_info(engine: AsyncLLM):
raise
async
def
init_multimodal_encode_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""Initialize multimodal encode worker component"""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
generate_endpoint
=
component
.
endpoint
(
config
.
endpoint
)
handler
=
EncodeWorkerHandler
(
config
.
engine_args
,
)
await
handler
.
async_init
(
runtime
)
logger
.
info
(
"Starting to serve the encode worker endpoint..."
)
try
:
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
metrics_labels
=
[
(
prometheus_names
.
labels
.
MODEL
,
config
.
model
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
model
),
],
),
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve encode worker endpoint:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
async
def
init_multimodal_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
pre_created_engine
=
None
,
):
"""
Initialize multimodal worker component.
Supports three modes:
1. --multimodal-worker: Prefill+decode worker for multimodal LLM; can route
to a separate encoder (--route-to-encoder) for embeddings. Runs
aggregated (P+D) or disaggregated (P→D).
2. --multimodal-decode-worker: Decode-only worker in disaggregated (P→D)
mode.
3. --multimodal-encode-prefill-worker: Unified encode+prefill+decode in one
worker for models with integrated image encoding (e.g., Llama 4).
"""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
generate_endpoint
=
component
.
endpoint
(
config
.
endpoint
)
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created_engine
is
not
None
:
(
engine_client
,
vllm_config
,
default_sampling_params
,
prometheus_temp_dir
,
_component_gauges
,
)
=
pre_created_engine
else
:
(
engine_client
,
vllm_config
,
default_sampling_params
,
prometheus_temp_dir
,
_component_gauges
,
)
=
setup_vllm_engine
(
config
)
# Set up encode worker client when routing to encoder is enabled
# (PD worker handles encode routing directly instead of a separate processor)
encode_worker_client
=
None
if
config
.
route_to_encoder
:
encode_worker_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"encoder"
)
.
endpoint
(
"generate"
)
.
client
()
)
logger
.
info
(
"Waiting for Encoder Worker Instances ..."
)
await
encode_worker_client
.
wait_for_instances
()
logger
.
info
(
"Connected to encoder workers"
)
# Set up decode worker client for disaggregated mode
decode_worker_client
=
None
if
config
.
is_prefill_worker
:
# Prefill worker needs to connect to decode worker
decode_worker_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"decoder"
)
.
endpoint
(
"generate"
)
.
client
()
)
await
decode_worker_client
.
wait_for_instances
()
logger
.
info
(
"Connected to decode worker for disaggregated mode"
)
# Choose handler based on worker type
if
config
.
multimodal_decode_worker
:
handler
=
MultimodalDecodeWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
shutdown_event
)
else
:
handler
=
MultimodalPDWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
encode_worker_client
,
decode_worker_client
,
shutdown_event
,
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
await
handler
.
async_init
(
runtime
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher
=
setup_kv_event_publisher
(
config
,
component
,
generate_endpoint
,
vllm_config
)
if
kv_publisher
:
handler
.
kv_publisher
=
kv_publisher
# Register model with the frontend so it can route requests
model_type
=
parse_endpoint_types
(
config
.
endpoint_types
)
model_input
=
ModelInput
.
Text
if
config
.
use_vllm_tokenizer
else
ModelInput
.
Tokens
await
register_vllm_model
(
model_input
,
model_type
,
generate_endpoint
,
config
,
engine_client
,
vllm_config
,
)
metrics_labels
=
[
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
),
]
try
:
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
metrics_labels
=
metrics_labels
,
),
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
metrics_labels
=
metrics_labels
,
),
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
async
def
init_omni
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
...
...
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
0 → 100644
View file @
4d0380d5
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for worker_factory.py"""
import
asyncio
from
unittest.mock
import
AsyncMock
,
Mock
import
pytest
from
dynamo.vllm.worker_factory
import
EngineSetupResult
,
WorkerFactory
def
_make_config
(
**
overrides
)
->
Mock
:
"""Create a mock Config with all multimodal flags defaulting to False."""
defaults
=
{
"multimodal_encode_worker"
:
False
,
"multimodal_worker"
:
False
,
"multimodal_decode_worker"
:
False
,
"omni"
:
False
,
"is_prefill_worker"
:
False
,
}
defaults
.
update
(
overrides
)
return
Mock
(
**
defaults
)
class
TestHandles
:
"""Test WorkerFactory.handles() config detection."""
def
test_multimodal_encode_worker
(
self
)
->
None
:
config
=
_make_config
(
multimodal_encode_worker
=
True
)
assert
WorkerFactory
.
handles
(
config
)
def
test_multimodal_worker
(
self
)
->
None
:
config
=
_make_config
(
multimodal_worker
=
True
)
assert
WorkerFactory
.
handles
(
config
)
def
test_multimodal_decode_worker
(
self
)
->
None
:
config
=
_make_config
(
multimodal_decode_worker
=
True
)
assert
WorkerFactory
.
handles
(
config
)
def
test_no_multimodal_flags
(
self
)
->
None
:
config
=
_make_config
()
assert
not
WorkerFactory
.
handles
(
config
)
def
test_omni_not_handled
(
self
)
->
None
:
config
=
_make_config
(
omni
=
True
)
assert
not
WorkerFactory
.
handles
(
config
)
def
test_prefill_only_not_handled
(
self
)
->
None
:
config
=
_make_config
(
is_prefill_worker
=
True
)
assert
not
WorkerFactory
.
handles
(
config
)
class
TestCreate
:
"""Test WorkerFactory.create() routing."""
@
pytest
.
fixture
def
factory
(
self
)
->
WorkerFactory
:
factory
=
WorkerFactory
(
setup_vllm_engine_fn
=
Mock
(),
setup_kv_event_publisher_fn
=
Mock
(),
register_vllm_model_fn
=
AsyncMock
(),
)
factory
.
_create_multimodal_encode_worker
=
AsyncMock
()
# type: ignore[assignment]
factory
.
_create_multimodal_worker
=
AsyncMock
()
# type: ignore[assignment]
return
factory
@
pytest
.
mark
.
asyncio
async
def
test_routes_to_multimodal_encode
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
(
multimodal_encode_worker
=
True
)
shutdown_event
=
asyncio
.
Event
()
await
factory
.
create
(
Mock
(),
config
,
shutdown_event
)
factory
.
_create_multimodal_encode_worker
.
assert_called_once
()
# type: ignore[union-attr]
@
pytest
.
mark
.
asyncio
async
def
test_routes_to_multimodal_worker
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
(
multimodal_worker
=
True
)
shutdown_event
=
asyncio
.
Event
()
await
factory
.
create
(
Mock
(),
config
,
shutdown_event
)
factory
.
_create_multimodal_worker
.
assert_called_once
()
# type: ignore[union-attr]
@
pytest
.
mark
.
asyncio
async
def
test_routes_multimodal_decode_worker
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
(
multimodal_decode_worker
=
True
)
shutdown_event
=
asyncio
.
Event
()
await
factory
.
create
(
Mock
(),
config
,
shutdown_event
)
factory
.
_create_multimodal_worker
.
assert_called_once
()
# type: ignore[union-attr]
@
pytest
.
mark
.
asyncio
async
def
test_passes_pre_created_engine
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
(
multimodal_worker
=
True
)
runtime
=
Mock
()
shutdown_event
=
asyncio
.
Event
()
pre_created_engine
:
EngineSetupResult
=
(
Mock
(),
Mock
(),
Mock
(),
"/tmp/prometheus"
,
)
await
factory
.
create
(
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
)
factory
.
_create_multimodal_worker
.
assert_called_once_with
(
# type: ignore[union-attr]
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
)
@
pytest
.
mark
.
asyncio
async
def
test_raises_when_no_multimodal_flag
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
()
with
pytest
.
raises
(
ValueError
,
match
=
"no multimodal worker type set"
):
await
factory
.
create
(
Mock
(),
config
,
asyncio
.
Event
())
components/src/dynamo/vllm/worker_factory.py
0 → 100644
View file @
4d0380d5
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Worker initialization factory for vLLM workers."""
import
asyncio
import
logging
from
collections.abc
import
Awaitable
,
Callable
from
typing
import
Any
,
Optional
from
dynamo.common.utils.endpoint_types
import
parse_endpoint_types
from
dynamo.llm
import
ModelInput
from
dynamo.runtime
import
DistributedRuntime
from
.args
import
Config
from
.multimodal_handlers
import
(
EncodeWorkerHandler
,
MultimodalDecodeWorkerHandler
,
MultimodalPDWorkerHandler
,
)
logger
=
logging
.
getLogger
(
__name__
)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir)
EngineSetupResult
=
tuple
[
Any
,
Any
,
Any
,
Any
]
SetupVllmEngineFn
=
Callable
[...,
EngineSetupResult
]
SetupKvEventPublisherFn
=
Callable
[...,
Optional
[
Any
]]
RegisterVllmModelFn
=
Callable
[...,
Awaitable
[
None
]]
class
WorkerFactory
:
"""Factory for creating and initializing multimodal vLLM workers."""
def
__init__
(
self
,
setup_vllm_engine_fn
:
SetupVllmEngineFn
,
setup_kv_event_publisher_fn
:
SetupKvEventPublisherFn
,
register_vllm_model_fn
:
RegisterVllmModelFn
,
):
self
.
setup_vllm_engine
=
setup_vllm_engine_fn
self
.
setup_kv_event_publisher
=
setup_kv_event_publisher_fn
self
.
register_vllm_model
=
register_vllm_model_fn
@
staticmethod
def
handles
(
config
:
Config
)
->
bool
:
"""Return True if this factory handles the given config."""
return
bool
(
config
.
multimodal_encode_worker
or
config
.
multimodal_worker
or
config
.
multimodal_decode_worker
)
async
def
create
(
self
,
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
pre_created_engine
:
Optional
[
EngineSetupResult
]
=
None
,
)
->
None
:
"""Create the appropriate multimodal worker based on config flags."""
if
config
.
multimodal_encode_worker
:
await
self
.
_create_multimodal_encode_worker
(
runtime
,
config
,
shutdown_event
)
elif
config
.
multimodal_worker
or
config
.
multimodal_decode_worker
:
await
self
.
_create_multimodal_worker
(
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
)
else
:
raise
ValueError
(
"WorkerFactory.create() called but no multimodal worker type set in config"
)
async
def
_create_multimodal_worker
(
self
,
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
pre_created_engine
:
Optional
[
EngineSetupResult
]
=
None
,
)
->
None
:
"""
Initialize multimodal worker component.
Supports:
- --multimodal-worker: PD worker that may receive embeddings from encoder
- --multimodal-decode-worker: Decode-only worker
Modes:
- Aggregated (P+D): Prefill and decode on same worker
- Disaggregated (P→D): Prefill forwards to separate decode worker
"""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
generate_endpoint
=
component
.
endpoint
(
config
.
endpoint
)
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created_engine
is
not
None
:
(
engine_client
,
vllm_config
,
_default_sampling_params
,
prometheus_temp_dir
,
_component_gauges
,
)
=
pre_created_engine
else
:
(
engine_client
,
vllm_config
,
_default_sampling_params
,
prometheus_temp_dir
,
_component_gauges
,
)
=
self
.
setup_vllm_engine
(
config
)
# Set up encode worker client when routing to encoder is enabled
encode_worker_client
=
None
if
config
.
route_to_encoder
:
encode_worker_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"encoder"
)
.
endpoint
(
"generate"
)
.
client
()
)
logger
.
info
(
"Waiting for Encoder Worker Instances ..."
)
await
encode_worker_client
.
wait_for_instances
()
logger
.
info
(
"Connected to encoder workers"
)
# Set up decode worker client for disaggregated mode
decode_worker_client
=
None
if
config
.
is_prefill_worker
:
decode_worker_client
=
(
await
runtime
.
namespace
(
config
.
namespace
)
.
component
(
"decoder"
)
.
endpoint
(
"generate"
)
.
client
()
)
await
decode_worker_client
.
wait_for_instances
()
logger
.
info
(
"Connected to decode worker for disaggregated mode"
)
# Choose handler based on worker type
if
config
.
multimodal_decode_worker
:
handler
=
MultimodalDecodeWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
shutdown_event
)
else
:
handler
=
MultimodalPDWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
encode_worker_client
,
decode_worker_client
,
shutdown_event
,
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
await
handler
.
async_init
(
runtime
)
# Set up KV event publisher for prefix caching if enabled
kv_publisher
=
self
.
setup_kv_event_publisher
(
config
,
component
,
generate_endpoint
,
vllm_config
)
if
kv_publisher
:
handler
.
kv_publisher
=
kv_publisher
# Register model with the frontend so it can route requests
model_type
=
parse_endpoint_types
(
config
.
endpoint_types
)
model_input
=
(
ModelInput
.
Text
if
config
.
use_vllm_tokenizer
else
ModelInput
.
Tokens
)
await
self
.
register_vllm_model
(
model_input
,
model_type
,
generate_endpoint
,
config
,
engine_client
,
vllm_config
,
)
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)]
try
:
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
metrics_labels
=
metrics_labels
,
),
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
metrics_labels
=
metrics_labels
,
),
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
async
def
_create_multimodal_encode_worker
(
self
,
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
)
->
None
:
"""Initialize standalone multimodal encode worker."""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
generate_endpoint
=
component
.
endpoint
(
config
.
endpoint
)
handler
=
EncodeWorkerHandler
(
config
.
engine_args
)
await
handler
.
async_init
(
runtime
)
logger
.
info
(
"Starting to serve the encode worker endpoint..."
)
try
:
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
metrics_labels
=
[(
"model"
,
config
.
model
)]
),
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve encode worker endpoint:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
docs/pages/features/multimodal/multimodal-vllm.md
View file @
4d0380d5
...
...
@@ -49,7 +49,6 @@ vLLM supports all multimodal deployment patterns. See [Architecture Patterns](RE
| PD Worker |
`--multimodal-worker`
| Prefill + Decode |
| Prefill Worker |
`--multimodal-worker --is-prefill-worker`
| Prefill only |
| Decode Worker |
`--multimodal-decode-worker`
| Decode only |
| Encode+Prefill Worker |
`--multimodal-encode-prefill-worker --is-prefill-worker`
| Combined (Llama 4) |
## Use the Latest Release
...
...
examples/backends/vllm/launch/agg_ec_connector.sh
deleted
100755 → 0
View file @
638d8e68
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set
-e
trap
'echo Cleaning up...; kill 0'
EXIT
# Default values
MODEL_NAME
=
"llava-hf/llava-1.5-7b-hf"
EC_CONNECTOR_BACKEND
=
"DynamoEcConnector"
# Parse command line arguments
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--model
)
MODEL_NAME
=
$2
shift
2
;;
-h
|
--help
)
echo
"Usage:
$0
[OPTIONS]"
echo
""
echo
"Aggregated multimodal serving with ECConnector (ec_both mode)"
echo
""
echo
"This script launches:"
echo
" - Frontend server"
echo
" - Aggregated multimodal worker (ec_both: produces and consumes encoder cache)"
echo
""
echo
"Options:"
echo
" --model <model_name> Specify the VLM model to use (default:
$MODEL_NAME
)"
echo
" -h, --help Show this help message"
echo
""
echo
"Examples:"
echo
"
$0
"
echo
"
$0
--model llava-hf/llava-1.5-7b-hf"
echo
""
exit
0
;;
*
)
echo
"Unknown option:
$1
"
echo
"Use --help for usage information"
exit
1
;;
esac
done
echo
"=================================================="
echo
"Aggregated Multimodal Serving (ECConnector ec_both)"
echo
"=================================================="
echo
"Model:
$MODEL_NAME
"
echo
"ECConnector Backend:
$EC_CONNECTOR_BACKEND
"
echo
"=================================================="
# GPU assignment (override via environment variable)
DYN_WORKER_GPU
=
${
DYN_WORKER_GPU
:-
0
}
# GPU memory utilization
DYN_GPU_MEM
=
${
DYN_GPU_MEM
:-
0
.85
}
# Start frontend
echo
"Starting frontend..."
python
-m
dynamo.frontend &
# Start aggregated multimodal worker (ec_both: produces and consumes encoder cache)
echo
"Starting aggregated multimodal worker (ec_both) on GPU
$DYN_WORKER_GPU
(mem:
$DYN_GPU_MEM
)..."
CUDA_VISIBLE_DEVICES
=
$DYN_WORKER_GPU
python
-m
dynamo.vllm
\
--multimodal-worker
\
--enable-multimodal
\
--model
$MODEL_NAME
\
--enable-mm-embeds
\
--connector
none
\
--enforce-eager
\
--gpu-memory-utilization
$DYN_GPU_MEM
\
--ec-transfer-config
"{
\"
ec_connector
\"
:
\"
$EC_CONNECTOR_BACKEND
\"
,
\"
ec_role
\"
:
\"
ec_both
\"
}"
&
# Wait for all background processes to complete
wait
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment