Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9acaa8d1
Unverified
Commit
9acaa8d1
authored
May 27, 2025
by
Tanmay Verma
Committed by
GitHub
May 27, 2025
Browse files
feat: Add metrics and event publishers (#1192)
parent
b8272a98
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
473 additions
and
78 deletions
+473
-78
launch/dynamo-run/src/subprocess/trtllm/engine.py
launch/dynamo-run/src/subprocess/trtllm/engine.py
+52
-0
launch/dynamo-run/src/subprocess/trtllm/publishers.py
launch/dynamo-run/src/subprocess/trtllm/publishers.py
+344
-0
launch/dynamo-run/src/subprocess/trtllm_inc.py
launch/dynamo-run/src/subprocess/trtllm_inc.py
+77
-78
No files found.
launch/dynamo-run/src/subprocess/trtllm/engine.py
0 → 100644
View file @
9acaa8d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
contextlib
import
asynccontextmanager
from
typing
import
AsyncGenerator
,
Optional
from
tensorrt_llm
import
LLM
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
class
TensorRTLLMEngine
:
def
__init__
(
self
,
engine_args
):
self
.
engine_args
=
engine_args
self
.
_llm
:
Optional
[
LLM
]
=
None
async
def
initialize
(
self
):
if
not
self
.
_llm
:
model
=
self
.
engine_args
.
pop
(
"model"
)
self
.
_llm
=
LLM
(
model
=
model
,
**
self
.
engine_args
,
)
async
def
cleanup
(
self
):
if
self
.
_llm
:
try
:
self
.
_llm
.
shutdown
()
except
Exception
as
e
:
logging
.
error
(
f
"Error during cleanup:
{
e
}
"
)
finally
:
self
.
_llm
=
None
@
property
def
llm
(
self
):
if
not
self
.
_llm
:
raise
RuntimeError
(
"Engine not initialized"
)
return
self
.
_llm
@
asynccontextmanager
async
def
get_llm_engine
(
engine_args
)
->
AsyncGenerator
[
TensorRTLLMEngine
,
None
]:
engine
=
TensorRTLLMEngine
(
engine_args
)
try
:
await
engine
.
initialize
()
yield
engine
except
Exception
as
e
:
logging
.
error
(
f
"Error in engine context:
{
e
}
"
)
raise
finally
:
await
engine
.
cleanup
()
launch/dynamo-run/src/subprocess/trtllm/publishers.py
0 → 100644
View file @
9acaa8d1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
concurrent.futures
import
logging
import
threading
import
traceback
import
weakref
from
queue
import
Queue
from
typing
import
Callable
,
Optional
,
Union
from
dynamo.llm
import
KvEventPublisher
,
KvMetricsPublisher
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
class
ManagedThread
(
threading
.
Thread
):
"""
A thread that runs a task and handles errors.
"""
def
__init__
(
self
,
task
:
Optional
[
Union
[
Callable
[...,
bool
],
weakref
.
WeakMethod
]],
error_queue
:
Optional
[
Queue
]
=
None
,
name
:
Optional
[
str
]
=
None
,
loop
:
Optional
[
asyncio
.
AbstractEventLoop
]
=
None
,
**
kwargs
,
):
super
().
__init__
(
name
=
name
)
self
.
task
=
task
self
.
error_queue
=
error_queue
self
.
kwargs
=
kwargs
self
.
loop
=
loop
self
.
daemon
=
True
self
.
_current_future
:
Optional
[
concurrent
.
futures
.
Future
]
=
None
self
.
_stop_event
=
threading
.
Event
()
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
):
self
.
loop
=
loop
def
run
(
self
):
while
not
self
.
_stop_event
.
is_set
():
task
:
Optional
[
Union
[
Callable
[...,
bool
],
weakref
.
WeakMethod
]]
=
self
.
task
if
isinstance
(
task
,
weakref
.
WeakMethod
):
task
=
task
()
if
task
is
None
:
# Normally, this should not happen.
logging
.
warning
(
"WeakMethod is expired."
)
break
if
task
is
None
:
break
try
:
if
self
.
loop
is
None
:
logging
.
error
(
"[ManagedThread] Loop not initialized!"
)
break
self
.
_current_future
=
asyncio
.
run_coroutine_threadsafe
(
task
(
**
self
.
kwargs
),
self
.
loop
)
_
=
self
.
_current_future
.
result
()
except
(
asyncio
.
CancelledError
,
concurrent
.
futures
.
CancelledError
):
logging
.
debug
(
f
"Thread
{
self
.
name
}
was cancelled"
)
break
except
Exception
as
e
:
logging
.
error
(
f
"Error in thread
{
self
.
name
}
:
{
e
}
\n
{
traceback
.
format_exc
()
}
"
)
if
self
.
error_queue
is
not
None
:
self
.
error_queue
.
put
(
e
)
logging
.
info
(
f
"Thread
{
self
.
name
}
stopped."
)
def
stop
(
self
):
self
.
_stop_event
.
set
()
if
self
.
_current_future
and
not
self
.
_current_future
.
done
():
self
.
_current_future
.
cancel
()
class
Publishers
:
"""
A class to retrieve stats and kv cache events from TRTLLM engine and publish them to the metrics and events publishers.
"""
def
__init__
(
self
,
component
,
engine
,
kv_listener
,
worker_id
,
kv_block_size
):
self
.
component
=
component
self
.
engine
=
engine
self
.
kv_listener
=
kv_listener
self
.
worker_id
=
worker_id
self
.
kv_block_size
=
kv_block_size
# Needed by the events and metrics publishers
self
.
metrics_publisher
=
None
self
.
kv_event_publisher
=
None
self
.
publish_kv_cache_events_thread
=
None
self
.
publish_stats_thread
=
None
# A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
# It is used to prevent sending remove event to kv router since partial blocks are not stored.
self
.
partial_block_hashes
=
set
()
self
.
error_queue
:
Queue
=
Queue
()
self
.
_stop_event
=
threading
.
Event
()
self
.
_setup
()
async
def
_create_metrics_publisher_endpoint
(
self
):
logging
.
debug
(
"Creating metrics publisher endpoint"
)
if
self
.
metrics_publisher
is
None
:
logging
.
error
(
"KV metrics publisher not initialized!"
)
return
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
component
)
def
_setup
(
self
):
# Setup the metrics publisher
self
.
metrics_publisher
=
KvMetricsPublisher
()
self
.
_init_publish_metrics_thread
()
task
=
asyncio
.
create_task
(
self
.
_create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logging
.
debug
(
"metrics publisher endpoint created"
)
)
# Setup the kv cache events publisher
self
.
kv_event_publisher
=
KvEventPublisher
(
self
.
kv_listener
,
self
.
worker_id
,
self
.
kv_block_size
)
self
.
_init_publish_kv_cache_events_thread
()
def
_init_publish_metrics_thread
(
self
):
# Need to publish stats once so that worker can be selected.
# Publishing some dummy values...
request_active_slots
=
0
request_total_slots
=
4
kv_active_block
=
0
kv_total_blocks
=
4
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
if
self
.
metrics_publisher
is
None
:
logging
.
error
(
"KV metrics publisher not initialized!"
)
return
self
.
metrics_publisher
.
publish
(
request_active_slots
,
request_total_slots
,
kv_active_block
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
)
# Prepare threads for publishing stats but don't start them yet.
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self
.
publish_stats_thread
=
ManagedThread
(
self
.
_publish_stats_task
,
error_queue
=
self
.
error_queue
,
name
=
"publish_stats_thread"
,
)
def
_init_publish_kv_cache_events_thread
(
self
):
if
self
.
kv_event_publisher
is
None
:
logging
.
error
(
"KV event publisher not initialized!"
)
return
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self
.
publish_kv_cache_events_thread
=
ManagedThread
(
self
.
_publish_kv_cache_events_task
,
error_queue
=
self
.
error_queue
,
name
=
"publish_kv_cache_events_thread"
,
)
async
def
_publish_stats_task
(
self
):
"""
Publish stats to the metrics publisher.
"""
if
self
.
engine
is
None
:
logging
.
error
(
"LLM engine not initialized!"
)
return
if
self
.
metrics_publisher
is
None
:
logging
.
error
(
"KV metrics publisher not initialized!"
)
return
False
stats
=
self
.
engine
.
llm
.
get_stats_async
(
timeout
=
5
)
async
for
stat
in
stats
:
request_active_slots
=
stat
[
"numActiveRequests"
]
request_total_slots
=
stat
[
"maxNumActiveRequests"
]
kv_active_block
=
stat
[
"kvCacheStats"
][
"usedNumBlocks"
]
kv_total_blocks
=
stat
[
"kvCacheStats"
][
"maxNumBlocks"
]
reused_blocks
=
stat
[
"kvCacheStats"
][
"reusedBlocks"
]
freeNumBlocks
=
stat
[
"kvCacheStats"
][
"freeNumBlocks"
]
allocTotalBlocks
=
stat
[
"kvCacheStats"
][
"allocTotalBlocks"
]
allocNewBlocks
=
stat
[
"kvCacheStats"
][
"allocNewBlocks"
]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting
=
(
stat
[
"numQueuedRequests"
]
+
stat
[
"inflightBatchingStats"
][
"numPausedRequests"
]
)
gpu_cache_usage_perc
=
allocTotalBlocks
/
kv_total_blocks
gpu_prefix_cache_hit_rate
=
stat
[
"kvCacheStats"
][
"cacheHitRate"
]
logging
.
debug
(
f
"Publishing stats: request_active_slots:
{
request_active_slots
}
, request_total_slots:
{
request_total_slots
}
, kv_active_block:
{
kv_active_block
}
, kv_total_blocks:
{
kv_total_blocks
}
, num_requests_waiting:
{
num_requests_waiting
}
, reused_blocks:
{
reused_blocks
}
, freeNumBlocks:
{
freeNumBlocks
}
, allocTotalBlocks:
{
allocTotalBlocks
}
, allocNewBlocks:
{
allocNewBlocks
}
, gpu_cache_usage_perc:
{
gpu_cache_usage_perc
}
, gpu_prefix_cache_hit_rate:
{
gpu_prefix_cache_hit_rate
}
"
)
self
.
metrics_publisher
.
publish
(
request_active_slots
,
request_total_slots
,
kv_active_block
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
)
return
True
async
def
_publish_kv_cache_events_task
(
self
):
"""
Publish kv cache events to the events publisher.
"""
if
self
.
engine
is
None
:
logging
.
error
(
"LLM engine not initialized!"
)
return
if
self
.
kv_event_publisher
is
None
:
logging
.
error
(
"KV event publisher not initialized!"
)
return
events
=
self
.
engine
.
llm
.
get_kv_cache_events_async
(
timeout
=
5
)
async
for
event
in
events
:
logging
.
debug
(
f
"KV cache event received:
{
event
}
"
)
event_id
=
event
[
"event_id"
]
data
=
event
[
"data"
]
if
data
[
"type"
]
==
"stored"
:
parent_hash
=
data
[
"parent_hash"
]
token_ids
=
[]
num_block_tokens
=
[]
block_hashes
=
[]
for
block
in
data
[
"blocks"
]:
token_num_in_block
=
len
(
block
[
"tokens"
])
block_hash
=
block
[
"block_hash"
]
if
token_num_in_block
>
self
.
kv_block_size
:
logging
.
error
(
f
"Block
{
block_hash
}
contains
{
token_num_in_block
}
tokens, which is greater than kv_block_size
{
self
.
kv_block_size
}
"
)
return
if
token_num_in_block
<
self
.
kv_block_size
:
logging
.
debug
(
f
"Early stop when block
{
block_hash
}
containing
{
token_num_in_block
}
tokens not equal to kv_block_size
{
self
.
kv_block_size
}
"
)
self
.
partial_block_hashes
.
add
(
block_hash
)
break
num_block_tokens
.
append
(
token_num_in_block
)
block_hashes
.
append
(
block_hash
)
for
token
in
block
[
"tokens"
]:
token_ids
.
append
(
int
(
token
[
"token_id"
]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id
=
data
.
get
(
"lora_id"
,
0
)
logging
.
debug
(
f
"publish stored event: event_id:
{
event_id
}
, token_ids:
{
token_ids
}
, num_block_tokens:
{
num_block_tokens
}
, block_hashes:
{
block_hashes
}
, lora_id:
{
lora_id
}
, parent_hash:
{
parent_hash
}
"
)
self
.
kv_event_publisher
.
publish_stored
(
event_id
,
token_ids
,
num_block_tokens
,
block_hashes
,
lora_id
,
parent_hash
,
)
elif
data
[
"type"
]
==
"removed"
:
block_hashes
=
[]
for
block_hash
in
data
[
"block_hashes"
]:
if
block_hash
in
self
.
partial_block_hashes
:
logging
.
debug
(
f
"Skipping removing block hash
{
block_hash
}
since it is a partial block"
)
self
.
partial_block_hashes
.
remove
(
block_hash
)
continue
block_hashes
.
append
(
block_hash
)
logging
.
debug
(
f
"publish removed event: event_id:
{
event_id
}
, block_hashes:
{
block_hashes
}
"
)
self
.
kv_event_publisher
.
publish_removed
(
event_id
,
block_hashes
)
return
True
def
start_publish_threads
(
self
):
if
(
self
.
publish_kv_cache_events_thread
and
not
self
.
publish_kv_cache_events_thread
.
is_alive
()
):
# REVISIT
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self
.
_stats_loop
=
asyncio
.
get_running_loop
()
self
.
publish_kv_cache_events_thread
.
set_loop
(
self
.
_stats_loop
)
self
.
publish_kv_cache_events_thread
.
start
()
logging
.
debug
(
"Started kv cache events thread"
)
if
self
.
publish_stats_thread
and
not
self
.
publish_stats_thread
.
is_alive
():
self
.
_stats_loop
=
asyncio
.
get_running_loop
()
self
.
publish_stats_thread
.
set_loop
(
self
.
_stats_loop
)
self
.
publish_stats_thread
.
start
()
logging
.
debug
(
"Started stats thread"
)
def
check_error_queue
(
self
):
if
not
self
.
error_queue
.
empty
():
logging
.
error
(
"Error in publishers error queue"
)
return
self
.
error_queue
.
get
()
return
None
async
def
cleanup
(
self
):
"""Cleanup threads and resources"""
self
.
_stop_event
.
set
()
# Add timeout to prevent hanging
cleanup_timeout
=
5.0
# seconds
if
self
.
publish_stats_thread
and
self
.
publish_stats_thread
.
is_alive
():
self
.
publish_stats_thread
.
stop
()
self
.
publish_stats_thread
.
join
(
timeout
=
cleanup_timeout
)
if
self
.
publish_stats_thread
.
is_alive
():
logging
.
warning
(
"Stats thread did not stop within timeout"
)
if
(
self
.
publish_kv_cache_events_thread
and
self
.
publish_kv_cache_events_thread
.
is_alive
()
):
self
.
publish_kv_cache_events_thread
.
stop
()
self
.
publish_kv_cache_events_thread
.
join
(
timeout
=
cleanup_timeout
)
if
self
.
publish_kv_cache_events_thread
.
is_alive
():
logging
.
warning
(
"KV cache events thread did not stop within timeout"
)
launch/dynamo-run/src/subprocess/trtllm_inc.py
View file @
9acaa8d1
...
...
@@ -12,17 +12,18 @@ import argparse
import
asyncio
import
logging
import
sys
from
contextlib
import
asynccontextmanager
from
typing
import
AsyncGenerator
,
Optional
from
typing
import
Optional
import
uvloop
# Import TRTLLM and related modules
from
tensorrt_llm
import
LLM
,
LlmArgs
,
SamplingParams
from
tensorrt_llm
import
SamplingParams
from
tensorrt_llm.llmapi.llm_utils
import
update_llm_args_with_extra_options
from
tensorrt_llm.llmapi.tokenizer
import
tokenizer_factory
from
trtllm.engine
import
get_llm_engine
from
trtllm.publishers
import
Publishers
from
dynamo.llm
import
KvMetricsPublisher
,
ModelType
,
register_llm
from
dynamo.llm
import
ModelType
,
register_llm
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
# Only used if you run it manually from the command line
...
...
@@ -30,6 +31,8 @@ DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
# Qwen/Qwen3-0.6B is not supported by TRTLLM yet.
DEFAULT_MODEL
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default buffer size for kv cache events.
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
=
1024
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
@@ -45,6 +48,7 @@ class Config:
tensor_parallel_size
:
int
kv_block_size
:
int
extra_engine_args
:
str
publish_events_and_metrics
:
bool
class
RequestHandler
:
...
...
@@ -52,34 +56,19 @@ class RequestHandler:
Request handler for the generate endpoint
"""
def
__init__
(
self
,
component
,
engine
,
default_sampling_params
):
def
__init__
(
self
,
component
,
engine
,
default_sampling_params
,
publishers
):
self
.
engine
=
engine
self
.
component
=
component
self
.
default_sampling_params
=
default_sampling_params
self
.
metrics_publisher
=
KvMetricsPublisher
()
def
setup_kv_metrics
(
self
):
# Initially send dummy metrics to kick start,
# TRTLLM will not update stat until forward pass is triggered
self
.
metrics_publisher
.
publish
(
0
,
# request_active_slots
1024
,
# request_total_slots
0
,
# kv_active_blocks
1024
,
# kv_total_blocks
0
,
# num_requests_waiting
0.0
,
# gpu_cache_usage_perc
0.0
,
# gpu_prefix_cache_hit_rate
)
task
=
asyncio
.
create_task
(
self
.
create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logging
.
debug
(
"metrics publisher endpoint created"
)
)
async
def
create_metrics_publisher_endpoint
(
self
):
logging
.
debug
(
"Creating metrics publisher endpoint"
)
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
component
)
self
.
publishers
=
publishers
self
.
first_generation
=
True
async
def
generate
(
self
,
request
):
# Check if there is an error in the publishers error queue
publishers_error
=
self
.
publishers
.
check_error_queue
()
if
publishers_error
:
raise
publishers_error
inputs
=
request
[
"token_ids"
]
sampling_params
=
self
.
default_sampling_params
...
...
@@ -98,6 +87,12 @@ class RequestHandler:
async
for
res
in
self
.
engine
.
llm
.
generate_async
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
streaming
=
True
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if
self
.
first_generation
and
self
.
publishers
:
self
.
publishers
.
start_publish_threads
()
self
.
first_generation
=
False
if
res
.
finished
:
yield
{
"finish_reason"
:
"stop"
,
"token_ids"
:
[]}
break
...
...
@@ -122,50 +117,6 @@ async def worker(runtime: DistributedRuntime):
await
init
(
runtime
,
cmd_line_args
())
class
AsyncLLMEngine
:
def
__init__
(
self
,
engine_args
):
self
.
engine_args
=
engine_args
self
.
_llm
:
Optional
[
LLM
]
=
None
self
.
_initialized
=
False
async
def
initialize
(
self
):
if
not
self
.
_initialized
:
model
=
self
.
engine_args
.
pop
(
"model"
)
self
.
_llm
=
LLM
(
model
=
model
,
**
self
.
engine_args
,
)
self
.
_initialized
=
True
async
def
cleanup
(
self
):
if
self
.
_initialized
:
try
:
self
.
_llm
.
shutdown
()
except
Exception
as
e
:
logging
.
error
(
f
"Error during cleanup:
{
e
}
"
)
finally
:
self
.
_initialized
=
False
@
property
def
llm
(
self
):
if
not
self
.
_initialized
:
raise
RuntimeError
(
"Engine not initialized"
)
return
self
.
_llm
@
asynccontextmanager
async
def
get_llm_engine
(
engine_args
:
LlmArgs
)
->
AsyncGenerator
[
AsyncLLMEngine
,
None
]:
engine
=
AsyncLLMEngine
(
engine_args
)
try
:
await
engine
.
initialize
()
yield
engine
except
Exception
as
e
:
logging
.
error
(
f
"Error in engine context:
{
e
}
"
)
raise
finally
:
await
engine
.
cleanup
()
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
):
"""
Instantiate and serve
...
...
@@ -187,8 +138,28 @@ async def init(runtime: DistributedRuntime, config: Config):
}
if
config
.
extra_engine_args
!=
""
:
arg_map
=
update_llm_args_with_extra_options
(
arg_map
,
config
.
extra_engine_args
)
if
config
.
publish_events_and_metrics
:
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config
=
None
if
"kv_cache_config"
not
in
arg_map
:
kv_cache_config
=
{}
kv_cache_config
[
"event_buffer_max_size"
]
=
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
else
:
kv_cache_config
=
arg_map
[
"kv_cache_config"
]
if
not
kv_cache_config
.
event_buffer_max_size
:
kv_cache_config
.
event_buffer_max_size
=
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
arg_map
[
"kv_cache_config"
]
=
kv_cache_config
# Only pytorch backend is supported for now to publish events and metrics.
if
"backend"
not
in
arg_map
:
arg_map
[
"backend"
]
=
"pytorch"
elif
arg_map
[
"backend"
]
!=
"pytorch"
:
logging
.
error
(
"Only pytorch backend is supported for now to publish events and metrics."
)
sys
.
exit
(
1
)
logging
.
debug
(
f
"TRTLLM engine args:
{
arg_map
}
"
)
logging
.
info
(
f
"TRTLLM engine args:
{
arg_map
}
"
)
engine_args
=
arg_map
# Populate default sampling params from the model
...
...
@@ -202,12 +173,29 @@ async def init(runtime: DistributedRuntime, config: Config):
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
config
.
model_path
,
config
.
model_name
)
handler
=
RequestHandler
(
component
,
engine
,
default_sampling_params
)
handler
.
setup_kv_metrics
()
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await
endpoint
.
serve_endpoint
(
handler
.
generate
)
publishers
=
None
if
config
.
publish_events_and_metrics
:
kv_listener
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
publishers
=
Publishers
(
component
,
engine
,
kv_listener
,
int
(
endpoint
.
lease_id
()),
config
.
kv_block_size
,
)
handler
=
RequestHandler
(
component
,
engine
,
default_sampling_params
,
publishers
)
try
:
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await
endpoint
.
serve_endpoint
(
handler
.
generate
)
finally
:
if
publishers
:
await
publishers
.
cleanup
()
def
cmd_line_args
():
...
...
@@ -235,6 +223,8 @@ def cmd_line_args():
parser
.
add_argument
(
"--tensor-parallel-size"
,
type
=
int
,
default
=
1
,
help
=
"Number of GPUs to use."
)
# IMPORTANT: We should ideally not expose this to users. We should be able to
# query the block size from the TRTLLM engine.
parser
.
add_argument
(
"--kv-block-size"
,
type
=
int
,
default
=
32
,
help
=
"Size of a KV cache block."
)
...
...
@@ -244,6 +234,11 @@ def cmd_line_args():
default
=
""
,
help
=
"Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine."
,
)
parser
.
add_argument
(
"--publish-events-and-metrics"
,
action
=
"store_true"
,
help
=
"Publish events and metrics to the dynamo components."
,
)
args
=
parser
.
parse_args
()
config
=
Config
()
...
...
@@ -270,10 +265,14 @@ def cmd_line_args():
config
.
tensor_parallel_size
=
args
.
tensor_parallel_size
config
.
kv_block_size
=
args
.
kv_block_size
config
.
extra_engine_args
=
args
.
extra_engine_args
config
.
publish_events_and_metrics
=
args
.
publish_events_and_metrics
return
config
if
__name__
==
"__main__"
:
uvloop
.
install
()
asyncio
.
run
(
worker
())
try
:
asyncio
.
run
(
worker
())
except
KeyboardInterrupt
:
logging
.
info
(
"Received SIGINT, shutting down..."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment