Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
03d976c7
Unverified
Commit
03d976c7
authored
Jun 26, 2025
by
Tanmay Verma
Committed by
GitHub
Jun 26, 2025
Browse files
refactor: Refactor the TRTLLM example components and improve UI (#1654)
Signed-off-by:
Tanmay Verma
<
tanmayv@nvidia.com
>
parent
8a2d6529
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
596 additions
and
1085 deletions
+596
-1085
examples/tensorrt_llm/README.md
examples/tensorrt_llm/README.md
+2
-2
examples/tensorrt_llm/common/base_engine.py
examples/tensorrt_llm/common/base_engine.py
+307
-521
examples/tensorrt_llm/common/parser.py
examples/tensorrt_llm/common/parser.py
+15
-150
examples/tensorrt_llm/common/utils.py
examples/tensorrt_llm/common/utils.py
+0
-108
examples/tensorrt_llm/components/prefill_worker.py
examples/tensorrt_llm/components/prefill_worker.py
+27
-26
examples/tensorrt_llm/components/worker.py
examples/tensorrt_llm/components/worker.py
+49
-52
examples/tensorrt_llm/configs/agg.yaml
examples/tensorrt_llm/configs/agg.yaml
+6
-1
examples/tensorrt_llm/configs/agg_router.yaml
examples/tensorrt_llm/configs/agg_router.yaml
+8
-2
examples/tensorrt_llm/configs/deepseek_r1/agg.yaml
examples/tensorrt_llm/configs/deepseek_r1/agg.yaml
+6
-1
examples/tensorrt_llm/configs/deepseek_r1/disagg.yaml
examples/tensorrt_llm/configs/deepseek_r1/disagg.yaml
+13
-10
examples/tensorrt_llm/configs/deepseek_r1/disagg_llm_api_config.yaml
...nsorrt_llm/configs/deepseek_r1/disagg_llm_api_config.yaml
+0
-88
examples/tensorrt_llm/configs/deepseek_r1/engine_configs/agg_config.yaml
...rt_llm/configs/deepseek_r1/engine_configs/agg_config.yaml
+0
-6
examples/tensorrt_llm/configs/deepseek_r1/engine_configs/decode_config.yaml
...llm/configs/deepseek_r1/engine_configs/decode_config.yaml
+31
-19
examples/tensorrt_llm/configs/deepseek_r1/engine_configs/prefill_config.yaml
...lm/configs/deepseek_r1/engine_configs/prefill_config.yaml
+17
-19
examples/tensorrt_llm/configs/deepseek_r1/mtp/engine_configs/agg_config.yaml
...lm/configs/deepseek_r1/mtp/engine_configs/agg_config.yaml
+0
-1
examples/tensorrt_llm/configs/deepseek_r1/mtp/engine_configs/decode_config.yaml
...configs/deepseek_r1/mtp/engine_configs/decode_config.yaml
+53
-0
examples/tensorrt_llm/configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml
...onfigs/deepseek_r1/mtp/engine_configs/prefill_config.yaml
+37
-0
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_agg.yaml
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_agg.yaml
+8
-1
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_disagg.yaml
...ples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_disagg.yaml
+17
-6
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml
...lm/configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml
+0
-72
No files found.
examples/tensorrt_llm/README.md
View file @
03d976c7
...
@@ -110,7 +110,7 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
...
@@ -110,7 +110,7 @@ dynamo serve graphs.agg:Frontend -f ./configs/agg.yaml
#### Aggregated serving with KV Routing
#### Aggregated serving with KV Routing
```
bash
```
bash
cd
/workspace/examples/tensorrt_llm
cd
/workspace/examples/tensorrt_llm
dynamo serve graphs.agg
_router
:Frontend
-f
./configs/agg_router.yaml
dynamo serve graphs.agg:Frontend
-f
./configs/agg_router.yaml
```
```
#### Disaggregated serving
#### Disaggregated serving
...
@@ -122,7 +122,7 @@ dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
...
@@ -122,7 +122,7 @@ dynamo serve graphs.disagg:Frontend -f ./configs/disagg.yaml
#### Disaggregated serving with KV Routing
#### Disaggregated serving with KV Routing
```
bash
```
bash
cd
/workspace/examples/tensorrt_llm
cd
/workspace/examples/tensorrt_llm
dynamo serve graphs.disagg
_router
:Frontend
-f
./configs/disagg_router.yaml
dynamo serve graphs.disagg:Frontend
-f
./configs/disagg_router.yaml
```
```
#### Aggregated serving with Multi-Token Prediction (MTP) and DeepSeek R1
#### Aggregated serving with Multi-Token Prediction (MTP) and DeepSeek R1
...
...
examples/tensorrt_llm/common/base_engine.py
View file @
03d976c7
...
@@ -12,588 +12,374 @@
...
@@ -12,588 +12,374 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
asyncio
import
copy
import
logging
import
logging
import
os
from
dataclasses
import
dataclass
import
signal
import
threading
from
contextlib
import
asynccontextmanager
from
enum
import
Enum
from
queue
import
Queue
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
from
common.parser
import
LLMAPIConfig
from
common.protocol
import
DisaggregatedTypeConverter
,
TRTLLMWorkerRequest
from
common.protocol
import
DisaggregatedTypeConverter
from
tensorrt_llm
import
SamplingParams
from
common.utils
import
ManagedThread
,
ServerType
from
tensorrt_llm.llmapi.llm_utils
import
update_llm_args_with_extra_options
from
tensorrt_llm.executor
import
CppExecutorError
from
tensorrt_llm.llmapi
import
LLM
,
SamplingParams
from
tensorrt_llm.llmapi.disagg_utils
import
(
CtxGenServerConfig
,
parse_disagg_config_file
,
)
from
tensorrt_llm.llmapi.tokenizer
import
tokenizer_factory
from
tensorrt_llm.llmapi.tokenizer
import
tokenizer_factory
from
tensorrt_llm.serve.openai_protocol
import
DisaggregatedParams
from
tensorrt_llm.serve.openai_protocol
import
(
DisaggregatedParams
as
OAIDisaggregatedParams
,
)
from
dynamo.llm
import
KvEventPublisher
,
WorkerMetricsP
ublisher
from
dynamo.llm
import
get_tensorrtllm_engine
,
get_tensorrtllm_p
ublisher
from
dynamo.
sdk
import
dynamo_context
from
dynamo.
runtime
import
DistributedRuntime
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
setLevel
(
logging
.
DEBUG
)
# Default buffer size for kv cache events.
class
DisaggRequestType
(
Enum
):
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
=
1024
CONTEXT_ONLY
=
"context_only"
GENERATION_ONLY
=
"generation_only"
def
update_args_from_disagg_config
(
engine_config
:
LLMAPIConfig
,
server_config
:
CtxGenServerConfig
):
# Update the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config
.
extra_args
.
update
(
**
server_config
.
other_args
)
engine_config
.
update_sub_configs
(
server_config
.
other_args
)
return
engine_config
def
_to_signed_i64
(
value
:
int
|
None
)
->
int
|
None
:
def
parse_endpoint
(
endpoint
:
str
)
->
tuple
[
str
,
str
,
str
]:
"""Convert a Python int to signed 64-bit range by two's complement."""
endpoint_str
=
endpoint
.
replace
(
"dyn://"
,
""
,
1
)
if
value
is
None
:
endpoint_parts
=
endpoint_str
.
split
(
"."
)
return
None
if
len
(
endpoint_parts
)
!=
3
:
raise
ValueError
(
if
value
>=
2
**
63
:
f
"Invalid endpoint format: '
{
endpoint
}
'. "
return
value
-
2
**
64
"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
if
value
<
-
(
2
**
63
):
)
return
((
value
+
2
**
63
)
%
2
**
64
)
-
2
**
63
return
value
def
get_sampling_params
(
sampling_params_dict
,
default_sampling_params
):
return
(
endpoint_parts
[
0
],
endpoint_parts
[
1
],
endpoint_parts
[
2
])
sampling_params
=
copy
.
deepcopy
(
default_sampling_params
)
for
key
,
value
in
sampling_params_dict
.
items
():
if
value
is
None
:
@
dataclass
continue
class
BaseEngineConfig
:
if
hasattr
(
sampling_params
,
key
):
"""Base engine configuration"""
setattr
(
sampling_params
,
key
,
value
)
return
sampling_params
namespace
:
str
component
:
str
endpoint
:
str
model_path
:
str
served_model_name
:
Optional
[
str
]
=
None
kv_block_size
:
int
=
32
extra_engine_args
:
str
=
""
publish_events_and_metrics
:
bool
=
False
disaggregation_mode
:
str
=
"prefill_and_decode"
remote_prefill_endpoint
:
Optional
[
str
]
=
None
lease_id
:
int
=
0
def
__str__
(
self
)
->
str
:
return
(
f
"Config(namespace=
{
self
.
namespace
}
, "
f
"component=
{
self
.
component
}
, "
f
"endpoint=
{
self
.
endpoint
}
, "
f
"model_path=
{
self
.
model_path
}
, "
f
"served_model_name=
{
self
.
served_model_name
}
, "
f
"kv_block_size=
{
self
.
kv_block_size
}
, "
f
"extra_engine_args=
{
self
.
extra_engine_args
}
, "
f
"publish_events_and_metrics=
{
self
.
publish_events_and_metrics
}
, "
f
"disaggregation_mode=
{
self
.
disaggregation_mode
}
, "
f
"remote_prefill_endpoint=
{
self
.
remote_prefill_endpoint
}
, "
f
"lease_id=
{
self
.
lease_id
}
)"
)
class
BaseTensorrtLLMEngine
:
class
BaseTensorrtLLMEngine
:
def
__init__
(
def
__init__
(
self
,
self
,
namespace_str
:
str
=
"dynamo"
,
config
:
BaseEngineConfig
,
component_str
:
str
=
"tensorrt-llm"
,
worker_id
:
Optional
[
str
]
=
None
,
engine_config
:
LLMAPIConfig
=
None
,
remote_prefill
:
bool
=
False
,
min_workers
:
int
=
0
,
disagg_config_file
:
Optional
[
str
]
=
None
,
block_size
:
int
=
32
,
router
:
str
=
"round_robin"
,
server_type
:
ServerType
=
ServerType
.
GEN
,
):
):
self
.
_namespace_str
=
namespace_str
self
.
_config
=
config
self
.
_component_str
=
component_str
self
.
_worker_id
=
worker_id
self
.
_remote_prefill
=
remote_prefill
self
.
_min_workers
=
0
self
.
_kv_block_size
=
block_size
self
.
_router
=
router
self
.
_server_type
=
server_type
self
.
_prefill_client
=
None
self
.
_prefill_client
=
None
self
.
_error_queue
:
Queue
=
Queue
()
self
.
_llm_engine
=
None
self
.
_kv_metrics_publisher
=
None
self
.
_llm_engine_context
=
None
self
.
_llm_publisher
=
None
if
self
.
_remote_prefill
or
self
.
_server_type
==
ServerType
.
CTX
:
self
.
_llm_publisher_context
=
None
self
.
_min_workers
=
min_workers
self
.
_runtime
=
None
if
disagg_config_file
is
None
or
not
os
.
path
.
exists
(
disagg_config_file
):
self
.
_first_generation
=
True
raise
ValueError
(
# Initialize default sampling params
"llmapi_disaggregated_config file does not exist or not provided"
self
.
default_sampling_params
=
SamplingParams
()
)
disagg_config
=
parse_disagg_config_file
(
disagg_config_file
)
async
def
initialize
(
self
,
runtime
:
DistributedRuntime
):
server_config
:
CtxGenServerConfig
=
None
"""Initialize the engine and prefill client if needed"""
self
.
_runtime
=
runtime
for
config
in
disagg_config
.
server_configs
:
# Select the first context server config
# Convert model path to Path object if it's a local path, otherwise keep as string
if
config
.
type
==
server_type
.
value
:
model_path
=
str
(
self
.
_config
.
model_path
)
server_config
=
config
break
# Initialize the LLM engine
engine_args
:
dict
[
str
,
Any
]
=
{
if
server_config
is
None
:
"model"
:
model_path
,
server_type_str
=
(
"tensor_parallel_size"
:
1
,
"generation"
if
server_type
==
ServerType
.
GEN
else
"context"
"backend"
:
"pytorch"
,
)
"skip_tokenizer_init"
:
True
,
raise
ValueError
(
}
f
"No
{
server_type_str
}
server config found. Please check the disaggregated config file."
)
if
self
.
_config
.
extra_engine_args
:
# TODO: Support extra engine args from json file as well.
engine_config
=
update_args_from_disagg_config
(
engine_config
,
server_config
)
engine_args
=
update_llm_args_with_extra_options
(
engine_args
,
self
.
_config
.
extra_engine_args
if
router
==
"kv"
:
self
.
_publish_stats
=
True
self
.
_publish_events
=
True
else
:
self
.
_publish_stats
=
False
self
.
_publish_events
=
False
if
self
.
_publish_stats
:
self
.
_kv_metrics_publisher
=
WorkerMetricsPublisher
()
if
self
.
_publish_events
:
if
self
.
_worker_id
is
None
:
raise
ValueError
(
"Worker ID is None!"
)
runtime
=
dynamo_context
[
"runtime"
]
kv_listener
=
runtime
.
namespace
(
self
.
_namespace_str
).
component
(
self
.
_component_str
)
)
self
.
_kv_event_publisher
=
KvEventPublisher
(
# Update the model path in the config to the model path used by the engine.
kv_listener
,
int
(
self
.
_worker_id
),
self
.
_kv_block_size
self
.
_config
.
model_path
=
str
(
engine_args
[
"model"
])
if
not
self
.
_config
.
model_path
:
raise
ValueError
(
"Model specification is required. Present neither in the config nor in the extra engine args."
)
)
logger
.
info
(
"KvEventPublisher is initialized"
)
self
.
_engine_config
=
engine_config
def
_init_engine
(
self
):
logger
.
info
(
"Initializing engine"
)
# Run the engine in a separate thread running the AsyncIO event loop.
self
.
_llm_engine
:
Optional
[
Any
]
=
None
self
.
_llm_engine_start_cv
=
threading
.
Condition
()
self
.
_llm_engine_shutdown_event
=
asyncio
.
Event
()
self
.
_event_thread
=
threading
.
Thread
(
target
=
asyncio
.
run
,
args
=
(
self
.
_run_llm_engine
(),)
)
# Populate default sampling params from the model
# Populate default sampling params from the model
tokenizer
=
tokenizer_factory
(
self
.
_engine_config
.
model_name
)
tokenizer
=
tokenizer_factory
(
self
.
_config
.
model_path
)
self
.
_default_sampling_params
=
SamplingParams
()
self
.
default_sampling_params
=
SamplingParams
()
self
.
_default_sampling_params
.
_setup
(
tokenizer
)
self
.
default_sampling_params
.
_setup
(
tokenizer
)
self
.
_default_sampling_params
.
stop
=
None
self
.
default_sampling_params
.
stop
=
None
self
.
publish_kv_cache_events_thread
=
None
if
self
.
_config
.
publish_events_and_metrics
:
self
.
publish_stats_thread
=
None
# 'event_buffer_max_size' is required to enable TRTLLM to publish kv cache events.
kv_cache_config
:
dict
[
str
,
Any
]
|
Any
=
None
self
.
_event_thread
.
start
()
if
"kv_cache_config"
not
in
engine_args
:
with
self
.
_llm_engine_start_cv
:
kv_cache_config
=
{}
while
self
.
_llm_engine
is
None
:
kv_cache_config
[
self
.
_llm_engine_start_cv
.
wait
()
"event_buffer_max_size"
]
=
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
# The 'threading.Thread()' will not raise the exception here should the engine
else
:
# failed to start, so the exception is passed back via the engine variable.
kv_cache_config
=
engine_args
[
"kv_cache_config"
]
if
isinstance
(
self
.
_llm_engine
,
Exception
):
if
(
e
=
self
.
_llm_engine
hasattr
(
kv_cache_config
,
"event_buffer_max_size"
)
logger
.
error
(
f
"Failed to start engine:
{
e
}
"
)
and
not
kv_cache_config
.
event_buffer_max_size
if
self
.
_event_thread
is
not
None
:
):
self
.
_event_thread
.
join
()
kv_cache_config
.
event_buffer_max_size
=
(
self
.
_event_thread
=
None
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
raise
e
)
elif
(
try
:
isinstance
(
kv_cache_config
,
dict
)
if
self
.
_publish_stats
:
and
"event_buffer_max_size"
not
in
kv_cache_config
self
.
_init_publish_metrics_thread
()
):
except
Exception
as
e
:
kv_cache_config
[
logger
.
error
(
f
"Failed to initialize publish metrics threads:
{
e
}
"
)
"event_buffer_max_size"
raise
e
]
=
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE
engine_args
[
"kv_cache_config"
]
=
kv_cache_config
try
:
if
self
.
_publish_events
:
# Enable iter perf stats by default if we are publishing events and metrics.
self
.
_init_publish_kv_cache_events_thread
()
if
not
engine_args
.
get
(
"enable_iter_perf_stats"
):
except
Exception
as
e
:
engine_args
[
"enable_iter_perf_stats"
]
=
True
logger
.
error
(
f
"Failed to initialize publish events threads:
{
e
}
"
)
raise
e
# Only pytorch backend is supported for now to publish events and metrics.
if
engine_args
.
get
(
"backend"
)
!=
"pytorch"
:
def
_init_publish_metrics_thread
(
self
):
logging
.
error
(
# Need to publish stats once so that worker can be selected.
"Only pytorch backend is supported for now to publish events and metrics."
# Publishing some dummy values...
)
request_active_slots
=
0
raise
RuntimeError
(
request_total_slots
=
4
"Only pytorch backend is supported for now to publish events and metrics. Hence, KV router is not supported."
kv_active_block
=
0
)
kv_total_blocks
=
4
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
if
self
.
_kv_metrics_publisher
is
None
:
logger
.
error
(
"KV metrics publisher not initialized!"
)
return
self
.
_kv_metrics_publisher
.
publish
(
request_active_slots
,
request_total_slots
,
kv_active_block
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
)
# Prepare threads for publishing stats but don't start them yet.
logging
.
info
(
f
"TRTLLM engine args:
{
engine_args
}
"
)
# TRTLLM needs to start generating tokens first before stats
# can be retrieved.
self
.
publish_stats_thread
=
ManagedThread
(
self
.
publish_stats_task
,
error_queue
=
self
.
_error_queue
,
name
=
"publish_stats_thread"
,
)
def
_init_publish_kv_cache_events_thread
(
self
):
# Get the engine using the asynccontextmanager
if
self
.
_kv_event_publisher
is
None
:
self
.
_llm_engine_context
=
get_tensorrtllm_engine
(
engine_args
)
logger
.
error
(
"KV event publisher not initialized!"
)
if
self
.
_llm_engine_context
is
not
None
:
return
self
.
_llm_engine
=
await
self
.
_llm_engine_context
.
__aenter__
()
else
:
# A set to store the block hash of partial block (i.e. block containing less than kv_block_size tokens) hashes.
raise
RuntimeError
(
"Failed to create LLM engine context"
)
# It is used to prevent sending remove event to kv router since partial blocks are not stored.
self
.
_partial_block_hashes
=
set
()
# Prepare threads for publishing kv cache events but don't start them yet.
# TRTLLM needs to start generating tokens first before kv cache events
# can be retrieved.
self
.
publish_kv_cache_events_thread
=
ManagedThread
(
self
.
publish_kv_cache_events_task
,
error_queue
=
self
.
_error_queue
,
name
=
"publish_kv_cache_events_thread"
,
)
async
def
publish_stats_task
(
self
):
if
(
"""
self
.
_config
.
publish_events_and_metrics
Publish stats to the metrics publisher.
and
self
.
_config
.
disaggregation_mode
!=
"prefill"
"""
):
if
self
.
_llm_engine
is
None
:
kv_listener
=
runtime
.
namespace
(
self
.
_config
.
namespace
).
component
(
logger
.
error
(
"LLM engine not initialized!"
)
self
.
_config
.
component
return
if
self
.
_kv_metrics_publisher
is
None
:
logger
.
error
(
"KV metrics publisher not initialized!"
)
return
False
stats
=
self
.
_llm_engine
.
get_stats_async
(
timeout
=
5
)
async
for
stat
in
stats
:
request_active_slots
=
stat
[
"numActiveRequests"
]
request_total_slots
=
stat
[
"maxNumActiveRequests"
]
kv_active_block
=
stat
[
"kvCacheStats"
][
"usedNumBlocks"
]
kv_total_blocks
=
stat
[
"kvCacheStats"
][
"maxNumBlocks"
]
reused_blocks
=
stat
[
"kvCacheStats"
][
"reusedBlocks"
]
freeNumBlocks
=
stat
[
"kvCacheStats"
][
"freeNumBlocks"
]
allocTotalBlocks
=
stat
[
"kvCacheStats"
][
"allocTotalBlocks"
]
allocNewBlocks
=
stat
[
"kvCacheStats"
][
"allocNewBlocks"
]
# NOTE: num paused requests is always 0 when using guarantee no evict scheduler (default).
num_requests_waiting
=
(
stat
[
"numQueuedRequests"
]
+
stat
[
"inflightBatchingStats"
][
"numPausedRequests"
]
)
)
gpu_cache_usage_perc
=
allocTotalBlocks
/
kv_total_blocks
self
.
_llm_publisher_context
=
get_tensorrtllm_publisher
(
gpu_prefix_cache_hit_rate
=
stat
[
"kvCacheStats"
][
"cacheHitRate"
]
kv_listener
,
self
.
_llm_engine
,
logger
.
debug
(
kv_listener
,
f
"Publishing stats: request_active_slots:
{
request_active_slots
}
, request_total_slots:
{
request_total_slots
}
, kv_active_block:
{
kv_active_block
}
, kv_total_blocks:
{
kv_total_blocks
}
, num_requests_waiting:
{
num_requests_waiting
}
, reused_blocks:
{
reused_blocks
}
, freeNumBlocks:
{
freeNumBlocks
}
, allocTotalBlocks:
{
allocTotalBlocks
}
, allocNewBlocks:
{
allocNewBlocks
}
, gpu_cache_usage_perc:
{
gpu_cache_usage_perc
}
, gpu_prefix_cache_hit_rate:
{
gpu_prefix_cache_hit_rate
}
"
self
.
_config
.
lease_id
,
self
.
_config
.
kv_block_size
,
)
)
if
self
.
_llm_publisher_context
is
not
None
:
self
.
_kv_metrics_publisher
.
publish
(
self
.
_llm_publisher
=
await
self
.
_llm_publisher_context
.
__aenter__
()
request_active_slots
,
else
:
request_total_slots
,
raise
RuntimeError
(
"Failed to create LLM publisher context"
)
kv_active_block
,
kv_total_blocks
,
# Initialize prefill client if in decode mode
num_requests_waiting
,
if
self
.
_config
.
disaggregation_mode
==
"decode"
:
gpu_cache_usage_perc
,
if
self
.
_config
.
remote_prefill_endpoint
is
None
:
gpu_prefix_cache_hit_rate
,
raise
ValueError
(
"remote_prefill_endpoint is required for decode mode"
)
logging
.
info
(
f
"Initializing remote prefill client for endpoint:
{
self
.
_config
.
remote_prefill_endpoint
}
"
)
)
(
return
True
parsed_namespace
,
parsed_component_name
,
async
def
publish_kv_cache_events_task
(
self
):
parsed_endpoint_name
,
"""
)
=
parse_endpoint
(
self
.
_config
.
remote_prefill_endpoint
)
Publish kv cache events to the events publisher.
if
self
.
_runtime
is
not
None
:
"""
self
.
_prefill_client
=
(
if
self
.
_llm_engine
is
None
:
await
self
.
_runtime
.
namespace
(
parsed_namespace
)
logger
.
error
(
"LLM engine not initialized!"
)
.
component
(
parsed_component_name
)
return
.
endpoint
(
parsed_endpoint_name
)
.
client
()
events
=
self
.
_llm_engine
.
get_kv_cache_events_async
(
timeout
=
5
)
async
for
event
in
events
:
event_id
=
event
[
"event_id"
]
data
=
event
[
"data"
]
if
data
[
"type"
]
==
"stored"
:
parent_hash
=
_to_signed_i64
(
data
[
"parent_hash"
])
token_ids
=
[]
num_block_tokens
=
[]
block_hashes
=
[]
for
block
in
data
[
"blocks"
]:
token_num_in_block
=
len
(
block
[
"tokens"
])
block_hash
=
_to_signed_i64
(
block
[
"block_hash"
])
if
token_num_in_block
>
self
.
_kv_block_size
:
logger
.
error
(
f
"Block
{
block_hash
}
contains
{
token_num_in_block
}
tokens, which is greater than kv_block_size
{
self
.
_kv_block_size
}
"
)
return
if
token_num_in_block
<
self
.
_kv_block_size
:
logger
.
debug
(
f
"Early stop when block
{
block_hash
}
containing
{
token_num_in_block
}
tokens not equal to kv_block_size
{
self
.
_kv_block_size
}
"
)
self
.
_partial_block_hashes
.
add
(
block_hash
)
break
num_block_tokens
.
append
(
token_num_in_block
)
block_hashes
.
append
(
block_hash
)
for
token
in
block
[
"tokens"
]:
token_ids
.
append
(
int
(
token
[
"token_id"
]))
# Note: Currently data does not have lora_id.
# Using 0 as default value. If later data has
# lora_id, we need to verify if this is correct.
lora_id
=
data
.
get
(
"lora_id"
,
0
)
logger
.
debug
(
f
"publish stored event: event_id:
{
event_id
}
, token_ids:
{
token_ids
}
, num_block_tokens:
{
num_block_tokens
}
, block_hashes:
{
block_hashes
}
, lora_id:
{
lora_id
}
, parent_hash:
{
parent_hash
}
"
)
self
.
_kv_event_publisher
.
publish_stored
(
event_id
,
token_ids
,
num_block_tokens
,
block_hashes
,
lora_id
,
parent_hash
,
)
)
elif
data
[
"type"
]
==
"removed"
:
else
:
block_hashes
=
[]
raise
RuntimeError
(
"Runtime not initialized"
)
for
block_hash
in
data
[
"block_hashes"
]:
block_hash
=
_to_signed_i64
(
block_hash
)
if
block_hash
in
self
.
_partial_block_hashes
:
logger
.
debug
(
f
"Skipping removing block hash
{
block_hash
}
since it is a partial block"
)
self
.
_partial_block_hashes
.
remove
(
block_hash
)
continue
block_hashes
.
append
(
block_hash
)
logger
.
debug
(
f
"publish removed event: event_id:
{
event_id
}
, block_hashes:
{
block_hashes
}
"
)
self
.
_kv_event_publisher
.
publish_removed
(
event_id
,
block_hashes
)
return
True
def
_start_threads
(
self
):
async
def
cleanup
(
self
):
if
(
"""Cleanup resources"""
self
.
publish_kv_cache_events_thread
if
self
.
_llm_publisher_context
:
and
not
self
.
publish_kv_cache_events_thread
.
is_alive
()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self
.
_stats_loop
=
asyncio
.
get_running_loop
()
self
.
publish_kv_cache_events_thread
.
set_loop
(
self
.
_stats_loop
)
self
.
publish_kv_cache_events_thread
.
start
()
logger
.
debug
(
"Started kv cache events thread"
)
if
self
.
publish_stats_thread
and
not
self
.
publish_stats_thread
.
is_alive
():
self
.
_stats_loop
=
asyncio
.
get_running_loop
()
self
.
publish_stats_thread
.
set_loop
(
self
.
_stats_loop
)
self
.
publish_stats_thread
.
start
()
logger
.
debug
(
"Started stats thread"
)
async
def
_run_llm_engine
(
self
):
# Counter to keep track of ongoing request counts.
self
.
_ongoing_request_count
=
0
@
asynccontextmanager
async
def
async_llm_wrapper
():
# Create LLM in a thread to avoid blocking
loop
=
asyncio
.
get_running_loop
()
try
:
try
:
llm
=
await
loop
.
run_in_executor
(
await
self
.
_llm_publisher_context
.
__aexit__
(
None
,
None
,
None
)
None
,
except
Exception
as
e
:
lambda
:
LLM
(
logging
.
error
(
f
"Error during publisher cleanup:
{
e
}
"
)
model
=
self
.
_engine_config
.
model_name
,
**
self
.
_engine_config
.
to_dict
(),
),
)
yield
llm
finally
:
finally
:
if
"llm"
in
locals
():
self
.
_llm_publisher
=
None
# Run shutdown in a thread to avoid blocking
self
.
_llm_publisher_context
=
None
await
loop
.
run_in_executor
(
None
,
llm
.
shutdown
)
try
:
async
with
async_llm_wrapper
()
as
engine
:
# Capture the engine event loop and make it visible to other threads.
self
.
_event_loop
=
asyncio
.
get_running_loop
()
# Signal the engine is started and make it visible to other threads.
if
self
.
_llm_engine_context
:
with
self
.
_llm_engine_start_cv
:
try
:
self
.
_llm_engine
=
engine
await
self
.
_llm_engine_context
.
__aexit__
(
None
,
None
,
None
)
self
.
_llm_engine_start_cv
.
notify_all
()
except
Exception
as
e
:
logging
.
error
(
f
"Error during engine cleanup:
{
e
}
"
)
logger
.
info
(
"Engine loaded and ready to serve..."
)
finally
:
self
.
_llm_engine
=
None
# Wait for the engine shutdown signal.
self
.
_llm_engine_context
=
None
await
self
.
_llm_engine_shutdown_event
.
wait
()
# Stop the publishing threads
self
.
_prefill_client
=
None
if
self
.
publish_stats_thread
and
self
.
publish_stats_thread
.
is_alive
():
self
.
publish_stats_thread
.
stop
()
self
.
publish_stats_thread
.
join
()
if
(
self
.
publish_kv_cache_events_thread
and
self
.
publish_kv_cache_events_thread
.
is_alive
()
):
self
.
publish_kv_cache_events_thread
.
stop
()
self
.
publish_kv_cache_events_thread
.
join
()
# Wait for the ongoing requests to complete.
while
self
.
_ongoing_request_count
>
0
:
logger
.
info
(
"Awaiting remaining {} requests"
.
format
(
self
.
_ongoing_request_count
)
)
await
asyncio
.
sleep
(
1
)
# Cancel all tasks in the event loop.
async
def
remote_prefill
(
self
,
request
:
TRTLLMWorkerRequest
):
for
task
in
asyncio
.
all_tasks
(
loop
=
self
.
_event_loop
):
"""
if
task
is
not
asyncio
.
current_task
():
Send a prefill request to the remote prefill worker.
task
.
cancel
()
except
Exception
as
e
:
Args:
# Signal and pass the exception back via the engine variable if the engine
request: The original request to be sent for prefill
# failed to start. If the engine has started, re-raise the exception.
with
self
.
_llm_engine_start_cv
:
if
self
.
_llm_engine
is
None
:
self
.
_llm_engine
=
e
self
.
_llm_engine_start_cv
.
notify_all
()
return
raise
e
self
.
_llm_engine
=
None
Returns:
logger
.
info
(
"Shutdown complete"
)
The response from the remote prefill worker
async
def
_get_remote_prefill_response
(
self
,
request
):
Raises:
prefill_request
=
copy
.
deepcopy
(
request
)
ValueError: If prefill client is not initialized or multiple responses received
"""
prefill_request
=
request
.
model_copy
(
deep
=
True
)
# TRTLLM requires max_tokens to be set for prefill requests.
# TRTLLM requires max_tokens to be set for prefill requests.
prefill_request
.
stop_conditions
.
max_tokens
=
1
prefill_request
.
stop_conditions
.
max_tokens
=
1
prefill_request
.
disaggregated_params
=
DisaggregatedParams
(
prefill_request
.
disaggregated_params
=
OAI
DisaggregatedParams
(
request_type
=
DisaggRequestType
.
CONTEXT_ONLY
.
value
request_type
=
"context_only"
)
)
if
self
.
_prefill_client
is
None
:
if
self
.
_prefill_client
is
None
:
raise
ValueError
(
"Prefill client not initialized"
)
raise
ValueError
(
"Prefill client not initialized"
)
try
:
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
remote_prefill_responses
=
[
remote_prefill_response
async
for
remote_prefill_response
in
await
self
.
_prefill_client
.
round_robin
(
prefill_request
.
model_dump_json
()
)
]
except
Exception
as
e
:
raise
ValueError
(
f
"Error in remote prefill:
{
e
}
"
)
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
if
len
(
remote_prefill_responses
)
>
1
:
ctx_responses
=
[
ctx_response
async
for
ctx_response
in
await
self
.
_prefill_client
.
round_robin
(
prefill_request
.
model_dump_json
()
)
]
if
len
(
ctx_responses
)
>
1
:
raise
ValueError
(
raise
ValueError
(
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
"Prefill worker returned more than one response. This is currently not supported in remote prefill mode."
)
)
logger
.
debug
(
f
"Received response from prefill worker:
{
ctx_responses
[
0
].
data
()
}
"
if
len
(
remote_prefill_responses
)
==
0
:
)
raise
ValueError
(
"No response received from remote prefill worker"
)
remote_prefill_response
=
ctx_responses
[
0
]
remote_prefill_response
=
remote_prefill_responses
[
0
]
return
remote_prefill_response
return
remote_prefill_response
async
def
generate
(
self
,
request
):
async
def
generate
(
self
,
request
:
TRTLLMWorkerRequest
):
if
self
.
_llm_engine
is
None
:
if
self
.
_llm_engine
is
None
:
raise
RuntimeError
(
"Engine not initialized"
)
raise
RuntimeError
(
"Engine not initialized"
)
if
not
self
.
_error_queue
.
empty
():
if
self
.
_llm_publisher
:
raise
self
.
_error_queue
.
get
()
publishers_error
=
self
.
_llm_publisher
.
check_error_queue
()
if
publishers_error
:
raise
publishers_error
self
.
_ongoing_request_count
+=
1
inputs
=
request
.
token_ids
try
:
# Decode the disaggregated params from the request
worker_inputs
=
request
.
token_ids
disaggregated_params
=
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
request
.
disaggregated_params
)
num_output_tokens_so_far
=
0
if
self
.
_config
.
disaggregation_mode
==
"decode"
:
# Run prefill/context phase remotely if disaggregation mode is decode.
try
:
prefill_result
=
await
self
.
remote_prefill
(
request
)
except
Exception
as
e
:
raise
ValueError
(
f
"Error in remote prefill:
{
e
}
"
)
remote_prefill_response
=
prefill_result
.
data
()
if
(
remote_prefill_response
[
"finish_reason"
]
==
"stop"
or
remote_prefill_response
[
"finish_reason"
]
==
"error"
):
yield
remote_prefill_response
return
num_output_tokens_so_far
=
len
(
remote_prefill_response
[
"token_ids"
])
# Decode the disaggregated params from the remote prefill response
# Decode the disaggregated params from the remote prefill response
disaggregated_params
=
(
disaggregated_params
=
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
request
.
disaggregated_params
OAIDisaggregatedParams
(
)
**
remote_prefill_response
[
"disaggregated_params"
]
)
num_output_tokens_so_far
=
0
if
self
.
_remote_prefill
and
self
.
_server_type
==
ServerType
.
GEN
:
ctx_response
=
await
self
.
_get_remote_prefill_response
(
request
)
remote_prefill_response
=
ctx_response
.
data
()
if
(
remote_prefill_response
[
"finish_reason"
]
==
"stop"
or
remote_prefill_response
[
"finish_reason"
]
==
"error"
):
yield
remote_prefill_response
return
num_output_tokens_so_far
=
len
(
remote_prefill_response
[
"token_ids"
])
# Decode the disaggregated params from the remote prefill response
disaggregated_params
=
(
DisaggregatedTypeConverter
.
to_llm_disaggregated_params
(
DisaggregatedParams
(
**
remote_prefill_response
[
"disaggregated_params"
]
)
)
)
)
)
# Send the first token response to the client
first_token_response
=
remote_prefill_response
first_token_response
.
pop
(
"disaggregated_params"
)
yield
first_token_response
disaggregated_params
.
request_type
=
(
DisaggRequestType
.
GENERATION_ONLY
.
value
)
logger
.
debug
(
f
"Worker inputs:
{
worker_inputs
}
, disaggregated params:
{
disaggregated_params
}
"
)
sampling_params
=
get_sampling_params
(
request
.
sampling_options
.
dict
(),
self
.
_default_sampling_params
)
)
max_tokens
=
request
.
stop_conditions
.
max_tokens
if
max_tokens
:
sampling_params
.
max_tokens
=
max_tokens
async
for
response
in
self
.
_llm_engine
.
generate_async
(
inputs
=
worker_inputs
,
sampling_params
=
sampling_params
,
disaggregated_params
=
disaggregated_params
,
streaming
=
self
.
_server_type
!=
ServerType
.
CTX
,
):
if
response
.
finished
and
self
.
_server_type
!=
ServerType
.
CTX
:
yield
{
"finish_reason"
:
"stop"
,
"token_ids"
:
[]}
break
if
not
response
.
outputs
:
yield
{
"finish_reason"
:
"error"
,
"token_ids"
:
[]}
break
output
=
response
.
outputs
[
0
]
next_total_toks
=
len
(
output
.
token_ids
)
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
output
.
finish_reason
if
output
.
stop_reason
:
out
[
"stop_reason"
]
=
output
.
stop_reason
if
self
.
_server_type
==
ServerType
.
CTX
:
# Return the disaggregated params only when operating in prefill mode.
out
[
"disaggregated_params"
]
=
DisaggregatedTypeConverter
.
to_oai_disaggregated_params
(
output
.
disaggregated_params
).
dict
()
yield
out
num_output_tokens_so_far
=
next_total_toks
except
CppExecutorError
:
signal
.
raise_signal
(
signal
.
SIGINT
)
except
Exception
as
e
:
raise
RuntimeError
(
"Failed to generate: "
+
str
(
e
))
self
.
_start_threads
()
# Send the first token response to the client
self
.
_ongoing_request_count
-=
1
first_token_response
=
remote_prefill_response
first_token_response
.
pop
(
"disaggregated_params"
)
yield
first_token_response
# Set the disaggregated params to generation_only for the rest of the generation
disaggregated_params
.
request_type
=
"generation_only"
sampling_params
=
self
.
default_sampling_params
for
key
,
value
in
request
.
sampling_options
.
model_dump
().
items
():
if
not
value
:
continue
if
hasattr
(
sampling_params
,
key
):
setattr
(
sampling_params
,
key
,
value
)
max_tokens
=
request
.
stop_conditions
.
max_tokens
if
max_tokens
:
sampling_params
.
max_tokens
=
max_tokens
# TODO: Disable streaming for context only requests when adding disagg support
async
for
res
in
self
.
_llm_engine
.
llm
.
generate_async
(
inputs
=
inputs
,
sampling_params
=
sampling_params
,
disaggregated_params
=
disaggregated_params
,
streaming
=
(
self
.
_config
.
disaggregation_mode
!=
"prefill"
),
):
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
if
self
.
_first_generation
and
self
.
_llm_publisher
:
self
.
_llm_publisher
.
start
()
self
.
_first_generation
=
False
if
res
.
finished
and
self
.
_config
.
disaggregation_mode
!=
"prefill"
:
yield
{
"finish_reason"
:
"stop"
,
"token_ids"
:
[]}
break
if
not
res
.
outputs
:
yield
{
"finish_reason"
:
"error"
,
"token_ids"
:
[]}
break
output
=
res
.
outputs
[
0
]
next_total_toks
=
len
(
output
.
token_ids
)
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
output
.
finish_reason
if
output
.
stop_reason
:
out
[
"stop_reason"
]
=
output
.
stop_reason
if
self
.
_config
.
disaggregation_mode
==
"prefill"
:
# Return the disaggregated params only when operating in prefill mode.
out
[
"disaggregated_params"
]
=
DisaggregatedTypeConverter
.
to_oai_disaggregated_params
(
output
.
disaggregated_params
).
model_dump
()
yield
out
num_output_tokens_so_far
=
next_total_toks
examples/tensorrt_llm/common/parser.py
View file @
03d976c7
...
@@ -14,136 +14,28 @@
...
@@ -14,136 +14,28 @@
# limitations under the License.
# limitations under the License.
import
argparse
import
argparse
import
os
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Tuple
import
yaml
from
tensorrt_llm._torch.pyexecutor.config
import
PyTorchConfig
from
tensorrt_llm.llmapi
import
KvCacheConfig
from
tensorrt_llm.llmapi.llm_args
import
DecodingBaseConfig
@
dataclass
class
LLMAPIConfig
:
def
__init__
(
self
,
model_name
:
str
,
model_path
:
str
|
None
=
None
,
pytorch_backend_config
:
PyTorchConfig
|
None
=
None
,
kv_cache_config
:
KvCacheConfig
|
None
=
None
,
speculative_config
:
DecodingBaseConfig
|
None
=
None
,
**
kwargs
,
):
self
.
model_name
=
model_name
self
.
model_path
=
model_path
self
.
pytorch_backend_config
=
pytorch_backend_config
self
.
kv_cache_config
=
kv_cache_config
self
.
speculative_config
=
speculative_config
self
.
extra_args
=
kwargs
# Hardcoded to skip tokenizer init for now.
# We will handle the tokenization/detokenization
# in the base engine.
if
"skip_tokenizer_init"
in
self
.
extra_args
:
self
.
extra_args
.
pop
(
"skip_tokenizer_init"
)
self
.
skip_tokenizer_init
=
True
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
data
=
{
"kv_cache_config"
:
self
.
kv_cache_config
,
"speculative_config"
:
self
.
speculative_config
,
"skip_tokenizer_init"
:
self
.
skip_tokenizer_init
,
}
if
self
.
extra_args
:
data
.
update
(
self
.
extra_args
)
return
data
def
update_sub_configs
(
self
,
other_config
:
Dict
[
str
,
Any
]):
# TODO: Consider removing pytorch_backend_config parsing as this section
# was collapsed to top level config fields in recent TRTLLM versions.
if
"pytorch_backend_config"
in
other_config
:
self
.
pytorch_backend_config
=
PyTorchConfig
(
**
other_config
[
"pytorch_backend_config"
]
)
self
.
extra_args
.
pop
(
"pytorch_backend_config"
,
None
)
if
"kv_cache_config"
in
other_config
:
self
.
kv_cache_config
=
KvCacheConfig
(
**
other_config
[
"kv_cache_config"
])
self
.
extra_args
.
pop
(
"kv_cache_config"
,
None
)
if
"speculative_config"
in
other_config
:
self
.
speculative_config
=
DecodingBaseConfig
.
from_dict
(
other_config
[
"speculative_config"
]
)
self
.
extra_args
.
pop
(
"speculative_config"
,
None
)
def
_get_llm_args
(
engine_config
):
# Only do model validation checks and leave other checks to LLMAPI
if
"model_name"
not
in
engine_config
:
raise
ValueError
(
"Model name is required in the TRT-LLM engine config."
)
if
engine_config
.
get
(
"model_path"
,
""
):
if
os
.
path
.
exists
(
engine_config
.
get
(
"model_path"
,
""
)):
engine_config
[
"model_path"
]
=
Path
(
engine_config
[
"model_path"
])
else
:
raise
ValueError
(
f
"Model path
{
engine_config
[
'model_path'
]
}
does not exist"
)
model_name
=
engine_config
[
"model_name"
]
model_path
=
engine_config
.
get
(
"model_path"
,
None
)
engine_config
.
pop
(
"model_name"
)
engine_config
.
pop
(
"model_path"
,
None
)
# Store all other args as kwargs
llm_api_config
=
LLMAPIConfig
(
model_name
=
model_name
,
model_path
=
model_path
,
**
engine_config
,
)
# Parse supported sub configs and remove from kwargs
llm_api_config
.
update_sub_configs
(
engine_config
)
return
llm_api_config
def
_init_engine_args
(
engine_args_filepath
):
"""Initialize engine arguments from config file."""
if
not
os
.
path
.
isfile
(
engine_args_filepath
):
raise
ValueError
(
"'YAML file containing TRT-LLM engine args must be provided in when launching the worker."
)
try
:
with
open
(
engine_args_filepath
)
as
file
:
trtllm_engine_config
=
yaml
.
safe_load
(
file
)
except
yaml
.
YAMLError
as
e
:
raise
RuntimeError
(
f
"Failed to parse engine config:
{
e
}
"
)
return
_get_llm_args
(
trtllm_engine_config
)
def
parse_tensorrt_llm_args
(
def
parse_tensorrt_llm_args
(
config_args
,
config_args
,
)
->
Tuple
[
Any
,
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]]
:
)
->
argparse
.
Namespace
:
parser
=
argparse
.
ArgumentParser
(
description
=
"A TensorRT-LLM Worker parser"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"A TensorRT-LLM Worker parser"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--engine_args"
,
type
=
str
,
required
=
True
,
help
=
"Path to the engine args file"
"--extra-engine-args"
,
type
=
str
,
default
=
""
,
help
=
"Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
served_model_name
"
,
"--
model-path
"
,
type
=
str
,
type
=
str
,
help
=
"Name of the model to serve"
,
default
=
None
,
default
=
None
,
help
=
"Path to disk model or HuggingFace model identifier to load."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--llmapi-disaggregated-config"
,
"--served_model_name"
,
"-c"
,
type
=
str
,
type
=
str
,
help
=
"Path to the llmapi disaggregated config file"
,
help
=
"Name to serve the model under."
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--router"
,
"--router"
,
...
@@ -152,46 +44,19 @@ def parse_tensorrt_llm_args(
...
@@ -152,46 +44,19 @@ def parse_tensorrt_llm_args(
default
=
"random"
,
default
=
"random"
,
help
=
"Router type to use for scheduling requests to workers"
,
help
=
"Router type to use for scheduling requests to workers"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--min-workers"
,
"--kv-block-size"
,
type
=
int
,
default
=
1
,
help
=
"Minimum number of workers for aggregated (monolith) server"
,
)
parser
.
add_argument
(
"--min-prefill-workers"
,
type
=
int
,
default
=
1
,
help
=
"Minimum number of prefill workers for disaggregated server"
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
type
=
int
,
default
=
32
,
default
=
32
,
help
=
"Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend."
,
help
=
"Number of tokens per KV block in TRTLLM worker. Default is 32 for pytorch backend."
,
)
)
parser
.
add_argument
(
"--remote-prefill"
,
action
=
"store_true"
,
help
=
"Use remote prefill workers for generation server in Disaggregated mode."
,
)
args
=
parser
.
parse_args
(
config_args
)
return
(
args
,
_init_engine_args
(
args
.
engine_args
))
def
parse_dynamo_run_args
()
->
Tuple
[
Any
,
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]]:
parser
=
argparse
.
ArgumentParser
(
description
=
"A TensorRT-LLM Dynamo-run engine parser"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--engine_args"
,
type
=
str
,
required
=
True
,
help
=
"Path to the engine args file"
"--enable-disagg"
,
)
parser
.
add_argument
(
"--publish-kv-cache-events"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"
Publish KV cache events from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode.
"
,
help
=
"
Enable remote prefill for the worker
"
,
)
)
args
,
_
=
parser
.
parse_
known
_args
(
)
args
=
parser
.
parse_
args
(
config
_args
)
return
(
args
,
_init_engine_args
(
args
.
engine_args
))
return
args
examples/tensorrt_llm/common/utils.py
deleted
100644 → 0
View file @
8a2d6529
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
logging
import
threading
import
traceback
import
weakref
from
enum
import
Enum
from
queue
import
Queue
from
typing
import
Any
,
Callable
,
Coroutine
,
Optional
,
TypedDict
,
Union
logger
=
logging
.
getLogger
(
__name__
)
AsyncTask
=
Union
[
Callable
[...,
Coroutine
[
Any
,
Any
,
bool
]],
weakref
.
WeakMethod
]
class
RoutingStrategy
(
Enum
):
ROUND_ROBIN
=
"round_robin"
RANDOM
=
"random"
PREFIX
=
"prefix"
class
RequestType
(
Enum
):
CHAT
=
"chat"
COMPLETION
=
"completion"
class
ServerType
(
Enum
):
# Generation server used for disaggregated and aggregated requests
GEN
=
"gen"
# Context server used for disaggregated requests
CTX
=
"ctx"
# Dynamo run server used for Dynamo run requests
DYN_RUN
=
"dyn_run"
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
class
ManagedThread
(
threading
.
Thread
):
def
__init__
(
self
,
task
:
Optional
[
AsyncTask
],
error_queue
:
Optional
[
Queue
]
=
None
,
name
:
Optional
[
str
]
=
None
,
loop
:
Optional
[
asyncio
.
AbstractEventLoop
]
=
None
,
**
kwargs
,
):
super
().
__init__
(
name
=
name
)
self
.
task
=
task
self
.
error_queue
=
error_queue
self
.
kwargs
=
kwargs
self
.
loop
=
loop
self
.
daemon
=
True
self
.
stop_event
=
threading
.
Event
()
def
set_loop
(
self
,
loop
:
asyncio
.
AbstractEventLoop
):
self
.
loop
=
loop
def
run
(
self
):
while
not
self
.
stop_event
.
is_set
():
task
:
Optional
[
AsyncTask
]
=
self
.
task
if
isinstance
(
task
,
weakref
.
WeakMethod
):
task
=
task
()
if
task
is
None
:
# Normally, this should not happen.
logger
.
warning
(
"WeakMethod is expired."
)
break
if
task
is
None
:
break
try
:
if
self
.
loop
is
None
:
logger
.
error
(
"[ManagedThread] Loop not initialized!"
)
break
future
=
asyncio
.
run_coroutine_threadsafe
(
task
(
**
self
.
kwargs
),
self
.
loop
)
_
=
future
.
result
()
except
Exception
as
e
:
logger
.
error
(
f
"Error in thread
{
self
.
name
}
:
{
e
}
\n
{
traceback
.
format_exc
()
}
"
)
if
self
.
error_queue
is
not
None
:
self
.
error_queue
.
put
(
e
)
logger
.
info
(
f
"Thread
{
self
.
name
}
stopped."
)
def
stop
(
self
):
self
.
stop_event
.
set
()
examples/tensorrt_llm/components/prefill_worker.py
View file @
03d976c7
...
@@ -12,15 +12,13 @@
...
@@ -12,15 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
asyncio
import
logging
import
logging
from
common.base_engine
import
BaseTensorrtLLMEngine
from
common.base_engine
import
BaseEngineConfig
,
BaseTensorrtLLMEngine
from
common.parser
import
parse_tensorrt_llm_args
from
common.parser
import
parse_tensorrt_llm_args
from
common.protocol
import
TRTLLMWorkerRequest
from
common.protocol
import
TRTLLMWorkerRequest
from
common.utils
import
ServerType
from
dynamo.sdk
import
async_on_start
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
dynamo_context
,
endpoint
,
on_shutdown
,
service
from
dynamo.sdk.lib.config
import
ServiceConfig
from
dynamo.sdk.lib.config
import
ServiceConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -39,34 +37,37 @@ class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
...
@@ -39,34 +37,37 @@ class TensorRTLLMPrefillWorker(BaseTensorrtLLMEngine):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
config
=
ServiceConfig
.
get_instance
()
config
=
ServiceConfig
.
get_instance
()
config_args
=
config
.
as_args
(
class_name
,
prefix
=
""
)
config_args
=
config
.
as_args
(
class_name
,
prefix
=
""
)
args
,
engine_config
=
parse_tensorrt_llm_args
(
config_args
)
args
=
parse_tensorrt_llm_args
(
config_args
)
worker_id
=
dynamo_context
[
"endpoints"
][
0
].
lease_id
()
lease_id
=
dynamo_context
[
"endpoints"
][
0
].
lease_id
()
super
().
__init__
(
namespace
,
_
=
TensorRTLLMPrefillWorker
.
dynamo_address
()
# type: ignore
namespace_str
=
"dynamo"
,
component_str
=
class_name
,
engine_config
=
BaseEngineConfig
(
worker_id
=
worker_id
,
namespace
=
namespace
,
engine_config
=
engine_config
,
component
=
class_name
,
remote_prefill
=
args
.
remote_prefill
,
endpoint
=
"generate"
,
min_workers
=
args
.
min_workers
,
model_path
=
args
.
model_path
,
disagg_config_file
=
args
.
llmapi_disaggregated_config
,
served_model_name
=
args
.
served_model_name
,
block_size
=
args
.
block_size
,
kv_block_size
=
args
.
kv_block_size
,
router
=
args
.
router
,
extra_engine_args
=
args
.
extra_engine_args
,
server_type
=
ServerType
.
CTX
,
publish_events_and_metrics
=
False
,
disaggregation_mode
=
"prefill"
,
remote_prefill_endpoint
=
None
,
lease_id
=
lease_id
,
)
)
super
().
__init__
(
config
=
engine_config
)
@
async_on_start
@
async_on_start
async
def
async_init
(
self
):
async
def
async_init
(
self
):
self
.
_init_engine
()
runtime
=
dynamo_context
[
"runtime"
]
if
self
.
_kv_metrics_publisher
is
not
None
:
await
self
.
initialize
(
runtime
)
task
=
asyncio
.
create_task
(
self
.
create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logger
.
info
(
"metrics publisher endpoint created"
)
)
logger
.
info
(
"TensorRT-LLM Prefill Worker initialized"
)
logger
.
info
(
"TensorRT-LLM Prefill Worker initialized"
)
async
def
create_metrics_publisher_endpoint
(
self
):
@
on_shutdown
component
=
dynamo_context
[
"component"
]
async
def
async_cleanup
(
self
):
await
self
.
kv_metrics_publisher
.
create_endpoint
(
component
)
logger
.
info
(
"Cleaning up TensorRT-LLM Prefill Worker"
)
await
self
.
cleanup
()
logger
.
info
(
"TensorRT-LLM Prefill Worker cleanup completed"
)
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
request
:
TRTLLMWorkerRequest
):
async
def
generate
(
self
,
request
:
TRTLLMWorkerRequest
):
...
...
examples/tensorrt_llm/components/worker.py
View file @
03d976c7
...
@@ -12,17 +12,22 @@
...
@@ -12,17 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
asyncio
import
logging
import
logging
from
common.base_engine
import
BaseTensorrtLLMEngine
from
common.base_engine
import
BaseEngineConfig
,
BaseTensorrtLLMEngine
from
common.parser
import
parse_tensorrt_llm_args
from
common.parser
import
parse_tensorrt_llm_args
from
common.protocol
import
TRTLLMWorkerRequest
from
common.protocol
import
TRTLLMWorkerRequest
from
common.utils
import
ServerType
from
components.prefill_worker
import
TensorRTLLMPrefillWorker
from
components.prefill_worker
import
TensorRTLLMPrefillWorker
from
dynamo.llm
import
ModelType
,
register_llm
from
dynamo.llm
import
ModelType
,
register_llm
from
dynamo.sdk
import
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
service
from
dynamo.sdk
import
(
async_on_start
,
depends
,
dynamo_context
,
endpoint
,
on_shutdown
,
service
,
)
from
dynamo.sdk.lib.config
import
ServiceConfig
from
dynamo.sdk.lib.config
import
ServiceConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -43,74 +48,66 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
...
@@ -43,74 +48,66 @@ class TensorRTLLMWorker(BaseTensorrtLLMEngine):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
config
=
ServiceConfig
.
get_instance
()
config
=
ServiceConfig
.
get_instance
()
config_args
=
config
.
as_args
(
class_name
,
prefix
=
""
)
config_args
=
config
.
as_args
(
class_name
,
prefix
=
""
)
args
,
engine_config
=
parse_tensorrt_llm_args
(
config_args
)
args
=
parse_tensorrt_llm_args
(
config_args
)
self
.
served_model_name
=
args
.
served_model_name
lease_id
=
dynamo_context
[
"endpoints"
][
0
].
lease_id
()
worker_id
=
dynamo_context
[
"endpoints"
][
0
].
lease_id
()
namespace
,
_
=
TensorRTLLMWorker
.
dynamo_address
()
# type: ignore
namespace
,
_
=
TensorRTLLMWorker
.
dynamo_address
()
# type: ignore
self
.
_min_prefill_workers
=
args
.
min_prefill_workers
endpoint_name
=
"generate"
super
().
__init__
(
publish_events_and_metrics
=
args
.
router
==
"kv"
namespace_str
=
namespace
,
prefill_class_name
=
"TensorRTLLMPrefillWorker"
component_str
=
class_name
,
worker_id
=
worker_id
,
if
args
.
enable_disagg
:
engine_config
=
engine_config
,
disaggregation_mode
=
"decode"
remote_prefill
=
args
.
remote_prefill
,
else
:
min_workers
=
args
.
min_workers
,
disaggregation_mode
=
"prefill_and_decode"
disagg_config_file
=
args
.
llmapi_disaggregated_config
,
block_size
=
args
.
block_size
,
engine_config
=
BaseEngineConfig
(
router
=
args
.
router
,
namespace
=
namespace
,
server_type
=
ServerType
.
GEN
,
component
=
class_name
,
endpoint
=
endpoint_name
,
model_path
=
args
.
model_path
,
served_model_name
=
args
.
served_model_name
,
kv_block_size
=
args
.
kv_block_size
,
extra_engine_args
=
args
.
extra_engine_args
,
publish_events_and_metrics
=
publish_events_and_metrics
,
disaggregation_mode
=
disaggregation_mode
,
remote_prefill_endpoint
=
f
"dyn://
{
namespace
}
.
{
prefill_class_name
}
.generate"
,
lease_id
=
lease_id
,
)
)
super
().
__init__
(
config
=
engine_config
)
@
async_on_start
@
async_on_start
async
def
async_init
(
self
):
async
def
async_init
(
self
):
self
.
_init_engine
()
runtime
=
dynamo_context
[
"runtime"
]
runtime
=
dynamo_context
[
"runtime"
]
await
self
.
initialize
(
runtime
)
logger
.
info
(
"Registering LLM for discovery"
)
logger
.
info
(
"Registering LLM for discovery"
)
comp_ns
,
comp_name
=
TensorRTLLMWorker
.
dynamo_address
()
# type: ignore
endpoint
=
(
endpoint
=
runtime
.
namespace
(
comp_ns
).
component
(
comp_name
).
endpoint
(
"generate"
)
runtime
.
namespace
(
self
.
_config
.
namespace
)
.
component
(
self
.
_config
.
component
)
.
endpoint
(
self
.
_config
.
endpoint
)
)
try
:
try
:
await
register_llm
(
await
register_llm
(
ModelType
.
Backend
,
ModelType
.
Backend
,
endpoint
,
endpoint
,
self
.
_
engine_
config
.
model_
name
,
self
.
_config
.
model_
path
,
self
.
served_model_name
,
self
.
_config
.
served_model_name
,
kv_cache_block_size
=
self
.
_kv_block_size
,
kv_cache_block_size
=
self
.
_
config
.
kv_block_size
,
)
)
logger
.
info
(
"Successfully registered LLM for discovery"
)
logger
.
info
(
"Successfully registered LLM for discovery"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to register LLM for discovery:
{
e
}
"
)
logger
.
error
(
f
"Failed to register LLM for discovery:
{
e
}
"
)
raise
raise
if
self
.
_remote_prefill
:
runtime
=
dynamo_context
[
"runtime"
]
comp_ns
,
comp_name
=
TensorRTLLMPrefillWorker
.
dynamo_address
()
# type: ignore
self
.
_prefill_client
=
(
await
runtime
.
namespace
(
comp_ns
)
.
component
(
comp_name
)
.
endpoint
(
"generate"
)
.
client
()
)
while
len
(
self
.
_prefill_client
.
instance_ids
())
<
self
.
_min_prefill_workers
:
logger
.
info
(
f
"Waiting for prefill workers to be ready.
\n
"
f
" Current:
{
len
(
self
.
_prefill_client
.
instance_ids
())
}
,"
f
" Required:
{
self
.
_min_prefill_workers
}
"
)
await
asyncio
.
sleep
(
30
)
if
self
.
_kv_metrics_publisher
is
not
None
:
task
=
asyncio
.
create_task
(
self
.
create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logger
.
info
(
"metrics publisher endpoint created"
)
)
logger
.
info
(
"TensorRT-LLM Worker initialized"
)
logger
.
info
(
"TensorRT-LLM Worker initialized"
)
async
def
create_metrics_publisher_endpoint
(
self
):
@
on_shutdown
component
=
dynamo_context
[
"component"
]
async
def
async_cleanup
(
self
):
await
self
.
_kv_metrics_publisher
.
create_endpoint
(
component
)
logger
.
info
(
"Cleaning up TensorRT-LLM Worker"
)
await
self
.
cleanup
()
logger
.
info
(
"TensorRT-LLM Worker cleanup completed"
)
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
request
:
TRTLLMWorkerRequest
):
async
def
generate
(
self
,
request
:
TRTLLMWorkerRequest
):
...
...
examples/tensorrt_llm/configs/agg.yaml
View file @
03d976c7
...
@@ -20,8 +20,13 @@ Frontend:
...
@@ -20,8 +20,13 @@ Frontend:
router
:
round-robin
router
:
round-robin
TensorRTLLMWorker
:
TensorRTLLMWorker
:
# Path to disk model or HuggingFace model identifier to load
model-path
:
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name
:
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
served_model_name
:
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
engine_args
:
"
configs/llm_api_config.yaml"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args
:
"
configs/engine_configs/agg_config.yaml"
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
...
examples/tensorrt_llm/configs/agg_router.yaml
View file @
03d976c7
...
@@ -20,9 +20,15 @@ Frontend:
...
@@ -20,9 +20,15 @@ Frontend:
router
:
kv
router
:
kv
TensorRTLLMWorker
:
TensorRTLLMWorker
:
engine_args
:
"
configs/llm_api_config_router.yaml"
# Path to disk model or HuggingFace model identifier to load
model-path
:
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Name to serve the model under
served_model_name
:
deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args
:
"
configs/engine_configs/agg_config.yaml"
router
:
kv
router
:
kv
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
resources
:
resources
:
gpu
:
1
gpu
:
1
\ No newline at end of file
examples/tensorrt_llm/configs/deepseek_r1/agg.yaml
View file @
03d976c7
...
@@ -22,7 +22,12 @@ Frontend:
...
@@ -22,7 +22,12 @@ Frontend:
TensorRTLLMWorker
:
TensorRTLLMWorker
:
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
engine_args
:
"
configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path
:
"
nvidia/DeepSeek-R1-FP4"
extra-engine-args
:
"
configs/deepseek_r1/engine_configs/agg_config.yaml"
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
...
examples/tensorrt_llm/configs/deepseek_r1/disagg.yaml
View file @
03d976c7
...
@@ -22,14 +22,13 @@ Frontend:
...
@@ -22,14 +22,13 @@ Frontend:
TensorRTLLMWorker
:
TensorRTLLMWorker
:
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
engine_args
:
"
configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config
:
"
configs/deepseek_r1/disagg_llm_api_config.yaml"
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
remote-prefill
:
true
# You can also specify the full path to locally downloaded weights
# NOTE: When testing/benchmarking multiple prefill workers, you can set
# instead of a HuggingFace ID here.
# this number to the exact amount of prefill workers if you want Dynamo to
model-path
:
"
nvidia/DeepSeek-R1-FP4"
# wait until all the prefill workers are ready before marking the decode
extra-engine-args
:
"
configs/deepseek_r1/engine_configs/decode_config.yaml"
# worker ready.
enable-disagg
:
true
min-prefill-workers
:
1
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
@@ -37,8 +36,12 @@ TensorRTLLMWorker:
...
@@ -37,8 +36,12 @@ TensorRTLLMWorker:
gpu
:
4
gpu
:
4
TensorRTLLMPrefillWorker
:
TensorRTLLMPrefillWorker
:
engine_args
:
"
configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config
:
"
configs/deepseek_r1/disagg_llm_api_config.yaml"
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path
:
"
nvidia/DeepSeek-R1-FP4"
extra-engine-args
:
"
configs/deepseek_r1/engine_configs/prefill_config.yaml"
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
...
examples/tensorrt_llm/configs/deepseek_r1/disagg_llm_api_config.yaml
deleted
100644 → 0
View file @
8a2d6529
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example Configs for Context & Generation on GB200 nodes
# - Context on 1xGB200 (4xB00)
# - Generation on 1xGB200 (4xB200)
# NOTE: Fields like hostname, ports, urls, num_instances, etc. only used by trtllm-serve, not by dynamo
backend
:
pytorch
context_servers
:
# Context/prefill processes many tokens at once, so for a large ISL, a large
# batch size may not be needed to saturate GPU utilization.
max_batch_size
:
1
max_num_tokens
:
8192
max_seq_len
:
8192
# TP/EP/PP/DP
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
pipeline_parallel_size
:
1
enable_attention_dp
:
true
kv_cache_config
:
free_gpu_memory_fraction
:
0.75
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
# Overlap scheduler not currently supported in context-only
disable_overlap_scheduler
:
true
print_iter_log
:
true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype
:
fp8
generation_servers
:
# Generation/decode processes one token per request at a time, so a larger
# batch size helps to saturate GPU utilization.
max_batch_size
:
256
max_num_tokens
:
256
# 8448 = 8192 ISL + 256 OSL
max_seq_len
:
8448
# TP/EP/PP/DP
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
pipeline_parallel_size
:
1
enable_attention_dp
:
false
kv_cache_config
:
# With dp attention disabled: high free_gpu_memory_fraction is fine.
free_gpu_memory_fraction
:
0.85
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: This field is called 'enable_overlap_scheduler' in older TRTLLM versions
disable_overlap_scheduler
:
false
use_cuda_graph
:
true
cuda_graph_padding_enabled
:
true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes
:
-
1
-
2
-
4
-
8
-
16
-
32
-
64
-
128
-
256
print_iter_log
:
true
# NOTE: This dtype must match in both context/generation configs
kv_cache_dtype
:
fp8
examples/tensorrt_llm/configs/deepseek_r1/
agg_llm_api
_config.yaml
→
examples/tensorrt_llm/configs/deepseek_r1/
engine_configs/agg
_config.yaml
View file @
03d976c7
...
@@ -12,12 +12,6 @@
...
@@ -12,12 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model_name
:
"
nvidia/DeepSeek-R1-FP4"
backend
:
pytorch
backend
:
pytorch
# TP/EP/PP/DP
# TP/EP/PP/DP
...
...
examples/tensorrt_llm/configs/
llm_api_config_router
.yaml
→
examples/tensorrt_llm/configs/
deepseek_r1/engine_configs/decode_config
.yaml
View file @
03d976c7
...
@@ -12,32 +12,44 @@
...
@@ -12,32 +12,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
backend
:
pytorch
# TP/EP/PP/DP
# In the case of disaggregated deployment, this config will apply to each server
tensor_parallel_size
:
4
# and will be overwritten by the disaggregated config file
moe_expert_parallel_size
:
4
pipeline_parallel_size
:
1
# TODO: figure out how to generate this from the service config or vice versa
model_name
:
"
deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path
:
null
tensor_parallel_size
:
1
moe_expert_parallel_size
:
1
enable_attention_dp
:
false
enable_attention_dp
:
false
max_num_tokens
:
8192
max_batch_size
:
1
6
max_batch_size
:
25
6
trust_remote_code
:
true
max_num_tokens
:
256
backend
:
pytorch
# 8448 = 8192 ISL + 256 OSL
enable_chunked_prefill
:
true
max_seq_len
:
8448
kv_cache_config
:
kv_cache_config
:
free_gpu_memory_fraction
:
0.95
# With dp attention disabled: high free_gpu_memory_fraction is fine.
event_buffer_max_size
:
1024
free_gpu_memory_fraction
:
0.85
enable_block_reuse
:
true
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
disable_overlap_scheduler
:
false
use_cuda_graph
:
true
use_cuda_graph
:
true
enable_iter_perf_stats
:
true
cuda_graph_padding_enabled
:
true
# NOTE: For larger max batch size, you may want to add larger cuda graph
# batch sizes below to match.
cuda_graph_batch_sizes
:
-
1
-
2
-
4
-
8
-
16
-
32
-
64
-
128
-
256
print_iter_log
:
true
kv_cache_dtype
:
fp8
examples/tensorrt_llm/configs/
llm_api_config_disagg_router
.yaml
→
examples/tensorrt_llm/configs/
deepseek_r1/engine_configs/prefill_config
.yaml
View file @
03d976c7
...
@@ -12,32 +12,30 @@
...
@@ -12,32 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
backend
:
pytorch
# TP/EP/PP/DP
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
pipeline_parallel_size
:
1
enable_attention_dp
:
true
# In the case of disaggregated deployment, this config will apply to each server
max_batch_size
:
1
# and will be overwritten by the disaggregated config file
# TODO: figure out how to generate this from the service config or vice versa
model_name
:
"
deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
model_path
:
null
tensor_parallel_size
:
1
moe_expert_parallel_size
:
1
enable_attention_dp
:
false
max_num_tokens
:
8192
max_num_tokens
:
8192
max_batch_size
:
16
max_seq_len
:
8192
trust_remote_code
:
true
backend
:
pytorch
enable_chunked_prefill
:
true
kv_cache_config
:
kv_cache_config
:
free_gpu_memory_fraction
:
0.95
# With dp attention disabled: high free_gpu_memory_fraction is fine.
event_buffer_max_size
:
1024
free_gpu_memory_fraction
:
0.75
enable_block_reuse
:
true
# With dp attention enabled: large ISL at high concurrency may need
# free_gpu_memory_fraction low to have enough available memory.
# free_gpu_memory_fraction: 0.30
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: pytorch_backend_config section flattened since: https://github.com/NVIDIA/TensorRT-LLM/pull/4603
# NOTE: overlap_scheduler enabled by default since this commit and changed
# NOTE: overlap_scheduler enabled by default since this commit and changed
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# config field from 'enable_overlap_scheduler' to 'disable_overlap_scheduler':
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
# https://github.com/NVIDIA/TensorRT-LLM/commit/b4e5df0ee0024eda3eeb83a6ba822245a30ab428
use_cuda_graph
:
true
disable_overlap_scheduler
:
true
enable_iter_perf_stats
:
true
print_iter_log
:
true
# NOTE: This dtype must match in both prefill/decode configs
kv_cache_dtype
:
fp8
examples/tensorrt_llm/configs/deepseek_r1/mtp/
mtp_agg_llm_api
_config.yaml
→
examples/tensorrt_llm/configs/deepseek_r1/mtp/
engine_configs/agg
_config.yaml
View file @
03d976c7
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
# You can also specify the full path to locally downloaded weights
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
# instead of a HuggingFace ID here.
model_name
:
"
nvidia/DeepSeek-R1-FP4"
backend
:
pytorch
backend
:
pytorch
tensor_parallel_size
:
4
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
moe_expert_parallel_size
:
4
...
...
examples/tensorrt_llm/configs/deepseek_r1/mtp/engine_configs/decode_config.yaml
0 → 100644
View file @
03d976c7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend
:
pytorch
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
enable_attention_dp
:
false
max_batch_size
:
256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens
:
512
# 8704 = 8192 ISL + 512 OSL
max_seq_len
:
8704
kv_cache_config
:
free_gpu_memory_fraction
:
0.85
# Enable the MTP(Multi-Token Prediction) in decode model engine
speculative_config
:
decoding_type
:
MTP
num_nextn_predict_layers
:
1
use_cuda_graph
:
true
cuda_graph_padding_enabled
:
true
cuda_graph_batch_sizes
:
-
1
-
2
-
4
-
8
-
16
-
32
-
64
-
128
-
256
print_iter_log
:
true
kv_cache_dtype
:
fp8
examples/tensorrt_llm/configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml
0 → 100644
View file @
03d976c7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend
:
pytorch
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
enable_attention_dp
:
true
max_batch_size
:
1
max_num_tokens
:
8192
max_seq_len
:
8192
kv_cache_config
:
free_gpu_memory_fraction
:
0.75
print_iter_log
:
true
kv_cache_dtype
:
fp8
disable_overlap_scheduler
:
true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config
:
decoding_type
:
MTP
num_nextn_predict_layers
:
1
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_agg.yaml
View file @
03d976c7
...
@@ -21,7 +21,14 @@ Frontend:
...
@@ -21,7 +21,14 @@ Frontend:
TensorRTLLMWorker
:
TensorRTLLMWorker
:
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
engine_args
:
"
configs/deepseek_r1/mtp/mtp_agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path
:
"
nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args
:
"
configs/deepseek_r1/mtp/engine_configs/agg_config.yaml"
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
...
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_disagg.yaml
View file @
03d976c7
...
@@ -21,19 +21,30 @@ Frontend:
...
@@ -21,19 +21,30 @@ Frontend:
TensorRTLLMWorker
:
TensorRTLLMWorker
:
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
served_model_name
:
"
nvidia/DeepSeek-R1-FP4"
engine_args
:
"
configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config
:
"
configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml"
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path
:
"
nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args
:
"
configs/deepseek_r1/mtp/engine_configs/decode_config.yaml"
router
:
round-robin
router
:
round-robin
remote-prefill
:
true
enable-disagg
:
true
min-prefill-workers
:
1
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
resources
:
resources
:
gpu
:
4
gpu
:
4
TensorRTLLMPrefillWorker
:
TensorRTLLMPrefillWorker
:
engine_args
:
"
configs/deepseek_r1/agg_llm_api_config.yaml"
# NOTE: FP4 only supported starting with Blackwell GPUs.
llmapi-disaggregated-config
:
"
configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml"
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
model-path
:
"
nvidia/DeepSeek-R1-FP4"
# Path to a YAML file containing additional keyword arguments to pass to the TRTLLM engine.
# The fields in `extra-engine-args` holds higher priority than the above TRTLLM engine fields.
extra-engine-args
:
"
configs/deepseek_r1/mtp/engine_configs/prefill_config.yaml"
router
:
round-robin
router
:
round-robin
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
...
examples/tensorrt_llm/configs/deepseek_r1/mtp/mtp_disagg_llm_api_config.yaml
deleted
100644 → 0
View file @
8a2d6529
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: FP4 only supported starting with Blackwell GPUs.
# https://huggingface.co/nvidia/DeepSeek-R1-FP4
# You can also specify the full path to locally downloaded weights
# instead of a HuggingFace ID here.
backend
:
pytorch
context_servers
:
num_instances
:
1
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
enable_attention_dp
:
true
max_batch_size
:
1
max_num_tokens
:
8192
max_seq_len
:
8192
kv_cache_config
:
free_gpu_memory_fraction
:
0.75
print_iter_log
:
true
kv_cache_dtype
:
fp8
disable_overlap_scheduler
:
true
# Enable the MTP(Multi-Token Prediction) in the prefill model engine
speculative_config
:
decoding_type
:
MTP
num_nextn_predict_layers
:
1
generation_servers
:
num_instances
:
1
tensor_parallel_size
:
4
moe_expert_parallel_size
:
4
enable_attention_dp
:
false
max_batch_size
:
256
# Note: When MPT is enabled and `cuda_graph_batch_sizes` is specified, `max_num_tokens` must satisfy the following formula:
# max_num_tokens >= max(cuda_graph_batch_sizes) * (num_nextn_predict_layers + 1)
# This is a known issue in TensorRT-LLM and will be resolved in the next release.
max_num_tokens
:
512
# 8704 = 8192 ISL + 512 OSL
max_seq_len
:
8704
kv_cache_config
:
free_gpu_memory_fraction
:
0.85
# Enable the MTP(Multi-Token Prediction) in the decode model engine
speculative_config
:
decoding_type
:
MTP
num_nextn_predict_layers
:
1
use_cuda_graph
:
true
cuda_graph_padding_enabled
:
true
cuda_graph_batch_sizes
:
-
1
-
2
-
4
-
8
-
16
-
32
-
64
-
128
-
256
print_iter_log
:
true
kv_cache_dtype
:
fp8
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment