Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
8435b993
Commit
8435b993
authored
Mar 12, 2025
by
Tanmay Verma
Committed by
GitHub
Mar 13, 2025
Browse files
fix: Fix TRTLLM chat to work with latest ToT (#127)
Co-authored-by:
Ryan McCormick
<
rmccormick@nvidia.com
>
parent
089f8e1b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
10 deletions
+67
-10
examples/python_rs/llm/tensorrt_llm/common/base_engine.py
examples/python_rs/llm/tensorrt_llm/common/base_engine.py
+16
-0
examples/python_rs/llm/tensorrt_llm/common/parser.py
examples/python_rs/llm/tensorrt_llm/common/parser.py
+6
-0
examples/python_rs/llm/tensorrt_llm/disaggregated/kv_router.py
...les/python_rs/llm/tensorrt_llm/disaggregated/kv_router.py
+37
-7
examples/python_rs/llm/tensorrt_llm/disaggregated/router.py
examples/python_rs/llm/tensorrt_llm/disaggregated/router.py
+8
-2
examples/python_rs/llm/tensorrt_llm/disaggregated/worker.py
examples/python_rs/llm/tensorrt_llm/disaggregated/worker.py
+0
-1
No files found.
examples/python_rs/llm/tensorrt_llm/common/base_engine.py
View file @
8435b993
...
@@ -138,6 +138,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
...
@@ -138,6 +138,10 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
kv_active_block
=
0
kv_active_block
=
0
kv_total_blocks
=
4
kv_total_blocks
=
4
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
if
self
.
_kv_metrics_publisher
is
None
:
if
self
.
_kv_metrics_publisher
is
None
:
logger
.
error
(
"KV metrics publisher not initialized!"
)
logger
.
error
(
"KV metrics publisher not initialized!"
)
return
return
...
@@ -147,6 +151,9 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
...
@@ -147,6 +151,9 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
request_total_slots
,
request_total_slots
,
kv_active_block
,
kv_active_block
,
kv_total_blocks
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
)
)
# Prepare threads for publishing stats but don't start them yet.
# Prepare threads for publishing stats but don't start them yet.
...
@@ -197,11 +204,20 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
...
@@ -197,11 +204,20 @@ class BaseTensorrtLLMEngine(ChatProcessorMixin):
logger
.
error
(
"KV metrics publisher not initialized!"
)
logger
.
error
(
"KV metrics publisher not initialized!"
)
return
False
return
False
# TODO: Remove this once we have the actual values.
# Adding dummy values for now so it doesn't break the metrics.
num_requests_waiting
=
0
gpu_cache_usage_perc
=
0.0
gpu_prefix_cache_hit_rate
=
0.0
self
.
_kv_metrics_publisher
.
publish
(
self
.
_kv_metrics_publisher
.
publish
(
request_active_slots
,
request_active_slots
,
request_total_slots
,
request_total_slots
,
kv_active_block
,
kv_active_block
,
kv_total_blocks
,
kv_total_blocks
,
num_requests_waiting
,
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
,
)
)
logger
.
debug
(
logger
.
debug
(
f
"Published stats: request_active_slots:
{
request_active_slots
}
, request_total_slots:
{
request_total_slots
}
, kv_active_block:
{
kv_active_block
}
, kv_total_blocks:
{
kv_total_blocks
}
"
f
"Published stats: request_active_slots:
{
request_active_slots
}
, request_total_slots:
{
request_total_slots
}
, kv_active_block:
{
kv_active_block
}
, kv_total_blocks:
{
kv_total_blocks
}
"
...
...
examples/python_rs/llm/tensorrt_llm/common/parser.py
View file @
8435b993
...
@@ -128,5 +128,11 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]
...
@@ -128,5 +128,11 @@ def parse_tensorrt_llm_args() -> Tuple[Any, Tuple[Dict[str, Any], Dict[str, Any]
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode."
,
help
=
"Publish stats from TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode."
,
)
)
parser
.
add_argument
(
"--kv-block-size"
,
type
=
int
,
help
=
"KV block size for TensorRT-LLM. Currently, only supported for context worker in Disaggregated mode."
,
default
=
64
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
(
args
,
_init_engine_args
(
args
.
engine_args
))
return
(
args
,
_init_engine_args
(
args
.
engine_args
))
examples/python_rs/llm/tensorrt_llm/disaggregated/kv_router.py
View file @
8435b993
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
asyncio
import
asyncio
import
copy
import
copy
import
enum
import
json
import
json
import
traceback
import
traceback
from
typing
import
AsyncIterator
from
typing
import
AsyncIterator
...
@@ -22,6 +23,7 @@ from typing import AsyncIterator
...
@@ -22,6 +23,7 @@ from typing import AsyncIterator
import
uvloop
import
uvloop
from
common.base_engine
import
ChatProcessorMixin
from
common.base_engine
import
ChatProcessorMixin
from
common.parser
import
LLMAPIConfig
,
parse_tensorrt_llm_args
from
common.parser
import
LLMAPIConfig
,
parse_tensorrt_llm_args
from
common.processor
import
parse_chat_message_content
from
common.protocol
import
(
from
common.protocol
import
(
DisaggChatCompletionRequest
,
DisaggChatCompletionRequest
,
DisaggChatCompletionStreamResponse
,
DisaggChatCompletionStreamResponse
,
...
@@ -37,6 +39,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
...
@@ -37,6 +39,11 @@ from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger
.
set_level
(
"debug"
)
logger
.
set_level
(
"debug"
)
class
EndpointType
(
enum
.
Enum
):
chat
=
"chat"
completion
=
"completion"
class
Scheduler
:
class
Scheduler
:
def
__init__
(
self
,
kv_router
:
KvRouter
):
def
__init__
(
self
,
kv_router
:
KvRouter
):
self
.
kv_router
=
kv_router
self
.
kv_router
=
kv_router
...
@@ -77,13 +84,32 @@ class Router(ChatProcessorMixin):
...
@@ -77,13 +84,32 @@ class Router(ChatProcessorMixin):
logger
.
info
(
"INITIALIZED ROUTER"
)
logger
.
info
(
"INITIALIZED ROUTER"
)
async
def
_get_ctx_resp
(
self
,
request
,
ctx_client
):
async
def
_get_ctx_resp
(
self
,
request
,
ctx_client
,
endpoint_type
:
EndpointType
):
logger
.
debug
(
f
"Received request
{
request
}
"
)
logger
.
debug
(
f
"Received request
{
request
}
"
)
# NOTE: this will increase TTFT since we are encoding the prompt here
# NOTE: this will increase TTFT since we are encoding the prompt here
# prompt is also encoded in the worker.
# prompt is also encoded in the worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
# TODO: we need to implement our own request processing and protocols to send only token ids to llmapi worker.
token_ids
=
self
.
_tokenizer
.
encode
(
request
.
prompt
)
if
endpoint_type
==
EndpointType
.
completion
:
token_ids
=
self
.
_tokenizer
.
encode
(
request
.
prompt
)
else
:
conversation
=
[]
for
message
in
request
.
messages
:
conversation
.
extend
(
parse_chat_message_content
(
message
))
tool_dicts
=
(
None
if
request
.
tools
is
None
else
[
tool
.
model_dump
()
for
tool
in
request
.
tools
]
)
token_ids
=
self
.
_tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
tokenize
=
True
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
chat_template
=
request
.
chat_template
,
**
(
request
.
chat_template_kwargs
or
{}),
)
worker_id_generator
:
AsyncIterator
=
self
.
scheduler
.
generate
(
worker_id_generator
:
AsyncIterator
=
self
.
scheduler
.
generate
(
Tokens
(
tokens
=
token_ids
).
model_dump_json
()
Tokens
(
tokens
=
token_ids
).
model_dump_json
()
)
)
...
@@ -92,7 +118,7 @@ class Router(ChatProcessorMixin):
...
@@ -92,7 +118,7 @@ class Router(ChatProcessorMixin):
await
worker_id_generator
.
__anext__
()
await
worker_id_generator
.
__anext__
()
)
# only one worker id is returned
)
# only one worker id is returned
request
.
max_tokens
=
1
request
.
max_
completion_
tokens
=
1
request
.
disaggregated_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
request
.
disaggregated_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
logger
.
debug
(
f
"[router] Sending request to context server:
{
request
}
"
)
logger
.
debug
(
f
"[router] Sending request to context server:
{
request
}
"
)
...
@@ -132,7 +158,9 @@ class Router(ChatProcessorMixin):
...
@@ -132,7 +158,9 @@ class Router(ChatProcessorMixin):
gen_req
=
copy
.
deepcopy
(
request
)
gen_req
=
copy
.
deepcopy
(
request
)
ctx_resp
=
await
self
.
_get_ctx_resp
(
request
,
self
.
ctx_completion_client
)
ctx_resp
=
await
self
.
_get_ctx_resp
(
request
,
self
.
ctx_completion_client
,
EndpointType
.
completion
)
ctx_resp_obj
=
DisaggCompletionStreamResponse
.
model_validate
(
ctx_resp
)
ctx_resp_obj
=
DisaggCompletionStreamResponse
.
model_validate
(
ctx_resp
)
gen_req
.
disaggregated_params
=
DisaggregatedParams
.
model_validate
(
gen_req
.
disaggregated_params
=
DisaggregatedParams
.
model_validate
(
...
@@ -165,7 +193,9 @@ class Router(ChatProcessorMixin):
...
@@ -165,7 +193,9 @@ class Router(ChatProcessorMixin):
gen_req
=
copy
.
deepcopy
(
request
)
gen_req
=
copy
.
deepcopy
(
request
)
ctx_resp
=
await
self
.
_get_ctx_resp
(
request
,
self
.
ctx_chat_client
)
ctx_resp
=
await
self
.
_get_ctx_resp
(
request
,
self
.
ctx_chat_client
,
EndpointType
.
chat
)
ctx_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate_json
(
ctx_resp
)
ctx_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate_json
(
ctx_resp
)
gen_req
.
disaggregated_params
=
DisaggregatedParams
.
model_validate
(
gen_req
.
disaggregated_params
=
DisaggregatedParams
.
model_validate
(
...
@@ -184,7 +214,7 @@ class Router(ChatProcessorMixin):
...
@@ -184,7 +214,7 @@ class Router(ChatProcessorMixin):
async
for
response
in
await
self
.
gen_chat_client
.
round_robin
(
async
for
response
in
await
self
.
gen_chat_client
.
round_robin
(
gen_req
.
model_dump_json
()
gen_req
.
model_dump_json
()
):
):
gen_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate
(
gen_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate
_json
(
response
.
data
()
response
.
data
()
)
)
yield
json
.
loads
(
gen_resp_obj
.
model_dump_json
(
exclude_unset
=
True
))
yield
json
.
loads
(
gen_resp_obj
.
model_dump_json
(
exclude_unset
=
True
))
...
@@ -228,7 +258,7 @@ async def worker(runtime: DistributedRuntime, args, engine_config):
...
@@ -228,7 +258,7 @@ async def worker(runtime: DistributedRuntime, args, engine_config):
kv_listener
=
runtime
.
namespace
(
"dynamo"
).
component
(
"tensorrt-llm-ctx"
)
kv_listener
=
runtime
.
namespace
(
"dynamo"
).
component
(
"tensorrt-llm-ctx"
)
await
kv_listener
.
create_service
()
await
kv_listener
.
create_service
()
kv_router
=
KvRouter
(
runtime
,
kv_listener
)
kv_router
=
KvRouter
(
runtime
,
kv_listener
,
args
.
kv_block_size
)
completions_endpoint
=
component
.
endpoint
(
"completions"
)
completions_endpoint
=
component
.
endpoint
(
"completions"
)
chat_endpoint
=
component
.
endpoint
(
"chat/completions"
)
chat_endpoint
=
component
.
endpoint
(
"chat/completions"
)
...
...
examples/python_rs/llm/tensorrt_llm/disaggregated/router.py
View file @
8435b993
...
@@ -48,7 +48,7 @@ class Router:
...
@@ -48,7 +48,7 @@ class Router:
async
def
_get_ctx_resp
(
self
,
request
,
ctx_client
):
async
def
_get_ctx_resp
(
self
,
request
,
ctx_client
):
logger
.
debug
(
f
"Received request
{
request
}
"
)
logger
.
debug
(
f
"Received request
{
request
}
"
)
request
.
max_tokens
=
1
request
.
max_
completion_
tokens
=
1
request
.
disaggregated_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
request
.
disaggregated_params
=
DisaggregatedParams
(
request_type
=
"context_only"
)
logger
.
debug
(
f
"[router] Sending request to context server:
{
request
}
"
)
logger
.
debug
(
f
"[router] Sending request to context server:
{
request
}
"
)
ctx_resp
=
[
ctx_resp
=
[
...
@@ -97,6 +97,9 @@ class Router:
...
@@ -97,6 +97,9 @@ class Router:
async
for
response
in
await
self
.
gen_completion_client
.
round_robin
(
async
for
response
in
await
self
.
gen_completion_client
.
round_robin
(
gen_req
.
model_dump_json
()
gen_req
.
model_dump_json
()
):
):
logger
.
debug
(
f
"[router] Received response from generation server:
{
response
.
data
()
}
"
)
gen_resp_obj
=
DisaggCompletionStreamResponse
.
model_validate
(
gen_resp_obj
=
DisaggCompletionStreamResponse
.
model_validate
(
response
.
data
()
response
.
data
()
)
)
...
@@ -130,7 +133,10 @@ class Router:
...
@@ -130,7 +133,10 @@ class Router:
async
for
response
in
await
self
.
gen_chat_client
.
round_robin
(
async
for
response
in
await
self
.
gen_chat_client
.
round_robin
(
gen_req
.
model_dump_json
()
gen_req
.
model_dump_json
()
):
):
gen_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate
(
logger
.
debug
(
f
"[router] Received response from generation server:
{
response
.
data
()
}
"
)
gen_resp_obj
=
DisaggChatCompletionStreamResponse
.
model_validate_json
(
response
.
data
()
response
.
data
()
)
)
yield
json
.
loads
(
gen_resp_obj
.
model_dump_json
(
exclude_unset
=
True
))
yield
json
.
loads
(
gen_resp_obj
.
model_dump_json
(
exclude_unset
=
True
))
...
...
examples/python_rs/llm/tensorrt_llm/disaggregated/worker.py
View file @
8435b993
...
@@ -136,7 +136,6 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
...
@@ -136,7 +136,6 @@ class TensorrtLLMEngine(BaseTensorrtLLMEngine):
streaming
=
request
.
stream
,
streaming
=
request
.
stream
,
disaggregated_params
=
disaggregated_params
,
disaggregated_params
=
disaggregated_params
,
):
):
self
.
generate_event
.
set
()
final_result
=
result
final_result
=
result
logger
.
debug
(
f
"Generated result:
{
result
}
"
)
logger
.
debug
(
f
"Generated result:
{
result
}
"
)
if
self
.
server_config
.
type
==
"ctx"
:
if
self
.
server_config
.
type
==
"ctx"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment