Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f99b78f0
Unverified
Commit
f99b78f0
authored
Mar 05, 2026
by
jh-nv
Committed by
GitHub
Mar 06, 2026
Browse files
chore: add mypy typing for trtllm (#6860)
parent
8cef50c6
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
87 additions
and
61 deletions
+87
-61
components/src/dynamo/trtllm/args.py
components/src/dynamo/trtllm/args.py
+1
-0
components/src/dynamo/trtllm/backend_args.py
components/src/dynamo/trtllm/backend_args.py
+2
-1
components/src/dynamo/trtllm/encode_helper.py
components/src/dynamo/trtllm/encode_helper.py
+8
-7
components/src/dynamo/trtllm/engine.py
components/src/dynamo/trtllm/engine.py
+9
-8
components/src/dynamo/trtllm/health_check.py
components/src/dynamo/trtllm/health_check.py
+2
-1
components/src/dynamo/trtllm/logits_processing/adapter.py
components/src/dynamo/trtllm/logits_processing/adapter.py
+2
-2
components/src/dynamo/trtllm/multimodal_processor.py
components/src/dynamo/trtllm/multimodal_processor.py
+6
-1
components/src/dynamo/trtllm/publisher.py
components/src/dynamo/trtllm/publisher.py
+35
-31
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
.../src/dynamo/trtllm/request_handlers/aggregated_handler.py
+4
-1
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+6
-5
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+10
-3
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
...nents/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
+2
-1
No files found.
components/src/dynamo/trtllm/args.py
View file @
f99b78f0
...
...
@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class
Config
(
DynamoRuntimeConfig
,
DynamoTrtllmConfig
):
component
:
str
use_kv_events
:
bool
connector
:
list
[
str
]
# Redeclare for mypy (inherited from DynamoRuntimeConfig)
def
validate
(
self
)
->
None
:
DynamoRuntimeConfig
.
validate
(
self
)
...
...
components/src/dynamo/trtllm/backend_args.py
View file @
f99b78f0
...
...
@@ -3,6 +3,7 @@
"""Dynamo TRT-LLM backend configuration ArgGroup."""
import
argparse
from
typing
import
Optional
from
tensorrt_llm.llmapi
import
BuildConfig
...
...
@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class
DynamoTrtllmArgGroup
(
ArgGroup
):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def
add_arguments
(
self
,
parser
)
->
None
:
def
add_arguments
(
self
,
parser
:
argparse
.
ArgumentParser
)
->
None
:
parser
.
add_argument
(
"--version"
,
action
=
"version"
,
...
...
components/src/dynamo/trtllm/encode_helper.py
View file @
f99b78f0
...
...
@@ -4,6 +4,7 @@
import
asyncio
import
logging
import
threading
from
collections.abc
import
AsyncGenerator
from
dataclasses
import
asdict
from
typing
import
Any
,
Dict
,
Optional
,
Union
...
...
@@ -377,13 +378,13 @@ class EncodeHelper:
@
staticmethod
async
def
process_encode_request
(
request
:
Dict
[
str
,
Any
],
multimodal_processor
,
multimodal_processor
:
Any
,
connector
:
Optional
[
nixl_connect
.
Connector
],
tokenizer
=
None
,
model_dir
=
None
,
model_type
=
None
,
engine
=
None
,
):
tokenizer
:
Any
=
None
,
model_dir
:
Optional
[
str
]
=
None
,
model_type
:
Optional
[
str
]
=
None
,
engine
:
Any
=
None
,
)
->
AsyncGenerator
[
dict
,
None
]
:
"""
Process an ENCODE-mode request. Dispatches to the appropriate flow.
...
...
@@ -447,7 +448,7 @@ class EncodeHelper:
# if the model's tokenizer_config chat template emits them).
token_ids
=
request
.
get
(
"token_ids"
)
async
for
response
in
EncodeHelper
.
_process_full_epd_flow
(
token_ids
,
token_ids
,
# type: ignore
image_urls
,
tokenizer
,
model_dir
,
...
...
components/src/dynamo/trtllm/engine.py
View file @
f99b78f0
...
...
@@ -4,8 +4,9 @@
import
enum
import
logging
import
time
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
typing
import
A
syncGenerator
,
Optional
from
typing
import
A
ny
,
Optional
from
tensorrt_llm
import
LLM
,
MultimodalEncoder
from
tensorrt_llm.llmapi.llm
import
BaseLLM
...
...
@@ -31,9 +32,9 @@ class Backend(str, enum.Enum):
class
TensorRTLLMEngine
:
def
__init__
(
self
,
engine_args
,
engine_args
:
dict
[
str
,
Any
]
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
):
)
->
None
:
self
.
_llm
:
Optional
[
LLM
]
=
None
self
.
disaggregation_mode
=
(
disaggregation_mode
...
...
@@ -63,7 +64,7 @@ class TensorRTLLMEngine:
"""Whether the multimodal encoder LLM is initialized."""
return
self
.
_llm
is
not
None
async
def
initialize
(
self
):
async
def
initialize
(
self
)
->
None
:
if
not
self
.
_llm
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
ENCODE
:
# Initialize the multimodal encoder for full EPD
...
...
@@ -75,7 +76,7 @@ class TensorRTLLMEngine:
# Skip MultimodalEncoder for architectures that handle vision
# encoding inside the main model (e.g. Llama4).
if
self
.
_is_unsupported_encoder_arch
(
model
):
if
self
.
_is_unsupported_encoder_arch
(
model
):
# type: ignore
return
max_batch_size
=
self
.
engine_args
.
get
(
"max_batch_size"
,
1
)
...
...
@@ -93,7 +94,7 @@ class TensorRTLLMEngine:
# (model path, backend settings, KV cache config, disaggregation settings, etc.)
self
.
_llm
=
self
.
_llm_cls
(
**
self
.
engine_args
)
async
def
cleanup
(
self
):
async
def
cleanup
(
self
)
->
None
:
if
self
.
_llm
:
try
:
self
.
_llm
.
shutdown
()
...
...
@@ -166,9 +167,9 @@ class TensorRTLLMEngine:
@
asynccontextmanager
async
def
get_llm_engine
(
engine_args
,
engine_args
:
dict
[
str
,
Any
]
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
component_gauges
=
None
,
component_gauges
:
Any
=
None
,
)
->
AsyncGenerator
[
TensorRTLLMEngine
,
None
]:
"""Get TensorRT-LLM engine instance with load time tracking.
...
...
components/src/dynamo/trtllm/health_check.py
View file @
f99b78f0
...
...
@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends.
"""
import
logging
from
typing
import
Any
from
dynamo.health_check
import
HealthCheckPayload
...
...
@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class.
"""
def
__init__
(
self
,
tokenizer
=
None
)
:
def
__init__
(
self
,
tokenizer
:
Any
=
None
)
->
None
:
"""
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
...
...
components/src/dynamo/trtllm/logits_processing/adapter.py
View file @
f99b78f0
...
...
@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
req_ids
:
int
,
logits
:
torch
.
Tensor
,
ids
:
List
[
List
[
int
]],
stream_ptr
,
stream_ptr
:
Optional
[
int
]
,
client_id
:
Optional
[
int
]
=
None
,
):
)
->
None
:
"""
TensorRT-LLM logits processor interface.
...
...
components/src/dynamo/trtllm/multimodal_processor.py
View file @
f99b78f0
...
...
@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol):
the tokenizer's decode method not being found on a generic 'object' type.
"""
def
decode
(
self
,
token_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
token_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
clean_up_tokenization_spaces
:
bool
=
True
,
)
->
str
:
...
...
...
components/src/dynamo/trtllm/publisher.py
View file @
f99b78f0
...
...
@@ -26,9 +26,10 @@ import threading
import
time
import
traceback
import
weakref
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
queue
import
Queue
from
typing
import
Awaitable
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Awaitable
,
Callable
,
Dict
,
Optional
,
Union
import
msgpack
import
zmq
...
...
@@ -87,7 +88,7 @@ class ZmqKvEventPublisher:
Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
"""
def
__init__
(
self
,
zmq_endpoint
:
str
,
kv_block_size
:
int
,
topic
:
str
=
""
):
def
__init__
(
self
,
zmq_endpoint
:
str
,
kv_block_size
:
int
,
topic
:
str
=
""
)
->
None
:
"""
Initialize ZMQ publisher.
...
...
@@ -120,7 +121,7 @@ class ZmqKvEventPublisher:
block_mm_infos
:
Optional
[
list
[
dict
|
None
]]
=
None
,
attention_dp_rank
:
int
=
0
,
lora_name
:
Optional
[
str
]
=
None
,
):
)
->
None
:
"""Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter.
...
...
@@ -133,7 +134,7 @@ class ZmqKvEventPublisher:
# Create event in the same format as vLLM's ZmqEventPublisher:
# All blocks should have the same size (kv_block_size)
event
=
{
event
:
dict
[
str
,
Any
]
=
{
"type"
:
"BlockStored"
,
"block_hashes"
:
block_hashes_signed
,
"parent_block_hash"
:
parent_hash_signed
,
...
...
@@ -149,7 +150,9 @@ class ZmqKvEventPublisher:
self
.
_publish_event
(
event
,
attention_dp_rank
)
def
publish_removed
(
self
,
block_hashes
:
list
[
int
],
attention_dp_rank
:
int
=
0
):
def
publish_removed
(
self
,
block_hashes
:
list
[
int
],
attention_dp_rank
:
int
=
0
)
->
None
:
"""Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
...
...
@@ -164,7 +167,7 @@ class ZmqKvEventPublisher:
self
.
_publish_event
(
event
,
attention_dp_rank
)
def
publish_all_cleared
(
self
):
def
publish_all_cleared
(
self
)
->
None
:
"""Publish an AllBlocksCleared event."""
event
=
{
"type"
:
"AllBlocksCleared"
}
self
.
_publish_event
(
event
)
...
...
@@ -197,7 +200,7 @@ class ZmqKvEventPublisher:
except
Exception
as
e
:
logging
.
error
(
f
"Failed to publish ZMQ event:
{
e
}
"
,
exc_info
=
True
)
def
shutdown
(
self
):
def
shutdown
(
self
)
->
None
:
"""Shutdown the ZMQ publisher."""
if
self
.
socket
:
self
.
socket
.
close
()
...
...
@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread):
self
.
_stop_event
=
threading
.
Event
()
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
):
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
)
->
None
:
self
.
loop
=
loop
def
run
(
self
):
def
run
(
self
)
->
None
:
while
not
self
.
_stop_event
.
is_set
():
task
:
Optional
[
Union
[
Callable
[...,
Awaitable
[
bool
]],
weakref
.
WeakMethod
]
...
...
@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread):
logging
.
info
(
f
"Thread
{
self
.
name
}
stopped."
)
def
stop
(
self
):
def
stop
(
self
)
->
None
:
self
.
_stop_event
.
set
()
if
self
.
_current_future
and
not
self
.
_current_future
.
done
():
self
.
_current_future
.
cancel
()
...
...
@@ -297,16 +300,16 @@ class Publisher:
def
__init__
(
self
,
endpoint
,
engine
,
worker_id
,
kv_block_size
,
metrics_labels
,
endpoint
:
Any
,
engine
:
Any
,
worker_id
:
Any
,
kv_block_size
:
int
,
metrics_labels
:
Any
,
component_gauges
:
LLMBackendMetrics
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
enable_local_indexer
:
bool
=
False
,
metrics_collector
=
None
,
):
metrics_collector
:
Any
=
None
,
)
->
None
:
self
.
endpoint
=
endpoint
self
.
engine
=
engine
self
.
worker_id
=
worker_id
...
...
@@ -324,7 +327,7 @@ class Publisher:
self
.
processing_initial_created_events
=
True
# Needed by the events and metrics publishers
self
.
metrics_publisher
=
None
self
.
metrics_publisher
:
Optional
[
WorkerMetricsPublisher
]
=
None
self
.
kv_event_publishers
:
Optional
[
Dict
[
int
,
KvEventPublisher
]
]
=
None
# One per attention_dp_rank
...
...
@@ -359,7 +362,7 @@ class Publisher:
return
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
endpoint
)
def
initialize
(
self
):
def
initialize
(
self
)
->
None
:
# Setup the metrics publisher
self
.
metrics_publisher
=
WorkerMetricsPublisher
()
self
.
_init_publish_metrics_thread
()
...
...
@@ -474,6 +477,7 @@ class Publisher:
kv_total_blocks
=
stat
[
"kvCacheStats"
][
"maxNumBlocks"
]
logging
.
debug
(
f
"Publishing stats: kv_active_blocks:
{
kv_active_blocks
}
"
)
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert
self
.
metrics_publisher
is
not
None
self
.
metrics_publisher
.
publish
(
None
,
kv_active_blocks
)
# Publish Prometheus metrics
...
...
@@ -680,7 +684,7 @@ class Publisher:
elif
data
[
"type"
]
==
"created"
and
self
.
processing_initial_created_events
:
self
.
update_max_window_size
(
event
)
def
start
(
self
):
def
start
(
self
)
->
None
:
if
(
self
.
publish_kv_cache_events_thread
and
not
self
.
publish_kv_cache_events_thread
.
is_alive
()
...
...
@@ -698,13 +702,13 @@ class Publisher:
self
.
publish_stats_thread
.
start
()
logging
.
debug
(
"Started stats thread"
)
def
check_error_queue
(
self
):
def
check_error_queue
(
self
)
->
Optional
[
Exception
]
:
if
not
self
.
error_queue
.
empty
():
logging
.
error
(
"Error in publishers error queue"
)
return
self
.
error_queue
.
get
()
return
None
async
def
cleanup
(
self
):
async
def
cleanup
(
self
)
->
None
:
"""Cleanup threads and resources"""
self
.
_stop_event
.
set
()
# Add timeout to prevent hanging
...
...
@@ -729,7 +733,7 @@ class Publisher:
if
self
.
zmq_kv_event_publisher
:
self
.
zmq_kv_event_publisher
.
shutdown
()
def
update_max_window_size
(
self
,
event
)
:
def
update_max_window_size
(
self
,
event
:
dict
)
->
None
:
if
"window_size"
in
event
:
window_size
=
event
[
"window_size"
]
if
self
.
max_window_size
is
None
or
window_size
>
self
.
max_window_size
:
...
...
@@ -744,7 +748,7 @@ class Publisher:
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# so we can use the "created" event to identify the max_window_size of the global
# attention layer in the model engine.
def
should_drop_event
(
self
,
event
)
:
def
should_drop_event
(
self
,
event
:
dict
)
->
bool
:
# There are two cases for KV event filtering:
#
# 1. If "window_size" is NOT in the KV event:
...
...
@@ -768,16 +772,16 @@ class Publisher:
@
asynccontextmanager
async
def
get_publisher
(
endpoint
,
engine
,
worker_id
,
kv_block_size
,
metrics_labels
,
endpoint
:
Any
,
engine
:
Any
,
worker_id
:
Any
,
kv_block_size
:
int
,
metrics_labels
:
Any
,
component_gauges
:
LLMBackendMetrics
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
enable_local_indexer
:
bool
=
False
,
metrics_collector
=
None
,
):
metrics_collector
:
Any
=
None
,
)
->
AsyncGenerator
[
Publisher
,
None
]
:
publisher
=
Publisher
(
endpoint
,
engine
,
...
...
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
View file @
f99b78f0
...
...
@@ -4,6 +4,7 @@
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import
logging
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
dynamo._core
import
Context
...
...
@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase):
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""Generate response, optionally using remote encoder for multimodal."""
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
f99b78f0
...
...
@@ -18,9 +18,10 @@ import dataclasses
import
logging
import
os
import
re
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Any
,
AsyncGenerator
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
from
tensorrt_llm.executor.result
import
GenerationResult
...
...
@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler):
self
.
shutdown_event
=
config
.
shutdown_event
self
.
disable_request_abort
=
config
.
disable_request_abort
def
check_error
(
self
,
result
:
dict
):
def
check_error
(
self
,
result
:
dict
)
->
bool
:
"""
Check if there is an error in the result.
"""
...
...
@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler):
Raise GeneratorExit if shutdown event is triggered.
"""
try
:
cancellation_triggers
=
[
cancellation_triggers
:
list
[
asyncio
.
Future
[
Any
]]
=
[
context
.
async_killed_or_stopped
(),
# Request cancellation
]
# Shutdown cancellation
...
...
@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler):
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
"""
disaggregated_params
=
None
epd_metadata
=
{}
epd_metadata
:
dict
[
str
,
Any
]
=
{}
# PREFILL mode: setup context_only params
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
...
...
@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler):
context
:
Context
,
embeddings
:
Optional
[
Union
[
torch
.
Tensor
,
dict
]]
=
None
,
ep_disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
,
):
)
->
AsyncGenerator
[
dict
,
None
]
:
"""
Generate responses based on the disaggregation mode in the request.
...
...
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
f99b78f0
...
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import
logging
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
dynamo._core
import
Context
...
...
@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase):
self
.
model_type
=
self
.
multimodal_processor
.
model_type
self
.
tokenizer
=
self
.
multimodal_processor
.
tokenizer
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
logging
.
debug
(
f
"New Request ID:
{
context
.
id
()
}
"
)
if
self
.
multimodal_processor
is
None
:
logging
.
error
(
"encode handler: no multimodal_processor configured"
)
...
...
@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase):
encode_response
,
self
.
connector
)
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""
Prefill worker: process prompt and return disaggregated_params.
Frontend routes to decode workers automatically.
...
...
@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase):
def
__init__
(
self
,
config
:
RequestHandlerConfig
):
super
().
__init__
(
config
)
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""
Decode worker: generate tokens using disaggregated_params from prefill.
If disaggregated_params is present, prefill was done. Otherwise generate normally.
...
...
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
View file @
f99b78f0
...
...
@@ -4,6 +4,7 @@
import
asyncio
import
re
as
re_mod
from
dataclasses
import
dataclass
from
typing
import
Any
from
unittest
import
mock
from
unittest.mock
import
MagicMock
...
...
@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice:
def
test_empty_choice_ignored
(
self
):
"""Empty choice list should not produce a regex."""
sampling_params
=
MockSamplingParams
()
request
=
{
request
:
dict
[
str
,
Any
]
=
{
"sampling_options"
:
{
"guided_decoding"
:
{
"choice"
:
[],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment