Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2831bfec
"examples/backends/trtllm/vscode:/vscode.git/clone" did not exist on "3c7ed61d05d153075b46d3a5e278c7d72c76f8e1"
Unverified
Commit
2831bfec
authored
Mar 06, 2026
by
jh-nv
Committed by
GitHub
Mar 06, 2026
Browse files
chore: add mypy typing to vllm (#6858)
parent
0f01e724
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
86 additions
and
56 deletions
+86
-56
components/src/dynamo/vllm/engine_monitor.py
components/src/dynamo/vllm/engine_monitor.py
+1
-1
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+2
-2
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+47
-30
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
.../dynamo/vllm/multimodal_handlers/encode_worker_handler.py
+12
-8
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+2
-2
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
+2
-2
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
.../src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
+3
-0
components/src/dynamo/vllm/omni/base_handler.py
components/src/dynamo/vllm/omni/base_handler.py
+1
-1
components/src/dynamo/vllm/omni/omni_handler.py
components/src/dynamo/vllm/omni/omni_handler.py
+2
-0
components/src/dynamo/vllm/publisher.py
components/src/dynamo/vllm/publisher.py
+7
-7
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
...imodal_handlers/test_vllm_multimodal_pd_worker_handler.py
+1
-1
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
+1
-1
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+5
-1
No files found.
components/src/dynamo/vllm/engine_monitor.py
View file @
2831bfec
...
...
@@ -29,7 +29,7 @@ class VllmEngineMonitor:
self
,
runtime
:
DistributedRuntime
,
engine_client
:
AsyncLLM
,
shutdown_event
:
asyncio
.
Event
=
None
,
shutdown_event
:
asyncio
.
Event
|
None
=
None
,
):
if
not
isinstance
(
runtime
,
DistributedRuntime
):
raise
ValueError
(
...
...
components/src/dynamo/vllm/handlers.py
View file @
2831bfec
...
...
@@ -261,7 +261,7 @@ def build_sampling_params_openai(
return
sampling_params
def
get_dp_range_for_worker
(
vllm_config
:
VllmConfig
)
->
range
:
def
get_dp_range_for_worker
(
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
]
:
"""
Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
...
...
@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC):
self
.
enable_multimodal
=
enable_multimodal
self
.
enable_frontend_decoding
=
enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
self
.
_nixl_connector
=
None
self
.
_nixl_connector
:
nixl_connect
.
Connector
|
None
=
None
self
.
_nixl_connector_lock
=
asyncio
.
Lock
()
# LoRA tracking: name -> LoRAInfo(id, path)
self
.
loaded_loras
:
dict
[
str
,
LoRAInfo
]
=
{}
...
...
components/src/dynamo/vllm/main.py
View file @
2831bfec
...
...
@@ -7,7 +7,7 @@ import logging
import
os
import
tempfile
import
time
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
uvloop
from
prometheus_client
import
REGISTRY
,
CollectorRegistry
,
multiprocess
...
...
@@ -37,17 +37,6 @@ from dynamo.llm import (
fetch_model
,
register_model
,
)
# Optional imports for frontend decoding support
try
:
from
dynamo.llm
import
MediaDecoder
,
MediaFetcher
MEDIA_DECODER_AVAILABLE
=
True
except
ImportError
:
MediaDecoder
=
None
MediaFetcher
=
None
MEDIA_DECODER_AVAILABLE
=
False
from
dynamo.runtime
import
DistributedRuntime
,
Endpoint
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.vllm.worker_factory
import
WorkerFactory
...
...
@@ -63,6 +52,18 @@ from .health_check import (
)
from
.publisher
import
DYNAMO_COMPONENT_REGISTRY
,
StatLoggerFactory
# Optional imports for frontend decoding support
MediaDecoder
:
type
|
None
=
None
MediaFetcher
:
type
|
None
=
None
try
:
from
dynamo.llm
import
MediaDecoder
,
MediaFetcher
MEDIA_DECODER_AVAILABLE
=
True
except
ImportError
:
MediaDecoder
=
None
MediaFetcher
=
None
MEDIA_DECODER_AVAILABLE
=
False
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
shutdown_endpoints
:
list
=
[]
...
...
@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None:
run_headless
(
args
)
async
def
worker
():
async
def
worker
()
->
None
:
config
=
parse_args
()
dump_config
(
config
.
dump_config_to
,
config
)
...
...
@@ -198,7 +199,9 @@ async def worker():
logger
.
debug
(
"Worker function completed, exiting..."
)
def
setup_metrics_collection
(
config
:
Config
,
generate_endpoint
,
logger
):
def
setup_metrics_collection
(
config
:
Config
,
generate_endpoint
:
Endpoint
,
logger
:
logging
.
Logger
)
->
None
:
"""Set up metrics collection for vLLM and LMCache metrics.
In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
...
...
@@ -294,8 +297,9 @@ def setup_kv_event_publisher(
vllm_config
:
VllmConfig
,
consolidator_enabled
:
bool
=
False
,
consolidator_port
:
Optional
[
int
]
=
5558
,
)
->
Optional
[
KvEventPublisher
]:
)
->
Optional
[
list
[
KvEventPublisher
]
]
:
"""
list[KvEventPublisher] | None
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
...
...
@@ -365,7 +369,9 @@ def setup_kv_event_publisher(
return
kv_publishers
if
kv_publishers
else
None
def
setup_vllm_engine
(
config
,
stat_logger
=
None
):
def
setup_vllm_engine
(
config
:
Config
,
stat_logger
:
Optional
[
StatLoggerFactory
]
=
None
)
->
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]:
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly.
...
...
@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None):
async
def
register_vllm_model
(
model_input
:
ModelInput
,
model_type
:
ModelType
,
generate_endpoint
,
generate_endpoint
:
Endpoint
,
config
:
Config
,
engine_client
:
AsyncLLM
,
vllm_config
:
VllmConfig
,
):
)
->
None
:
"""
Helper function to register a vLLM model with runtime configuration.
...
...
@@ -563,6 +569,7 @@ async def register_vllm_model(
"--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
)
assert
MediaDecoder
is
not
None
and
MediaFetcher
is
not
None
media_decoder
=
MediaDecoder
()
media_decoder
.
enable_image
({
"limits"
:
{
"max_alloc"
:
128
*
1024
*
1024
}})
# media_decoder.enable_video({})
...
...
@@ -590,8 +597,10 @@ async def init_prefill(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
checkpoint_restore_engine
=
None
,
):
checkpoint_restore_engine
:
Optional
[
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]
]
=
None
,
)
->
None
:
"""
Instantiate and serve
"""
...
...
@@ -690,7 +699,7 @@ async def init_prefill(
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
handler
.
generate
,
# type: ignore
graceful_shutdown
=
True
,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels
=
[
...
...
@@ -706,10 +715,16 @@ async def init_prefill(
health_check_payload
=
health_check_payload
,
),
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
handler
.
clear_kv_blocks
,
# type: ignore
metrics_labels
=
[
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
),
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
),
)
...
...
@@ -726,8 +741,10 @@ async def init(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
checkpoint_restore_engine
=
None
,
):
checkpoint_restore_engine
:
Optional
[
tuple
[
AsyncLLM
,
VllmConfig
,
Any
,
Any
,
LLMBackendMetrics
]
]
=
None
,
)
->
None
:
"""
Instantiate and serve
"""
...
...
@@ -886,7 +903,7 @@ async def init(
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
handler
.
generate
,
# type: ignore
graceful_shutdown
=
True
,
metrics_labels
=
model_metrics_labels
,
health_check_payload
=
health_check_payload
,
...
...
@@ -926,7 +943,7 @@ async def init(
handler
.
cleanup
()
def
get_engine_cache_info
(
engine
:
AsyncLLM
):
def
get_engine_cache_info
(
engine
:
AsyncLLM
)
->
dict
[
str
,
Any
]
:
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
try
:
...
...
@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM):
async
def
init_omni
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
)
->
None
:
"""Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Supports text-to-text, text-to-image, and text-to-video generation
...
...
@@ -1034,7 +1051,7 @@ async def init_omni(
handler
.
cleanup
()
def
main
():
def
main
()
->
None
:
uvloop
.
run
(
worker
())
...
...
components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py
View file @
2831bfec
...
...
@@ -6,7 +6,7 @@ import logging
import
os
import
time
from
dataclasses
import
dataclass
from
typing
import
AsyncIterator
from
typing
import
Any
,
AsyncIterator
import
torch
from
transformers
import
AutoImageProcessor
...
...
@@ -80,7 +80,7 @@ class EncodeWorkerHandler:
self
.
_connector
:
connect
.
Connector
|
None
=
None
self
.
_accumulated_time
=
0.0
self
.
_processed_requests
=
0
self
.
readables
=
[]
self
.
readables
:
list
[
Any
]
=
[]
self
.
embedding_cache
=
EmbeddingCache
()
if
ENABLE_ENCODER_CACHE
else
None
if
embedding_transfer_mode
==
EmbeddingTransferMode
.
LOCAL
:
self
.
embedding_sender
=
LocalEmbeddingSender
()
...
...
@@ -93,7 +93,7 @@ class EncodeWorkerHandler:
f
"Invalid embedding transfer mode:
{
embedding_transfer_mode
}
"
)
self
.
send_complete_queue
=
asyncio
.
Queue
()
self
.
send_complete_queue
:
asyncio
.
Queue
[
tuple
[
Any
,
Any
]]
=
asyncio
.
Queue
()
self
.
send_complete_checker_task
=
asyncio
.
create_task
(
self
.
check_complete
(
self
.
send_complete_queue
)
)
...
...
@@ -150,7 +150,9 @@ class EncodeWorkerHandler:
with
_nvtx
.
annotate
(
"mm:enc:cache_check"
,
color
=
"cyan"
):
# Before batch process images, check cache first
need_encode_indexes
=
[]
embedding_lists
=
[
None
]
*
len
(
request
.
multimodal_inputs
)
embedding_lists
:
list
[
EmbeddingItem
|
None
]
=
[
None
]
*
len
(
request
.
multimodal_inputs
)
for
idx
in
range
(
len
(
request
.
multimodal_inputs
)):
if
not
request
.
multimodal_inputs
[
idx
].
multimodal_input
.
image_url
:
raise
ValueError
(
"image_url is required for the encode worker."
)
...
...
@@ -251,16 +253,16 @@ class EncodeWorkerHandler:
for
split_idx
,
(
list_idx
,
key
)
in
enumerate
(
need_encode_indexes
):
embedding_lists
[
list_idx
]
=
EmbeddingItem
(
key
,
[
image_grid_thw
[
split_idx
]]
if
image_grid_thw
else
None
,
[
image_grid_thw
[
split_idx
]]
if
image_grid_thw
else
[]
,
splitted_embeddings
[
split_idx
].
unsqueeze
(
0
),
)
# Cache the computed value for future use
if
self
.
embedding_cache
is
not
None
:
self
.
embedding_cache
.
set
(
embedding_lists
[
list_idx
].
key
,
embedding_lists
[
list_idx
].
key
,
# type: ignore
(
embedding_lists
[
list_idx
].
image_grid_thw
,
embedding_lists
[
list_idx
].
embeddings
,
embedding_lists
[
list_idx
].
image_grid_thw
,
# type: ignore
embedding_lists
[
list_idx
].
embeddings
,
# type: ignore
),
)
...
...
@@ -275,6 +277,7 @@ class EncodeWorkerHandler:
)
)
for
embedding_item
in
embedding_lists
if
embedding_item
is
not
None
]
transfer_requests
=
await
asyncio
.
gather
(
*
send_tasks
)
...
...
@@ -282,6 +285,7 @@ class EncodeWorkerHandler:
for
idx
,
item
in
enumerate
(
zip
(
embedding_lists
,
transfer_requests
)):
embedding_item
,
transfer_request
=
item
assert
embedding_item
is
not
None
logger
.
debug
(
f
"
{
embedding_item
.
embeddings
.
shape
}
prepared for transfer."
)
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
View file @
2831bfec
...
...
@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
num_output_tokens_so_far
=
0
async
for
(
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore
[union-attr]
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore
request
.
model_dump_json
()
):
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore
[attr-defined]
output
=
MyRequestOutput
.
model_validate_json
(
decode_response
.
data
())
# type: ignore
yield
self
.
_format_engine_output
(
output
,
num_output_tokens_so_far
)
if
output
.
outputs
:
num_output_tokens_so_far
=
len
(
output
.
outputs
[
0
].
token_ids
)
...
...
components/src/dynamo/vllm/multimodal_utils/encode_utils.py
View file @
2831bfec
...
...
@@ -64,8 +64,8 @@ def get_qwen_image_features(
if
grid_thw
is
None
:
raise
ValueError
(
"grid_thw is not provided"
)
grid_thw
=
grid_thw
.
tolist
()
image_
embed
s
=
vision_encoder
(
pixel_values
,
grid_thw
=
grid_thw
)
return
image_
embed
s
image_
feature
s
=
vision_encoder
(
pixel_values
,
grid_thw
=
grid_thw
)
return
image_
feature
s
pixel_values
=
image_embeds
[
"pixel_values"
].
to
(
vision_encoder
.
device
)
...
...
components/src/dynamo/vllm/multimodal_utils/prefill_worker_utils.py
View file @
2831bfec
...
...
@@ -257,8 +257,10 @@ async def _fetch_embeddings(
)
# ── 3. Update cache (no-op when cache is None) ──────────────
for
(
idx
,
_url
,
key
),
group
in
zip
(
to_fetch
,
groups
,
strict
=
True
):
if
cache
is
not
None
and
key
is
not
None
:
assert
group
.
loaded_embedding
is
not
None
cache
.
set
(
key
,
CachedEmbedding
(
...
...
@@ -301,6 +303,7 @@ async def load_multimodal_embeddings(
multi_modal_data
:
Dict
[
str
,
Any
]
=
defaultdict
(
list
)
for
group
in
groups
:
assert
group
.
loaded_embedding
is
not
None
_accumulate_embeddings
(
multi_modal_data
,
model
,
...
...
components/src/dynamo/vllm/omni/base_handler.py
View file @
2831bfec
...
...
@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id
=
context
.
id
()
logger
.
debug
(
f
"Omni Request ID:
{
request_id
}
"
)
async
for
chunk
in
self
.
_generate_openai_mode
(
request
,
context
,
request_id
):
async
for
chunk
in
self
.
_generate_openai_mode
(
request
,
context
,
request_id
):
# type: ignore
yield
chunk
async
def
_generate_openai_mode
(
...
...
components/src/dynamo/vllm/omni/omni_handler.py
View file @
2831bfec
...
...
@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler):
output
=
NvImagesResponse
(
created
=
int
(
time
.
time
()),
data
=
image_data_list
)
return
output
.
model_dump
()
else
:
return
None
async
def
_format_video_chunk
(
self
,
...
...
components/src/dynamo/vllm/publisher.py
View file @
2831bfec
...
...
@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
raise
# TODO: Remove this and pass as metadata through shared storage
def
set_num_gpu_block
(
self
,
num_blocks
)
:
def
set_num_gpu_block
(
self
,
num_blocks
:
int
)
->
None
:
self
.
num_gpu_block
=
num_blocks
def
record
(
...
...
@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
scheduler_stats
:
SchedulerStats
,
iteration_stats
:
Optional
[
IterationStats
],
engine_idx
:
int
=
0
,
*
args
,
**
kwargs
,
):
*
args
:
object
,
**
kwargs
:
object
,
)
->
None
:
active_decode_blocks
=
int
(
self
.
num_gpu_block
*
scheduler_stats
.
kv_cache_usage
)
self
.
inner
.
publish
(
self
.
dp_rank
,
active_decode_blocks
)
...
...
@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
dp_rank_str
,
scheduler_stats
.
kv_cache_usage
)
def
init_publish
(
self
):
def
init_publish
(
self
)
->
None
:
self
.
inner
.
publish
(
self
.
dp_rank
,
0
)
dp_rank_str
=
str
(
self
.
dp_rank
)
self
.
component_gauges
.
set_total_blocks
(
dp_rank_str
,
0
)
...
...
@@ -112,10 +112,10 @@ class StatLoggerFactory:
return
self
.
create_stat_logger
(
dp_rank
=
dp_rank
)
# TODO Remove once we publish metadata to shared storage
def
set_num_gpu_blocks_all
(
self
,
num_blocks
)
:
def
set_num_gpu_blocks_all
(
self
,
num_blocks
:
int
)
->
None
:
if
self
.
created_logger
:
self
.
created_logger
.
set_num_gpu_block
(
num_blocks
)
def
init_publish
(
self
):
def
init_publish
(
self
)
->
None
:
if
self
.
created_logger
:
self
.
created_logger
.
init_publish
()
components/src/dynamo/vllm/tests/multimodal_handlers/test_vllm_multimodal_pd_worker_handler.py
View file @
2831bfec
...
...
@@ -180,7 +180,7 @@ class TestLoadMultimodalData:
mock_client
=
MagicMock
()
handler
=
_make_handler
(
encode_worker_client
=
mock_client
)
fake_mm_data
=
defaultdict
(
list
,
{
"image"
:
torch
.
randn
(
1
,
10
)})
fake_mm_data
=
defaultdict
(
list
,
{
"image"
:
torch
.
randn
(
1
,
10
)})
# type: ignore
with
patch
.
object
(
mod
,
"load_multimodal_embeddings"
,
...
...
components/src/dynamo/vllm/tests/test_vllm_prompt_embeds.py
View file @
2831bfec
...
...
@@ -29,7 +29,7 @@ def mock_handler():
pass
handler
=
MockHandler
()
handler
.
_decode_prompt_embeds
=
BaseWorkerHandler
.
_decode_prompt_embeds
.
__get__
(
handler
.
_decode_prompt_embeds
=
BaseWorkerHandler
.
_decode_prompt_embeds
.
__get__
(
# type: ignore
handler
)
return
handler
...
...
lib/bindings/python/src/dynamo/_core.pyi
View file @
2831bfec
...
...
@@ -981,7 +981,7 @@ class ModelType:
Audios: ModelType
Videos: ModelType
def __or__(self, other:
"
ModelType
"
) ->
"
ModelType
"
:
def __or__(self, other: ModelType) -> ModelType:
...
def supports_chat(self) -> bool:
...
...
@@ -1091,6 +1091,8 @@ async def register_model(
runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
media_decoder: Optional[MediaDecoder] = None,
media_fetcher: Optional[MediaFetcher] = None,
lora_name: Optional[str] = None,
base_model_path: Optional[str] = None,
) -> None:
...
...
@@ -1649,6 +1651,8 @@ class PlannerDecision:
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
Call VirtualConnectorClient.complete(event) when action is completed.
"""
num_prefill_workers: int
num_decode_workers: int
...
class VirtualConnectorCoordinator:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment