Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
f99b78f0
"lib/bindings/vscode:/vscode.git/clone" did not exist on "177d662f86099602465dfeede907f5709608fdc6"
Unverified
Commit
f99b78f0
authored
Mar 05, 2026
by
jh-nv
Committed by
GitHub
Mar 06, 2026
Browse files
chore: add mypy typing for trtllm (#6860)
parent
8cef50c6
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
87 additions
and
61 deletions
+87
-61
components/src/dynamo/trtllm/args.py
components/src/dynamo/trtllm/args.py
+1
-0
components/src/dynamo/trtllm/backend_args.py
components/src/dynamo/trtllm/backend_args.py
+2
-1
components/src/dynamo/trtllm/encode_helper.py
components/src/dynamo/trtllm/encode_helper.py
+8
-7
components/src/dynamo/trtllm/engine.py
components/src/dynamo/trtllm/engine.py
+9
-8
components/src/dynamo/trtllm/health_check.py
components/src/dynamo/trtllm/health_check.py
+2
-1
components/src/dynamo/trtllm/logits_processing/adapter.py
components/src/dynamo/trtllm/logits_processing/adapter.py
+2
-2
components/src/dynamo/trtllm/multimodal_processor.py
components/src/dynamo/trtllm/multimodal_processor.py
+6
-1
components/src/dynamo/trtllm/publisher.py
components/src/dynamo/trtllm/publisher.py
+35
-31
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
.../src/dynamo/trtllm/request_handlers/aggregated_handler.py
+4
-1
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+6
-5
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+10
-3
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
...nents/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
+2
-1
No files found.
components/src/dynamo/trtllm/args.py
View file @
f99b78f0
...
@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
...
@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class
Config
(
DynamoRuntimeConfig
,
DynamoTrtllmConfig
):
class
Config
(
DynamoRuntimeConfig
,
DynamoTrtllmConfig
):
component
:
str
component
:
str
use_kv_events
:
bool
use_kv_events
:
bool
connector
:
list
[
str
]
# Redeclare for mypy (inherited from DynamoRuntimeConfig)
def
validate
(
self
)
->
None
:
def
validate
(
self
)
->
None
:
DynamoRuntimeConfig
.
validate
(
self
)
DynamoRuntimeConfig
.
validate
(
self
)
...
...
components/src/dynamo/trtllm/backend_args.py
View file @
f99b78f0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
"""Dynamo TRT-LLM backend configuration ArgGroup."""
"""Dynamo TRT-LLM backend configuration ArgGroup."""
import
argparse
from
typing
import
Optional
from
typing
import
Optional
from
tensorrt_llm.llmapi
import
BuildConfig
from
tensorrt_llm.llmapi
import
BuildConfig
...
@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
...
@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class
DynamoTrtllmArgGroup
(
ArgGroup
):
class
DynamoTrtllmArgGroup
(
ArgGroup
):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def
add_arguments
(
self
,
parser
)
->
None
:
def
add_arguments
(
self
,
parser
:
argparse
.
ArgumentParser
)
->
None
:
parser
.
add_argument
(
parser
.
add_argument
(
"--version"
,
"--version"
,
action
=
"version"
,
action
=
"version"
,
...
...
components/src/dynamo/trtllm/encode_helper.py
View file @
f99b78f0
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
asyncio
import
asyncio
import
logging
import
logging
import
threading
import
threading
from
collections.abc
import
AsyncGenerator
from
dataclasses
import
asdict
from
dataclasses
import
asdict
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Union
...
@@ -377,13 +378,13 @@ class EncodeHelper:
...
@@ -377,13 +378,13 @@ class EncodeHelper:
@
staticmethod
@
staticmethod
async
def
process_encode_request
(
async
def
process_encode_request
(
request
:
Dict
[
str
,
Any
],
request
:
Dict
[
str
,
Any
],
multimodal_processor
,
multimodal_processor
:
Any
,
connector
:
Optional
[
nixl_connect
.
Connector
],
connector
:
Optional
[
nixl_connect
.
Connector
],
tokenizer
=
None
,
tokenizer
:
Any
=
None
,
model_dir
=
None
,
model_dir
:
Optional
[
str
]
=
None
,
model_type
=
None
,
model_type
:
Optional
[
str
]
=
None
,
engine
=
None
,
engine
:
Any
=
None
,
):
)
->
AsyncGenerator
[
dict
,
None
]
:
"""
"""
Process an ENCODE-mode request. Dispatches to the appropriate flow.
Process an ENCODE-mode request. Dispatches to the appropriate flow.
...
@@ -447,7 +448,7 @@ class EncodeHelper:
...
@@ -447,7 +448,7 @@ class EncodeHelper:
# if the model's tokenizer_config chat template emits them).
# if the model's tokenizer_config chat template emits them).
token_ids
=
request
.
get
(
"token_ids"
)
token_ids
=
request
.
get
(
"token_ids"
)
async
for
response
in
EncodeHelper
.
_process_full_epd_flow
(
async
for
response
in
EncodeHelper
.
_process_full_epd_flow
(
token_ids
,
token_ids
,
# type: ignore
image_urls
,
image_urls
,
tokenizer
,
tokenizer
,
model_dir
,
model_dir
,
...
...
components/src/dynamo/trtllm/engine.py
View file @
f99b78f0
...
@@ -4,8 +4,9 @@
...
@@ -4,8 +4,9 @@
import
enum
import
enum
import
logging
import
logging
import
time
import
time
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
typing
import
A
syncGenerator
,
Optional
from
typing
import
A
ny
,
Optional
from
tensorrt_llm
import
LLM
,
MultimodalEncoder
from
tensorrt_llm
import
LLM
,
MultimodalEncoder
from
tensorrt_llm.llmapi.llm
import
BaseLLM
from
tensorrt_llm.llmapi.llm
import
BaseLLM
...
@@ -31,9 +32,9 @@ class Backend(str, enum.Enum):
...
@@ -31,9 +32,9 @@ class Backend(str, enum.Enum):
class
TensorRTLLMEngine
:
class
TensorRTLLMEngine
:
def
__init__
(
def
__init__
(
self
,
self
,
engine_args
,
engine_args
:
dict
[
str
,
Any
]
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
):
)
->
None
:
self
.
_llm
:
Optional
[
LLM
]
=
None
self
.
_llm
:
Optional
[
LLM
]
=
None
self
.
disaggregation_mode
=
(
self
.
disaggregation_mode
=
(
disaggregation_mode
disaggregation_mode
...
@@ -63,7 +64,7 @@ class TensorRTLLMEngine:
...
@@ -63,7 +64,7 @@ class TensorRTLLMEngine:
"""Whether the multimodal encoder LLM is initialized."""
"""Whether the multimodal encoder LLM is initialized."""
return
self
.
_llm
is
not
None
return
self
.
_llm
is
not
None
async
def
initialize
(
self
):
async
def
initialize
(
self
)
->
None
:
if
not
self
.
_llm
:
if
not
self
.
_llm
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
ENCODE
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
ENCODE
:
# Initialize the multimodal encoder for full EPD
# Initialize the multimodal encoder for full EPD
...
@@ -75,7 +76,7 @@ class TensorRTLLMEngine:
...
@@ -75,7 +76,7 @@ class TensorRTLLMEngine:
# Skip MultimodalEncoder for architectures that handle vision
# Skip MultimodalEncoder for architectures that handle vision
# encoding inside the main model (e.g. Llama4).
# encoding inside the main model (e.g. Llama4).
if
self
.
_is_unsupported_encoder_arch
(
model
):
if
self
.
_is_unsupported_encoder_arch
(
model
):
# type: ignore
return
return
max_batch_size
=
self
.
engine_args
.
get
(
"max_batch_size"
,
1
)
max_batch_size
=
self
.
engine_args
.
get
(
"max_batch_size"
,
1
)
...
@@ -93,7 +94,7 @@ class TensorRTLLMEngine:
...
@@ -93,7 +94,7 @@ class TensorRTLLMEngine:
# (model path, backend settings, KV cache config, disaggregation settings, etc.)
# (model path, backend settings, KV cache config, disaggregation settings, etc.)
self
.
_llm
=
self
.
_llm_cls
(
**
self
.
engine_args
)
self
.
_llm
=
self
.
_llm_cls
(
**
self
.
engine_args
)
async
def
cleanup
(
self
):
async
def
cleanup
(
self
)
->
None
:
if
self
.
_llm
:
if
self
.
_llm
:
try
:
try
:
self
.
_llm
.
shutdown
()
self
.
_llm
.
shutdown
()
...
@@ -166,9 +167,9 @@ class TensorRTLLMEngine:
...
@@ -166,9 +167,9 @@ class TensorRTLLMEngine:
@
asynccontextmanager
@
asynccontextmanager
async
def
get_llm_engine
(
async
def
get_llm_engine
(
engine_args
,
engine_args
:
dict
[
str
,
Any
]
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
disaggregation_mode
:
Optional
[
DisaggregationMode
]
=
None
,
component_gauges
=
None
,
component_gauges
:
Any
=
None
,
)
->
AsyncGenerator
[
TensorRTLLMEngine
,
None
]:
)
->
AsyncGenerator
[
TensorRTLLMEngine
,
None
]:
"""Get TensorRT-LLM engine instance with load time tracking.
"""Get TensorRT-LLM engine instance with load time tracking.
...
...
components/src/dynamo/trtllm/health_check.py
View file @
f99b78f0
...
@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends.
...
@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends.
"""
"""
import
logging
import
logging
from
typing
import
Any
from
dynamo.health_check
import
HealthCheckPayload
from
dynamo.health_check
import
HealthCheckPayload
...
@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
...
@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class.
Provides TRT-LLM defaults and inherits environment override support from base class.
"""
"""
def
__init__
(
self
,
tokenizer
=
None
)
:
def
__init__
(
self
,
tokenizer
:
Any
=
None
)
->
None
:
"""
"""
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
...
...
components/src/dynamo/trtllm/logits_processing/adapter.py
View file @
f99b78f0
...
@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
...
@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
req_ids
:
int
,
req_ids
:
int
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
ids
:
List
[
List
[
int
]],
ids
:
List
[
List
[
int
]],
stream_ptr
,
stream_ptr
:
Optional
[
int
]
,
client_id
:
Optional
[
int
]
=
None
,
client_id
:
Optional
[
int
]
=
None
,
):
)
->
None
:
"""
"""
TensorRT-LLM logits processor interface.
TensorRT-LLM logits processor interface.
...
...
components/src/dynamo/trtllm/multimodal_processor.py
View file @
f99b78f0
...
@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol):
...
@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol):
the tokenizer's decode method not being found on a generic 'object' type.
the tokenizer's decode method not being found on a generic 'object' type.
"""
"""
def
decode
(
self
,
token_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
token_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
clean_up_tokenization_spaces
:
bool
=
True
,
)
->
str
:
...
...
...
...
components/src/dynamo/trtllm/publisher.py
View file @
f99b78f0
...
@@ -26,9 +26,10 @@ import threading
...
@@ -26,9 +26,10 @@ import threading
import
time
import
time
import
traceback
import
traceback
import
weakref
import
weakref
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
queue
import
Queue
from
queue
import
Queue
from
typing
import
Awaitable
,
Callable
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Awaitable
,
Callable
,
Dict
,
Optional
,
Union
import
msgpack
import
msgpack
import
zmq
import
zmq
...
@@ -87,7 +88,7 @@ class ZmqKvEventPublisher:
...
@@ -87,7 +88,7 @@ class ZmqKvEventPublisher:
Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
"""
"""
def
__init__
(
self
,
zmq_endpoint
:
str
,
kv_block_size
:
int
,
topic
:
str
=
""
):
def
__init__
(
self
,
zmq_endpoint
:
str
,
kv_block_size
:
int
,
topic
:
str
=
""
)
->
None
:
"""
"""
Initialize ZMQ publisher.
Initialize ZMQ publisher.
...
@@ -120,7 +121,7 @@ class ZmqKvEventPublisher:
...
@@ -120,7 +121,7 @@ class ZmqKvEventPublisher:
block_mm_infos
:
Optional
[
list
[
dict
|
None
]]
=
None
,
block_mm_infos
:
Optional
[
list
[
dict
|
None
]]
=
None
,
attention_dp_rank
:
int
=
0
,
attention_dp_rank
:
int
=
0
,
lora_name
:
Optional
[
str
]
=
None
,
lora_name
:
Optional
[
str
]
=
None
,
):
)
->
None
:
"""Publish a BlockStored event.
"""Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter.
Note: event_id is managed internally via self.sequence counter.
...
@@ -133,7 +134,7 @@ class ZmqKvEventPublisher:
...
@@ -133,7 +134,7 @@ class ZmqKvEventPublisher:
# Create event in the same format as vLLM's ZmqEventPublisher:
# Create event in the same format as vLLM's ZmqEventPublisher:
# All blocks should have the same size (kv_block_size)
# All blocks should have the same size (kv_block_size)
event
=
{
event
:
dict
[
str
,
Any
]
=
{
"type"
:
"BlockStored"
,
"type"
:
"BlockStored"
,
"block_hashes"
:
block_hashes_signed
,
"block_hashes"
:
block_hashes_signed
,
"parent_block_hash"
:
parent_hash_signed
,
"parent_block_hash"
:
parent_hash_signed
,
...
@@ -149,7 +150,9 @@ class ZmqKvEventPublisher:
...
@@ -149,7 +150,9 @@ class ZmqKvEventPublisher:
self
.
_publish_event
(
event
,
attention_dp_rank
)
self
.
_publish_event
(
event
,
attention_dp_rank
)
def
publish_removed
(
self
,
block_hashes
:
list
[
int
],
attention_dp_rank
:
int
=
0
):
def
publish_removed
(
self
,
block_hashes
:
list
[
int
],
attention_dp_rank
:
int
=
0
)
->
None
:
"""Publish a BlockRemoved event.
"""Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
Note: event_id is managed internally via self.sequence counter.
...
@@ -164,7 +167,7 @@ class ZmqKvEventPublisher:
...
@@ -164,7 +167,7 @@ class ZmqKvEventPublisher:
self
.
_publish_event
(
event
,
attention_dp_rank
)
self
.
_publish_event
(
event
,
attention_dp_rank
)
def
publish_all_cleared
(
self
):
def
publish_all_cleared
(
self
)
->
None
:
"""Publish an AllBlocksCleared event."""
"""Publish an AllBlocksCleared event."""
event
=
{
"type"
:
"AllBlocksCleared"
}
event
=
{
"type"
:
"AllBlocksCleared"
}
self
.
_publish_event
(
event
)
self
.
_publish_event
(
event
)
...
@@ -197,7 +200,7 @@ class ZmqKvEventPublisher:
...
@@ -197,7 +200,7 @@ class ZmqKvEventPublisher:
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Failed to publish ZMQ event:
{
e
}
"
,
exc_info
=
True
)
logging
.
error
(
f
"Failed to publish ZMQ event:
{
e
}
"
,
exc_info
=
True
)
def
shutdown
(
self
):
def
shutdown
(
self
)
->
None
:
"""Shutdown the ZMQ publisher."""
"""Shutdown the ZMQ publisher."""
if
self
.
socket
:
if
self
.
socket
:
self
.
socket
.
close
()
self
.
socket
.
close
()
...
@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread):
...
@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread):
self
.
_stop_event
=
threading
.
Event
()
self
.
_stop_event
=
threading
.
Event
()
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
):
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
)
->
None
:
self
.
loop
=
loop
self
.
loop
=
loop
def
run
(
self
):
def
run
(
self
)
->
None
:
while
not
self
.
_stop_event
.
is_set
():
while
not
self
.
_stop_event
.
is_set
():
task
:
Optional
[
task
:
Optional
[
Union
[
Callable
[...,
Awaitable
[
bool
]],
weakref
.
WeakMethod
]
Union
[
Callable
[...,
Awaitable
[
bool
]],
weakref
.
WeakMethod
]
...
@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread):
...
@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread):
logging
.
info
(
f
"Thread
{
self
.
name
}
stopped."
)
logging
.
info
(
f
"Thread
{
self
.
name
}
stopped."
)
def
stop
(
self
):
def
stop
(
self
)
->
None
:
self
.
_stop_event
.
set
()
self
.
_stop_event
.
set
()
if
self
.
_current_future
and
not
self
.
_current_future
.
done
():
if
self
.
_current_future
and
not
self
.
_current_future
.
done
():
self
.
_current_future
.
cancel
()
self
.
_current_future
.
cancel
()
...
@@ -297,16 +300,16 @@ class Publisher:
...
@@ -297,16 +300,16 @@ class Publisher:
def
__init__
(
def
__init__
(
self
,
self
,
endpoint
,
endpoint
:
Any
,
engine
,
engine
:
Any
,
worker_id
,
worker_id
:
Any
,
kv_block_size
,
kv_block_size
:
int
,
metrics_labels
,
metrics_labels
:
Any
,
component_gauges
:
LLMBackendMetrics
,
component_gauges
:
LLMBackendMetrics
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
enable_local_indexer
:
bool
=
False
,
enable_local_indexer
:
bool
=
False
,
metrics_collector
=
None
,
metrics_collector
:
Any
=
None
,
):
)
->
None
:
self
.
endpoint
=
endpoint
self
.
endpoint
=
endpoint
self
.
engine
=
engine
self
.
engine
=
engine
self
.
worker_id
=
worker_id
self
.
worker_id
=
worker_id
...
@@ -324,7 +327,7 @@ class Publisher:
...
@@ -324,7 +327,7 @@ class Publisher:
self
.
processing_initial_created_events
=
True
self
.
processing_initial_created_events
=
True
# Needed by the events and metrics publishers
# Needed by the events and metrics publishers
self
.
metrics_publisher
=
None
self
.
metrics_publisher
:
Optional
[
WorkerMetricsPublisher
]
=
None
self
.
kv_event_publishers
:
Optional
[
self
.
kv_event_publishers
:
Optional
[
Dict
[
int
,
KvEventPublisher
]
Dict
[
int
,
KvEventPublisher
]
]
=
None
# One per attention_dp_rank
]
=
None
# One per attention_dp_rank
...
@@ -359,7 +362,7 @@ class Publisher:
...
@@ -359,7 +362,7 @@ class Publisher:
return
return
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
endpoint
)
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
endpoint
)
def
initialize
(
self
):
def
initialize
(
self
)
->
None
:
# Setup the metrics publisher
# Setup the metrics publisher
self
.
metrics_publisher
=
WorkerMetricsPublisher
()
self
.
metrics_publisher
=
WorkerMetricsPublisher
()
self
.
_init_publish_metrics_thread
()
self
.
_init_publish_metrics_thread
()
...
@@ -474,6 +477,7 @@ class Publisher:
...
@@ -474,6 +477,7 @@ class Publisher:
kv_total_blocks
=
stat
[
"kvCacheStats"
][
"maxNumBlocks"
]
kv_total_blocks
=
stat
[
"kvCacheStats"
][
"maxNumBlocks"
]
logging
.
debug
(
f
"Publishing stats: kv_active_blocks:
{
kv_active_blocks
}
"
)
logging
.
debug
(
f
"Publishing stats: kv_active_blocks:
{
kv_active_blocks
}
"
)
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert
self
.
metrics_publisher
is
not
None
self
.
metrics_publisher
.
publish
(
None
,
kv_active_blocks
)
self
.
metrics_publisher
.
publish
(
None
,
kv_active_blocks
)
# Publish Prometheus metrics
# Publish Prometheus metrics
...
@@ -680,7 +684,7 @@ class Publisher:
...
@@ -680,7 +684,7 @@ class Publisher:
elif
data
[
"type"
]
==
"created"
and
self
.
processing_initial_created_events
:
elif
data
[
"type"
]
==
"created"
and
self
.
processing_initial_created_events
:
self
.
update_max_window_size
(
event
)
self
.
update_max_window_size
(
event
)
def
start
(
self
):
def
start
(
self
)
->
None
:
if
(
if
(
self
.
publish_kv_cache_events_thread
self
.
publish_kv_cache_events_thread
and
not
self
.
publish_kv_cache_events_thread
.
is_alive
()
and
not
self
.
publish_kv_cache_events_thread
.
is_alive
()
...
@@ -698,13 +702,13 @@ class Publisher:
...
@@ -698,13 +702,13 @@ class Publisher:
self
.
publish_stats_thread
.
start
()
self
.
publish_stats_thread
.
start
()
logging
.
debug
(
"Started stats thread"
)
logging
.
debug
(
"Started stats thread"
)
def
check_error_queue
(
self
):
def
check_error_queue
(
self
)
->
Optional
[
Exception
]
:
if
not
self
.
error_queue
.
empty
():
if
not
self
.
error_queue
.
empty
():
logging
.
error
(
"Error in publishers error queue"
)
logging
.
error
(
"Error in publishers error queue"
)
return
self
.
error_queue
.
get
()
return
self
.
error_queue
.
get
()
return
None
return
None
async
def
cleanup
(
self
):
async
def
cleanup
(
self
)
->
None
:
"""Cleanup threads and resources"""
"""Cleanup threads and resources"""
self
.
_stop_event
.
set
()
self
.
_stop_event
.
set
()
# Add timeout to prevent hanging
# Add timeout to prevent hanging
...
@@ -729,7 +733,7 @@ class Publisher:
...
@@ -729,7 +733,7 @@ class Publisher:
if
self
.
zmq_kv_event_publisher
:
if
self
.
zmq_kv_event_publisher
:
self
.
zmq_kv_event_publisher
.
shutdown
()
self
.
zmq_kv_event_publisher
.
shutdown
()
def
update_max_window_size
(
self
,
event
)
:
def
update_max_window_size
(
self
,
event
:
dict
)
->
None
:
if
"window_size"
in
event
:
if
"window_size"
in
event
:
window_size
=
event
[
"window_size"
]
window_size
=
event
[
"window_size"
]
if
self
.
max_window_size
is
None
or
window_size
>
self
.
max_window_size
:
if
self
.
max_window_size
is
None
or
window_size
>
self
.
max_window_size
:
...
@@ -744,7 +748,7 @@ class Publisher:
...
@@ -744,7 +748,7 @@ class Publisher:
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# so we can use the "created" event to identify the max_window_size of the global
# so we can use the "created" event to identify the max_window_size of the global
# attention layer in the model engine.
# attention layer in the model engine.
def
should_drop_event
(
self
,
event
)
:
def
should_drop_event
(
self
,
event
:
dict
)
->
bool
:
# There are two cases for KV event filtering:
# There are two cases for KV event filtering:
#
#
# 1. If "window_size" is NOT in the KV event:
# 1. If "window_size" is NOT in the KV event:
...
@@ -768,16 +772,16 @@ class Publisher:
...
@@ -768,16 +772,16 @@ class Publisher:
@
asynccontextmanager
@
asynccontextmanager
async
def
get_publisher
(
async
def
get_publisher
(
endpoint
,
endpoint
:
Any
,
engine
,
engine
:
Any
,
worker_id
,
worker_id
:
Any
,
kv_block_size
,
kv_block_size
:
int
,
metrics_labels
,
metrics_labels
:
Any
,
component_gauges
:
LLMBackendMetrics
,
component_gauges
:
LLMBackendMetrics
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
zmq_endpoint
:
Optional
[
str
]
=
None
,
enable_local_indexer
:
bool
=
False
,
enable_local_indexer
:
bool
=
False
,
metrics_collector
=
None
,
metrics_collector
:
Any
=
None
,
):
)
->
AsyncGenerator
[
Publisher
,
None
]
:
publisher
=
Publisher
(
publisher
=
Publisher
(
endpoint
,
endpoint
,
engine
,
engine
,
...
...
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
View file @
f99b78f0
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import
logging
import
logging
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
typing
import
Optional
from
dynamo._core
import
Context
from
dynamo._core
import
Context
...
@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase):
...
@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
_encoder_cache
=
encoder_cache
self
.
_encoder_cache
=
encoder_cache
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""Generate response, optionally using remote encoder for multimodal."""
"""Generate response, optionally using remote encoder for multimodal."""
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
f99b78f0
...
@@ -18,9 +18,10 @@ import dataclasses
...
@@ -18,9 +18,10 @@ import dataclasses
import
logging
import
logging
import
os
import
os
import
re
import
re
from
collections.abc
import
AsyncGenerator
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Any
,
AsyncGenerator
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
torch
import
torch
from
tensorrt_llm.executor.result
import
GenerationResult
from
tensorrt_llm.executor.result
import
GenerationResult
...
@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler):
...
@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler):
self
.
shutdown_event
=
config
.
shutdown_event
self
.
shutdown_event
=
config
.
shutdown_event
self
.
disable_request_abort
=
config
.
disable_request_abort
self
.
disable_request_abort
=
config
.
disable_request_abort
def
check_error
(
self
,
result
:
dict
):
def
check_error
(
self
,
result
:
dict
)
->
bool
:
"""
"""
Check if there is an error in the result.
Check if there is an error in the result.
"""
"""
...
@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler):
...
@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler):
Raise GeneratorExit if shutdown event is triggered.
Raise GeneratorExit if shutdown event is triggered.
"""
"""
try
:
try
:
cancellation_triggers
=
[
cancellation_triggers
:
list
[
asyncio
.
Future
[
Any
]]
=
[
context
.
async_killed_or_stopped
(),
# Request cancellation
context
.
async_killed_or_stopped
(),
# Request cancellation
]
]
# Shutdown cancellation
# Shutdown cancellation
...
@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler):
...
@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler):
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
"""
"""
disaggregated_params
=
None
disaggregated_params
=
None
epd_metadata
=
{}
epd_metadata
:
dict
[
str
,
Any
]
=
{}
# PREFILL mode: setup context_only params
# PREFILL mode: setup context_only params
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
...
@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler):
...
@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler):
context
:
Context
,
context
:
Context
,
embeddings
:
Optional
[
Union
[
torch
.
Tensor
,
dict
]]
=
None
,
embeddings
:
Optional
[
Union
[
torch
.
Tensor
,
dict
]]
=
None
,
ep_disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
,
ep_disaggregated_params
:
Optional
[
DisaggregatedParams
]
=
None
,
):
)
->
AsyncGenerator
[
dict
,
None
]
:
"""
"""
Generate responses based on the disaggregation mode in the request.
Generate responses based on the disaggregation mode in the request.
...
...
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
f99b78f0
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
typing
import
Optional
from
dynamo._core
import
Context
from
dynamo._core
import
Context
...
@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase):
...
@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase):
self
.
model_type
=
self
.
multimodal_processor
.
model_type
self
.
model_type
=
self
.
multimodal_processor
.
model_type
self
.
tokenizer
=
self
.
multimodal_processor
.
tokenizer
self
.
tokenizer
=
self
.
multimodal_processor
.
tokenizer
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
logging
.
debug
(
f
"New Request ID:
{
context
.
id
()
}
"
)
logging
.
debug
(
f
"New Request ID:
{
context
.
id
()
}
"
)
if
self
.
multimodal_processor
is
None
:
if
self
.
multimodal_processor
is
None
:
logging
.
error
(
"encode handler: no multimodal_processor configured"
)
logging
.
error
(
"encode handler: no multimodal_processor configured"
)
...
@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase):
...
@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase):
encode_response
,
self
.
connector
encode_response
,
self
.
connector
)
)
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""
"""
Prefill worker: process prompt and return disaggregated_params.
Prefill worker: process prompt and return disaggregated_params.
Frontend routes to decode workers automatically.
Frontend routes to decode workers automatically.
...
@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase):
...
@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase):
def
__init__
(
self
,
config
:
RequestHandlerConfig
):
def
__init__
(
self
,
config
:
RequestHandlerConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
):
async
def
generate
(
self
,
request
:
dict
,
context
:
Context
)
->
AsyncGenerator
[
dict
,
None
]:
"""
"""
Decode worker: generate tokens using disaggregated_params from prefill.
Decode worker: generate tokens using disaggregated_params from prefill.
If disaggregated_params is present, prefill was done. Otherwise generate normally.
If disaggregated_params is present, prefill was done. Otherwise generate normally.
...
...
components/src/dynamo/trtllm/tests/test_trtllm_handler_base.py
View file @
f99b78f0
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
asyncio
import
asyncio
import
re
as
re_mod
import
re
as
re_mod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
from
unittest
import
mock
from
unittest
import
mock
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
...
@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice:
...
@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice:
def
test_empty_choice_ignored
(
self
):
def
test_empty_choice_ignored
(
self
):
"""Empty choice list should not produce a regex."""
"""Empty choice list should not produce a regex."""
sampling_params
=
MockSamplingParams
()
sampling_params
=
MockSamplingParams
()
request
=
{
request
:
dict
[
str
,
Any
]
=
{
"sampling_options"
:
{
"sampling_options"
:
{
"guided_decoding"
:
{
"guided_decoding"
:
{
"choice"
:
[],
"choice"
:
[],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment