Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
a9b74dc2
Unverified
Commit
a9b74dc2
authored
Jan 26, 2026
by
jh-nv
Committed by
GitHub
Jan 26, 2026
Browse files
feat: request migration for trtllm (#5599)
parent
66c36996
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
21 deletions
+65
-21
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+11
-4
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+54
-17
No files found.
components/src/dynamo/trtllm/main.py
View file @
a9b74dc2
...
@@ -71,8 +71,9 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
...
@@ -71,8 +71,9 @@ DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
configure_dynamo_logging
()
configure_dynamo_logging
()
async
def
graceful_shutdown
(
runtime
):
async
def
graceful_shutdown
(
runtime
,
shutdown_event
):
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
shutdown_event
.
set
()
runtime
.
shutdown
()
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
logging
.
info
(
"DistributedRuntime shutdown complete"
)
...
@@ -128,6 +129,9 @@ async def worker():
...
@@ -128,6 +129,9 @@ async def worker():
config
=
cmd_line_args
()
config
=
cmd_line_args
()
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
# Create shutdown event
shutdown_event
=
asyncio
.
Event
()
# Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics)
# Enable NATS based on use_kv_events flag (derived from publish_events_and_metrics)
runtime
=
DistributedRuntime
(
runtime
=
DistributedRuntime
(
loop
,
config
.
store_kv
,
config
.
request_plane
,
config
.
use_kv_events
loop
,
config
.
store_kv
,
config
.
request_plane
,
config
.
use_kv_events
...
@@ -136,17 +140,19 @@ async def worker():
...
@@ -136,17 +140,19 @@ async def worker():
# Set up signal handler for graceful shutdown
# Set up signal handler for graceful shutdown
def
signal_handler
():
def
signal_handler
():
# Schedule the shutdown coroutine instead of calling it directly
# Schedule the shutdown coroutine instead of calling it directly
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
asyncio
.
create_task
(
graceful_shutdown
(
runtime
,
shutdown_event
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
await
init
(
runtime
,
config
)
await
init
(
runtime
,
config
,
shutdown_event
)
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Instantiate and serve
Instantiate and serve
"""
"""
...
@@ -425,6 +431,7 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -425,6 +431,7 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime
=
runtime
,
# Pass runtime for graceful shutdown
runtime
=
runtime
,
# Pass runtime for graceful shutdown
metrics_collector
=
metrics_collector
,
metrics_collector
=
metrics_collector
,
kv_block_size
=
config
.
kv_block_size
,
kv_block_size
=
config
.
kv_block_size
,
shutdown_event
=
shutdown_event
,
)
)
# Register the model with runtime config
# Register the model with runtime config
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
a9b74dc2
...
@@ -67,6 +67,7 @@ class RequestHandlerConfig:
...
@@ -67,6 +67,7 @@ class RequestHandlerConfig:
]
=
None
# DistributedRuntime reference for graceful shutdown
]
=
None
# DistributedRuntime reference for graceful shutdown
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
metrics_collector
:
Optional
[
Any
]
=
None
# TensorRT-LLM MetricsCollector
kv_block_size
:
int
=
32
kv_block_size
:
int
=
32
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
class
HandlerBase
:
class
HandlerBase
:
...
@@ -88,6 +89,7 @@ class HandlerBase:
...
@@ -88,6 +89,7 @@ class HandlerBase:
# Store runtime reference for graceful shutdown
# Store runtime reference for graceful shutdown
self
.
runtime
=
config
.
runtime
self
.
runtime
=
config
.
runtime
self
.
kv_block_size
:
int
=
config
.
kv_block_size
self
.
kv_block_size
:
int
=
config
.
kv_block_size
self
.
shutdown_event
=
config
.
shutdown_event
def
check_error
(
self
,
result
:
dict
):
def
check_error
(
self
,
result
:
dict
):
"""
"""
...
@@ -170,18 +172,49 @@ class HandlerBase:
...
@@ -170,18 +172,49 @@ class HandlerBase:
return
log_probs
if
log_probs
else
None
,
top_logprobs
if
top_logprobs
else
None
return
log_probs
if
log_probs
else
None
,
top_logprobs
if
top_logprobs
else
None
async
def
_handle_cancellation
(
async
def
_handle_cancellation
_and_shutdown
(
self
,
generation_result
:
GenerationResult
,
context
:
Context
self
,
generation_result
:
GenerationResult
,
context
:
Context
):
):
"""Background task to handle cancellation by monitoring context state."""
"""
Background task to handle cancellation and shutdown by monitoring both signals.
Returns 'shutdown' if shutdown was triggered, 'cancelled' if cancelled, None otherwise.
"""
try
:
try
:
# Wait asynchronously for cancellation signal instead of polling
cancellation_task
=
context
.
async_killed_or_stopped
()
await
context
.
async_killed_or_stopped
()
# Build list of futures/tasks to wait for
wait_for
=
[
cancellation_task
]
shutdown_task
=
None
if
self
.
shutdown_event
:
# Create task for shutdown monitoring and add to wait list
shutdown_task
=
asyncio
.
create_task
(
self
.
shutdown_event
.
wait
())
wait_for
.
append
(
shutdown_task
)
# Wait for whichever happens first
done
,
pending
=
await
asyncio
.
wait
(
wait_for
,
return_when
=
asyncio
.
FIRST_COMPLETED
,
)
# Cancel the pending task/future
for
task
in
pending
:
task
.
cancel
()
try
:
await
task
except
asyncio
.
CancelledError
:
pass
# Abort the generation
# Abort the generation
generation_result
.
abort
()
generation_result
.
abort
()
logging
.
debug
(
f
"Aborted Request ID:
{
context
.
id
()
}
"
)
logging
.
debug
(
f
"Aborted Request ID:
{
context
.
id
()
}
"
)
# Check which event triggered and return the reason
if
shutdown_task
and
shutdown_task
in
done
:
raise
GeneratorExit
(
"Engine was shut down during generation."
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
# Task was cancelled, which is expected when generation completes
# Task was cancelled, which is expected when generation completes
normally
pass
pass
@
asynccontextmanager
@
asynccontextmanager
...
@@ -189,28 +222,32 @@ class HandlerBase:
...
@@ -189,28 +222,32 @@ class HandlerBase:
self
,
generation_result
:
GenerationResult
,
context
:
Context
self
,
generation_result
:
GenerationResult
,
context
:
Context
)
->
AsyncGenerator
[
asyncio
.
Task
,
None
]:
)
->
AsyncGenerator
[
asyncio
.
Task
,
None
]:
"""
"""
Context manager for monitoring request cancellation.
Context manager for monitoring request cancellation and shutdown.
Automatically creates a background task to monitor for cancellation
and shutdown events, cleaning it up when the context exits.
Automatically creates a background task to monitor for cancellation and
If shutdown event was triggered, raises GeneratorExit on exit.
cleans it up when the context exits.
Yields:
Yields:
asyncio.Task: The
cancellation
monitoring task
asyncio.Task: The monitoring task
"""
"""
cancellation
_task
=
asyncio
.
create_task
(
monitor
_task
=
asyncio
.
create_task
(
self
.
_handle_cancellation
(
generation_result
,
context
)
self
.
_handle_cancellation
_and_shutdown
(
generation_result
,
context
)
)
)
try
:
try
:
yield
cancellation
_task
yield
monitor
_task
finally
:
finally
:
# Clean up the background
cancellation
task
# Clean up the background
monitoring
task
if
not
cancellation
_task
.
done
():
if
not
monitor
_task
.
done
():
cancellation
_task
.
cancel
()
monitor
_task
.
cancel
()
try
:
try
:
await
cancellation
_task
await
monitor
_task
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
pass
pass
else
:
monitor_task
.
result
()
def
_decode_disaggregated_params_from_prefill
(
def
_decode_disaggregated_params_from_prefill
(
self
,
prefill_result
:
dict
self
,
prefill_result
:
dict
...
@@ -653,7 +690,7 @@ class HandlerBase:
...
@@ -653,7 +690,7 @@ class HandlerBase:
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
)
)
# Use the context manager to handle cancellation monitoring
# Use the context manager to handle cancellation
and shutdown
monitoring
async
with
self
.
_cancellation_monitor
(
generation_result
,
context
):
async
with
self
.
_cancellation_monitor
(
generation_result
,
context
):
async
for
res
in
generation_result
:
async
for
res
in
generation_result
:
# TRTLLM engine needs to start generating tokens first before stats
# TRTLLM engine needs to start generating tokens first before stats
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment