Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2831bfec
Unverified
Commit
2831bfec
authored
Mar 06, 2026
by
jh-nv
Committed by
GitHub
Mar 06, 2026
Browse files
chore: add mypy typing to vllm (#6858)
parent
0f01e724
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
86 additions
and
56 deletions
+86
-56
components/src/dynamo/vllm/engine_monitor.py
components/src/dynamo/vllm/engine_monitor.py
+1
-1
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+2
-2
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+47
-30
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
.../dynamo/vllm/multimodal_handlers/encode_worker_handler.py
+12
-8
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+2
-2
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
+2
-2
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
.../src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
+3
-0
components/src/dynamo/vllm/omni/base_handler.py
components/src/dynamo/vllm/omni/base_handler.py
+1
-1
components/src/dynamo/vllm/omni/omni_handler.py
components/src/dynamo/vllm/omni/omni_handler.py
+2
-0
components/src/dynamo/vllm/publisher.py
components/src/dynamo/vllm/publisher.py
+7
-7
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
...imodal_handlers/test_vllm_multimodal_pd_worker_handler.py
+1
-1
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
+1
-1
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+5
-1
No files found.
components/src/dynamo/vllm/engine_monitor.py
View file @
2831bfec
...
@@ -29,7 +29,7 @@ class VllmEngineMonitor:
...
@@ -29,7 +29,7 @@ class VllmEngineMonitor:
self
,
self
,
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
engine_client
:
AsyncLLM
,
engine_client
:
AsyncLLM
,
shutdown_event
:
asyncio
.
Event
=
None
,
shutdown_event
:
asyncio
.
Event
|
None
=
None
,
):
):
if
not
isinstance
(
runtime
,
DistributedRuntime
):
if
not
isinstance
(
runtime
,
DistributedRuntime
):
raise
ValueError
(
raise
ValueError
(
...
...
components/src/dynamo/vllm/handlers.py
View file @
2831bfec
...
@@ -261,7 +261,7 @@ def build_sampling_params_openai(
...
@@ -261,7 +261,7 @@ def build_sampling_params_openai(
return
sampling_params
return
sampling_params
def
get_dp_range_for_worker
(
vllm_config
:
VllmConfig
)
->
range
:
def
get_dp_range_for_worker
(
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
]
:
"""
"""
Get the global DP rank range that this worker is responsible for based on vLLM config.
Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
...
@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC):
...
@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC):
self
.
enable_multimodal
=
enable_multimodal
self
.
enable_multimodal
=
enable_multimodal
self
.
enable_frontend_decoding
=
enable_frontend_decoding
self
.
enable_frontend_decoding
=
enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
# NIXL connector for frontend decoding - lazy initialized
self
.
_nixl_connector
=
None
self
.
_nixl_connector
:
nixl_connect
.
Connector
|
None
=
None
self
.
_nixl_connector_lock
=
asyncio
.
Lock
()
self
.
_nixl_connector_lock
=
asyncio
.
Lock
()
# LoRA tracking: name -> LoRAInfo(id, path)
# LoRA tracking: name -> LoRAInfo(id, path)
self
.
loaded_loras
:
dict
[
str
,
LoRAInfo
]
=
{}
self
.
loaded_loras
:
dict
[
str
,
LoRAInfo
]
=
{}
...
...
components/src/dynamo/vllm/main.py
View file @
2831bfec
...
@@ -7,7 +7,7 @@ import logging
...
@@ -7,7 +7,7 @@ import logging
import
os
import
os
import
tempfile
import
tempfile
import
time
import
time
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
uvloop
import
uvloop
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
...
@@ -37,17 +37,6 @@ from dynamo.llm import (
...
@@ -37,17 +37,6 @@ from dynamo.llm import (
fetch_model
,
fetch_model
,
register_model
,
register_model
,
)
)
# Optional imports for frontend decoding support
try
:
from
dynamo.llm
import
MediaDecoder
,
MediaFetcher
MEDIA_DECODER_AVAILABLE
=
True
except
ImportError
:
MediaDecoder
=
None
MediaFetcher
=
None
MEDIA_DECODER_AVAILABLE
=
False
from
dynamo.runtime
import
DistributedRuntime
,
Endpoint
from
dynamo.runtime
import
DistributedRuntime
,
Endpoint
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.vllm.worker_factory
import
WorkerFactory
from
dynamo.vllm.worker_factory
import
WorkerFactory
...
@@ -63,6 +52,18 @@ from .health_check import (
...
@@ -63,6 +52,18 @@ from .health_check import (
)
)
from
.publisher
import
DYNAMO_COMPONENT_REGISTRY
,
StatLoggerFactory
from
.publisher
import
DYNAMO_COMPONENT_REGISTRY
,
StatLoggerFactory
# Optional imports for frontend decoding support
MediaDecoder
:
type
|
None
=
None
MediaFetcher
:
type
|
None
=
None
try
:
from
dynamo.llm
import
MediaDecoder
,
MediaFetcher
MEDIA_DECODER_AVAILABLE
=
True
except
ImportError
:
MediaDecoder
=
None
MediaFetcher
=
None
MEDIA_DECODER_AVAILABLE
=
False
configure_dynamo_logging
()
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
shutdown_endpoints
:
list
=
[]
shutdown_endpoints
:
list
=
[]
...
@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None:
...
@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None:
run_headless
(
args
)
run_headless
(
args
)
async
def
worker
():
async
def
worker
()
->
None
:
config
=
parse_args
()
config
=
parse_args
()
dump_config
(
config
.
dump_config_to
,
config
)
dump_config
(
config
.
dump_config_to
,
config
)
...
@@ -198,7 +199,9 @@ async def worker():
...
@@ -198,7 +199,9 @@ async def worker():
logger
.
debug
(
"Worker function completed, exiting..."
)
logger
.
debug
(
"Worker function completed, exiting..."
)
def
setup_metrics_collection
(
config
:
Config
,
generate_endpoint
,
logger
):
def
setup_metrics_collection
(
config
:
Config
,
generate_endpoint
:
Endpoint
,
logger
:
logging
.
Logger
)
->
None
:
"""Set up metrics collection for vLLM and LMCache metrics.
"""Set up metrics collection for vLLM and LMCache metrics.
In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
...
@@ -294,8 +297,9 @@ def setup_kv_event_publisher(
...
@@ -294,8 +297,9 @@ def setup_kv_event_publisher(
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
consolidator_enabled
:
bool
=
False
,
consolidator_enabled
:
bool
=
False
,
consolidator_port
:
Optional
[
int
]
=
5558
,
consolidator_port
:
Optional
[
int
]
=
5558
,
)
->
Optional
[
KvEventPublisher
]:
)
->
Optional
[
list
[
KvEventPublisher
]
]
:
"""
"""
list[KvEventPublisher] | None
Set up KV event publishers for prefix caching if enabled.
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
Args:
...
@@ -365,7 +369,9 @@ def setup_kv_event_publisher(
...
@@ -365,7 +369,9 @@ def setup_kv_event_publisher(
return
kv_publishers
if
kv_publishers
else
None
return
kv_publishers
if
kv_publishers
else
None
def
setup_vllm_engine
(
config
,
stat_logger
=
None
):
def
setup_vllm_engine
(
config
:
Config
,
stat_logger
:
Optional
[
StatLoggerFactory
]
=
None
)
->
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]:
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly.
# ourselves to avoid this and handle cleanup properly.
...
@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None):
...
@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None):
async
def
register_vllm_model
(
async
def
register_vllm_model
(
model_input
:
ModelInput
,
model_input
:
ModelInput
,
model_type
:
ModelType
,
model_type
:
ModelType
,
generate_endpoint
,
generate_endpoint
:
Endpoint
,
config
:
Config
,
config
:
Config
,
engine_client
:
AsyncLLM
,
engine_client
:
AsyncLLM
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
):
)
->
None
:
"""
"""
Helper function to register a vLLM model with runtime configuration.
Helper function to register a vLLM model with runtime configuration.
...
@@ -563,6 +569,7 @@ async def register_vllm_model(
...
@@ -563,6 +569,7 @@ async def register_vllm_model(
"--frontend-decoding requires MediaDecoder support. "
"--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
)
)
assert
MediaDecoder
is
not
None
and
MediaFetcher
is
not
None
media_decoder
=
MediaDecoder
()
media_decoder
=
MediaDecoder
()
media_decoder
.
enable_image
({
"limits"
:
{
"max_alloc"
:
128
*
1024
*
1024
}})
media_decoder
.
enable_image
({
"limits"
:
{
"max_alloc"
:
128
*
1024
*
1024
}})
# media_decoder.enable_video({})
# media_decoder.enable_video({})
...
@@ -590,8 +597,10 @@ async def init_prefill(
...
@@ -590,8 +597,10 @@ async def init_prefill(
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
checkpoint_restore_engine
=
None
,
checkpoint_restore_engine
:
Optional
[
):
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]
]
=
None
,
)
->
None
:
"""
"""
Instantiate and serve
Instantiate and serve
"""
"""
...
@@ -690,7 +699,7 @@ async def init_prefill(
...
@@ -690,7 +699,7 @@ async def init_prefill(
# (long-term reason): prefill engine should pull from a global queue so there is
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
# only a few in-flight requests that can be quickly finished
generate_endpoint
.
serve_endpoint
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
handler
.
generate
,
# type: ignore
graceful_shutdown
=
True
,
graceful_shutdown
=
True
,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels
=
[
metrics_labels
=
[
...
@@ -706,10 +715,16 @@ async def init_prefill(
...
@@ -706,10 +715,16 @@ async def init_prefill(
health_check_payload
=
health_check_payload
,
health_check_payload
=
health_check_payload
,
),
),
clear_endpoint
.
serve_endpoint
(
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
handler
.
clear_kv_blocks
,
# type: ignore
metrics_labels
=
[
metrics_labels
=
[
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
),
(
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
),
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
],
),
),
)
)
...
@@ -726,8 +741,10 @@ async def init(
...
@@ -726,8 +741,10 @@ async def init(
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
checkpoint_restore_engine
=
None
,
checkpoint_restore_engine
:
Optional
[
):
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]
]
=
None
,
)
->
None
:
"""
"""
Instantiate and serve
Instantiate and serve
"""
"""
...
@@ -886,7 +903,7 @@ async def init(
...
@@ -886,7 +903,7 @@ async def init(
# for decode, we want to transfer the in-flight requests to other decode engines,
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
# because waiting them to finish can take a long time for long OSLs
generate_endpoint
.
serve_endpoint
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
handler
.
generate
,
# type: ignore
graceful_shutdown
=
True
,
graceful_shutdown
=
True
,
metrics_labels
=
model_metrics_labels
,
metrics_labels
=
model_metrics_labels
,
health_check_payload
=
health_check_payload
,
health_check_payload
=
health_check_payload
,
...
@@ -926,7 +943,7 @@ async def init(
...
@@ -926,7 +943,7 @@ async def init(
handler
.
cleanup
()
handler
.
cleanup
()
def
get_engine_cache_info
(
engine
:
AsyncLLM
):
def
get_engine_cache_info
(
engine
:
AsyncLLM
)
->
dict
[
str
,
Any
]
:
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
try
:
try
:
...
@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM):
...
@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM):
async
def
init_omni
(
async
def
init_omni
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
)
->
None
:
"""Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
"""Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Supports text-to-text, text-to-image, and text-to-video generation
Supports text-to-text, text-to-image, and text-to-video generation
...
@@ -1034,7 +1051,7 @@ async def init_omni(
...
@@ -1034,7 +1051,7 @@ async def init_omni(
handler
.
cleanup
()
handler
.
cleanup
()
def
main
():
def
main
()
->
None
:
uvloop
.
run
(
worker
())
uvloop
.
run
(
worker
())
...
...
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
View file @
2831bfec
...
@@ -6,7 +6,7 @@ import logging
...
@@ -6,7 +6,7 @@ import logging
import
os
import
os
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
AsyncIterator
from
typing
import
Any
,
AsyncIterator
import
torch
import
torch
from
transformers
import
AutoImageProcessor
from
transformers
import
AutoImageProcessor
...
@@ -80,7 +80,7 @@ class EncodeWorkerHandler:
...
@@ -80,7 +80,7 @@ class EncodeWorkerHandler:
self
.
_connector
:
connect
.
Connector
|
None
=
None
self
.
_connector
:
connect
.
Connector
|
None
=
None
self
.
_accumulated_time
=
0.0
self
.
_accumulated_time
=
0.0
self
.
_processed_requests
=
0
self
.
_processed_requests
=
0
self
.
readables
=
[]
self
.
readables
:
list
[
Any
]
=
[]
self
.
embedding_cache
=
EmbeddingCache
()
if
ENABLE_ENCODER_CACHE
else
None
self
.
embedding_cache
=
EmbeddingCache
()
if
ENABLE_ENCODER_CACHE
else
None
if
embedding_transfer_mode
==
EmbeddingTransferMode
.
LOCAL
:
if
embedding_transfer_mode
==
EmbeddingTransferMode
.
LOCAL
:
self
.
embedding_sender
=
LocalEmbeddingSender
()
self
.
embedding_sender
=
LocalEmbeddingSender
()
...
@@ -93,7 +93,7 @@ class EncodeWorkerHandler:
...
@@ -93,7 +93,7 @@ class EncodeWorkerHandler:
f
"Invalid embedding transfer mode:
{
embedding_transfer_mode
}
"
f
"Invalid embedding transfer mode:
{
embedding_transfer_mode
}
"
)
)
self
.
send_complete_queue
=
asyncio
.
Queue
()
self
.
send_complete_queue
:
asyncio
.
Queue
[
tuple
[
Any
,
Any
]]
=
asyncio
.
Queue
()
self
.
send_complete_checker_task
=
asyncio
.
create_task
(
self
.
send_complete_checker_task
=
asyncio
.
create_task
(
self
.
check_complete
(
self
.
send_complete_queue
)
self
.
check_complete
(
self
.
send_complete_queue
)
)
)
...
@@ -150,7 +150,9 @@ class EncodeWorkerHandler:
...
@@ -150,7 +150,9 @@ class EncodeWorkerHandler:
with
_nvtx
.
annotate
(
"mm:enc:cache_check"
,
color
=
"cyan"
):
with
_nvtx
.
annotate
(
"mm:enc:cache_check"
,
color
=
"cyan"
):
# Before batch process images, check cache first
# Before batch process images, check cache first
need_encode_indexes
=
[]
need_encode_indexes
=
[]
embedding_lists
=
[
None
]
*
len
(
request
.
multimodal_inputs
)
embedding_lists
:
list
[
EmbeddingItem
|
None
]
=
[
None
]
*
len
(
request
.
multimodal_inputs
)
for
idx
in
range
(
len
(
request
.
multimodal_inputs
)):
for
idx
in
range
(
len
(
request
.
multimodal_inputs
)):
if
not
request
.
multimodal_inputs
[
idx
].
multimodal_input
.
image_url
:
if
not
request
.
multimodal_inputs
[
idx
].
multimodal_input
.
image_url
:
raise
ValueError
(
"image_url is required for the encode worker."
)
raise
ValueError
(
"image_url is required for the encode worker."
)
...
@@ -251,16 +253,16 @@ class EncodeWorkerHandler:
...
@@ -251,16 +253,16 @@ class EncodeWorkerHandler:
for
split_idx
,
(
list_idx
,
key
)
in
enumerate
(
need_encode_indexes
):
for
split_idx
,
(
list_idx
,
key
)
in
enumerate
(
need_encode_indexes
):
embedding_lists
[
list_idx
]
=
EmbeddingItem
(
embedding_lists
[
list_idx
]
=
EmbeddingItem
(
key
,
key
,
[
image_grid_thw
[
split_idx
]]
if
image_grid_thw
else
None
,
[
image_grid_thw
[
split_idx
]]
if
image_grid_thw
else
[]
,
splitted_embeddings
[
split_idx
].
unsqueeze
(
0
),
splitted_embeddings
[
split_idx
].
unsqueeze
(
0
),
)
)
# Cache the computed value for future use
# Cache the computed value for future use
if
self
.
embedding_cache
is
not
None
:
if
self
.
embedding_cache
is
not
None
:
self
.
embedding_cache
.
set
(
self
.
embedding_cache
.
set
(
embedding_lists
[
list_idx
].
key
,
embedding_lists
[
list_idx
].
key
,
# type: ignore
(
(
embedding_lists
[
list_idx
].
image_grid_thw
,
embedding_lists
[
list_idx
].
image_grid_thw
,
# type: ignore
embedding_lists
[
list_idx
].
embeddings
,
embedding_lists
[
list_idx
].
embeddings
,
# type: ignore
),
),
)
)
...
@@ -275,6 +277,7 @@ class EncodeWorkerHandler:
...
@@ -275,6 +277,7 @@ class EncodeWorkerHandler:
)
)
)
)
for
embedding_item
in
embedding_lists
for
embedding_item
in
embedding_lists
if
embedding_item
is
not
None
]
]
transfer_requests
=
await
asyncio
.
gather
(
*
send_tasks
)
transfer_requests
=
await
asyncio
.
gather
(
*
send_tasks
)
...
@@ -282,6 +285,7 @@ class EncodeWorkerHandler:
...
@@ -282,6 +285,7 @@ class EncodeWorkerHandler:
for
idx
,
item
in
enumerate
(
zip
(
embedding_lists
,
transfer_requests
)):
for
idx
,
item
in
enumerate
(
zip
(
embedding_lists
,
transfer_requests
)):
embedding_item
,
transfer_request
=
item
embedding_item
,
transfer_request
=
item
assert
embedding_item
is
not
None
logger
.
debug
(
logger
.
debug
(
f
"
{
embedding_item
.
embeddings
.
shape
}
prepared for transfer."
f
"
{
embedding_item
.
embeddings
.
shape
}
prepared for transfer."
)
)
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
View file @
2831bfec
...
@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
num_output_tokens_so_far
=
0
num_output_tokens_so_far
=
0
async
for
(
async
for
(
decode_response
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore
[union-attr]
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore
request
.
model_dump_json
()
request
.
model_dump_json
()
):
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore
[attr-defined]
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore
yield
self
.
_format_engine_output
(
output
,
num_output_tokens_so_far
)
yield
self
.
_format_engine_output
(
output
,
num_output_tokens_so_far
)
if
output
.
outputs
:
if
output
.
outputs
:
num_output_tokens_so_far
=
len
(
output
.
outputs
[
0
].
token_ids
)
num_output_tokens_so_far
=
len
(
output
.
outputs
[
0
].
token_ids
)
...
...
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
View file @
2831bfec
...
@@ -64,8 +64,8 @@ def get_qwen_image_features(
...
@@ -64,8 +64,8 @@ def get_qwen_image_features(
if
grid_thw
is
None
:
if
grid_thw
is
None
:
raise
ValueError
(
"grid_thw is not provided"
)
raise
ValueError
(
"grid_thw is not provided"
)
grid_thw
=
grid_thw
.
tolist
()
grid_thw
=
grid_thw
.
tolist
()
image_
embed
s
=
vision_encoder
(
pixel_values
,
grid_thw
=
grid_thw
)
image_
feature
s
=
vision_encoder
(
pixel_values
,
grid_thw
=
grid_thw
)
return
image_
embed
s
return
image_
feature
s
pixel_values
=
image_embeds
[
"pixel_values"
].
to
(
vision_encoder
.
device
)
pixel_values
=
image_embeds
[
"pixel_values"
].
to
(
vision_encoder
.
device
)
...
...
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
View file @
2831bfec
...
@@ -257,8 +257,10 @@ async def _fetch_embeddings(
...
@@ -257,8 +257,10 @@ async def _fetch_embeddings(
)
)
# ── 3. Update cache (no-op when cache is None) ──────────────
# ── 3. Update cache (no-op when cache is None) ──────────────
for
(
idx
,
_url
,
key
),
group
in
zip
(
to_fetch
,
groups
,
strict
=
True
):
for
(
idx
,
_url
,
key
),
group
in
zip
(
to_fetch
,
groups
,
strict
=
True
):
if
cache
is
not
None
and
key
is
not
None
:
if
cache
is
not
None
and
key
is
not
None
:
assert
group
.
loaded_embedding
is
not
None
cache
.
set
(
cache
.
set
(
key
,
key
,
CachedEmbedding
(
CachedEmbedding
(
...
@@ -301,6 +303,7 @@ async def load_multimodal_embeddings(
...
@@ -301,6 +303,7 @@ async def load_multimodal_embeddings(
multi_modal_data
:
Dict
[
str
,
Any
]
=
defaultdict
(
list
)
multi_modal_data
:
Dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
group
in
groups
:
for
group
in
groups
:
assert
group
.
loaded_embedding
is
not
None
_accumulate_embeddings
(
_accumulate_embeddings
(
multi_modal_data
,
multi_modal_data
,
model
,
model
,
...
...
components/src/dynamo/vllm/omni/base_handler.py
View file @
2831bfec
...
@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler):
...
@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id
=
context
.
id
()
request_id
=
context
.
id
()
logger
.
debug
(
f
"Omni Request ID:
{
request_id
}
"
)
logger
.
debug
(
f
"Omni Request ID:
{
request_id
}
"
)
async
for
chunk
in
self
.
_generate_openai_mode
(
request
,
context
,
request_id
):
async
for
chunk
in
self
.
_generate_openai_mode
(
request
,
context
,
request_id
):
# type: ignore
yield
chunk
yield
chunk
async
def
_generate_openai_mode
(
async
def
_generate_openai_mode
(
...
...
components/src/dynamo/vllm/omni/omni_handler.py
View file @
2831bfec
...
@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler):
...
@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler):
output
=
NvImagesResponse
(
created
=
int
(
time
.
time
()),
data
=
image_data_list
)
output
=
NvImagesResponse
(
created
=
int
(
time
.
time
()),
data
=
image_data_list
)
return
output
.
model_dump
()
return
output
.
model_dump
()
else
:
return
None
async
def
_format_video_chunk
(
async
def
_format_video_chunk
(
self
,
self
,
...
...
components/src/dynamo/vllm/publisher.py
View file @
2831bfec
...
@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
...
@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
raise
raise
# TODO: Remove this and pass as metadata through shared storage
# TODO: Remove this and pass as metadata through shared storage
def
set_num_gpu_block
(
self
,
num_blocks
)
:
def
set_num_gpu_block
(
self
,
num_blocks
:
int
)
->
None
:
self
.
num_gpu_block
=
num_blocks
self
.
num_gpu_block
=
num_blocks
def
record
(
def
record
(
...
@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
...
@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
scheduler_stats
:
SchedulerStats
,
scheduler_stats
:
SchedulerStats
,
iteration_stats
:
Optional
[
IterationStats
],
iteration_stats
:
Optional
[
IterationStats
],
engine_idx
:
int
=
0
,
engine_idx
:
int
=
0
,
*
args
,
*
args
:
object
,
**
kwargs
,
**
kwargs
:
object
,
):
)
->
None
:
active_decode_blocks
=
int
(
self
.
num_gpu_block
*
scheduler_stats
.
kv_cache_usage
)
active_decode_blocks
=
int
(
self
.
num_gpu_block
*
scheduler_stats
.
kv_cache_usage
)
self
.
inner
.
publish
(
self
.
dp_rank
,
active_decode_blocks
)
self
.
inner
.
publish
(
self
.
dp_rank
,
active_decode_blocks
)
...
@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
...
@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
dp_rank_str
,
scheduler_stats
.
kv_cache_usage
dp_rank_str
,
scheduler_stats
.
kv_cache_usage
)
)
def
init_publish
(
self
):
def
init_publish
(
self
)
->
None
:
self
.
inner
.
publish
(
self
.
dp_rank
,
0
)
self
.
inner
.
publish
(
self
.
dp_rank
,
0
)
dp_rank_str
=
str
(
self
.
dp_rank
)
dp_rank_str
=
str
(
self
.
dp_rank
)
self
.
component_gauges
.
set_total_blocks
(
dp_rank_str
,
0
)
self
.
component_gauges
.
set_total_blocks
(
dp_rank_str
,
0
)
...
@@ -112,10 +112,10 @@ class StatLoggerFactory:
...
@@ -112,10 +112,10 @@ class StatLoggerFactory:
return
self
.
create_stat_logger
(
dp_rank
=
dp_rank
)
return
self
.
create_stat_logger
(
dp_rank
=
dp_rank
)
# TODO Remove once we publish metadata to shared storage
# TODO Remove once we publish metadata to shared storage
def
set_num_gpu_blocks_all
(
self
,
num_blocks
)
:
def
set_num_gpu_blocks_all
(
self
,
num_blocks
:
int
)
->
None
:
if
self
.
created_logger
:
if
self
.
created_logger
:
self
.
created_logger
.
set_num_gpu_block
(
num_blocks
)
self
.
created_logger
.
set_num_gpu_block
(
num_blocks
)
def
init_publish
(
self
):
def
init_publish
(
self
)
->
None
:
if
self
.
created_logger
:
if
self
.
created_logger
:
self
.
created_logger
.
init_publish
()
self
.
created_logger
.
init_publish
()
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
View file @
2831bfec
...
@@ -180,7 +180,7 @@ class TestLoadMultimodalData:
...
@@ -180,7 +180,7 @@ class TestLoadMultimodalData:
mock_client
=
MagicMock
()
mock_client
=
MagicMock
()
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
fake_mm_data
=
defaultdict
(
list
,
{
"image"
:
torch
.
randn
(
1
,
10
)})
fake_mm_data
=
defaultdict
(
list
,
{
"image"
:
torch
.
randn
(
1
,
10
)})
# type: ignore
with
patch
.
object
(
with
patch
.
object
(
mod
,
mod
,
"load_multimodal_embeddings"
,
"load_multimodal_embeddings"
,
...
...
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
View file @
2831bfec
...
@@ -29,7 +29,7 @@ def mock_handler():
...
@@ -29,7 +29,7 @@ def mock_handler():
pass
pass
handler
=
MockHandler
()
handler
=
MockHandler
()
handler
.
_decode_prompt_embeds
=
BaseWorkerHandler
.
_decode_prompt_embeds
.
__get__
(
handler
.
_decode_prompt_embeds
=
BaseWorkerHandler
.
_decode_prompt_embeds
.
__get__
(
# type: ignore
handler
handler
)
)
return
handler
return
handler
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
2831bfec
...
@@ -981,7 +981,7 @@ class ModelType:
...
@@ -981,7 +981,7 @@ class ModelType:
Audios: ModelType
Audios: ModelType
Videos: ModelType
Videos: ModelType
def __or__(self, other:
"
ModelType
"
) ->
"
ModelType
"
:
def __or__(self, other: ModelType) -> ModelType:
...
...
def supports_chat(self) -> bool:
def supports_chat(self) -> bool:
...
@@ -1091,6 +1091,8 @@ async def register_model(
...
@@ -1091,6 +1091,8 @@ async def register_model(
runtime_config: Optional[ModelRuntimeConfig] = None,
runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
custom_template_path: Optional[str] = None,
media_decoder: Optional[MediaDecoder] = None,
media_fetcher: Optional[MediaFetcher] = None,
lora_name: Optional[str] = None,
lora_name: Optional[str] = None,
base_model_path: Optional[str] = None,
base_model_path: Optional[str] = None,
) -> None:
) -> None:
...
@@ -1649,6 +1651,8 @@ class PlannerDecision:
...
@@ -1649,6 +1651,8 @@ class PlannerDecision:
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
Call VirtualConnectorClient.complete(event) when action is completed.
Call VirtualConnectorClient.complete(event) when action is completed.
"""
"""
num_prefill_workers: int
num_decode_workers: int
...
...
class VirtualConnectorCoordinator:
class VirtualConnectorCoordinator:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment