Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
bf19823d
Unverified
Commit
bf19823d
authored
Jan 20, 2026
by
jthomson04
Committed by
GitHub
Jan 20, 2026
Browse files
feat: Support Dynamo KVBM with TRTLLM Disagg (#3527)
Signed-off-by:
jthomson04
<
jwillthomson19@gmail.com
>
parent
0e0d6c16
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
166 additions
and
125 deletions
+166
-125
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+23
-0
components/src/dynamo/trtllm/utils/trtllm_utils.py
components/src/dynamo/trtllm/utils/trtllm_utils.py
+8
-0
docs/kvbm/trtllm-setup.md
docs/kvbm/trtllm-setup.md
+11
-1
lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py
...vbm/trtllm_integration/connector/kvbm_connector_leader.py
+15
-0
lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs
...ings/kvbm/src/block_manager/vllm/connector/leader/slot.rs
+0
-117
lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs
...gs/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs
+14
-4
tests/kvbm_integration/test_determinism_disagg.py
tests/kvbm_integration/test_determinism_disagg.py
+95
-3
No files found.
components/src/dynamo/trtllm/main.py
View file @
bf19823d
...
...
@@ -28,6 +28,7 @@ from tensorrt_llm.llmapi import (
SchedulerConfig
,
)
from
tensorrt_llm.llmapi.llm
import
SamplingParams
from
tensorrt_llm.llmapi.llm_args
import
KvCacheConnectorConfig
from
tensorrt_llm.llmapi.llm_utils
import
update_llm_args_with_extra_options
from
tensorrt_llm.llmapi.tokenizer
import
tokenizer_factory
from
tensorrt_llm.metrics
import
MetricsCollector
...
...
@@ -107,6 +108,22 @@ async def get_engine_runtime_config(
return
runtime_config
def
build_kv_connector_config
(
config
:
Config
):
if
config
.
connector
is
not
None
:
if
config
.
connector
==
"kvbm"
:
return
KvCacheConnectorConfig
(
connector_module
=
"kvbm.trtllm_integration.connector"
,
connector_scheduler_class
=
"DynamoKVBMConnectorLeader"
,
connector_worker_class
=
"DynamoKVBMConnectorWorker"
,
)
elif
config
.
connector
==
"none"
:
return
None
else
:
logging
.
error
(
f
"Invalid connector:
{
config
.
connector
}
"
)
sys
.
exit
(
1
)
return
None
async
def
worker
():
config
=
cmd_line_args
()
...
...
@@ -166,6 +183,9 @@ async def init(runtime: DistributedRuntime, config: Config):
free_gpu_memory_fraction
=
config
.
free_gpu_memory_fraction
)
if
config
.
connector
is
not
None
and
"kvbm"
in
config
.
connector
:
kv_cache_config
.
enable_partial_reuse
=
False
dynamic_batch_config
=
DynamicBatchConfig
(
enable_batch_size_tuning
=
True
,
enable_max_num_tokens_tuning
=
False
,
...
...
@@ -175,6 +195,8 @@ async def init(runtime: DistributedRuntime, config: Config):
capacity_scheduler_policy
=
CapacitySchedulerPolicy
.
GUARANTEED_NO_EVICT
,
dynamic_batch_config
=
dynamic_batch_config
,
)
kv_connector_config
=
build_kv_connector_config
(
config
)
modality
=
getattr
(
config
,
"modality"
,
None
)
or
"text"
arg_map
=
{
"model"
:
model_path
,
...
...
@@ -190,6 +212,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_beam_width"
:
config
.
max_beam_width
,
"max_batch_size"
:
config
.
max_batch_size
,
"return_perf_metrics"
:
config
.
publish_events_and_metrics
,
"kv_connector_config"
:
kv_connector_config
,
}
if
config
.
extra_engine_args
!=
""
:
...
...
components/src/dynamo/trtllm/utils/trtllm_utils.py
View file @
bf19823d
...
...
@@ -281,6 +281,13 @@ def cmd_line_args():
choices
=
get_reasoning_parser_names
(),
help
=
"Reasoning parser name for the model. If not specified, no reasoning parsing is performed."
,
)
parser
.
add_argument
(
"--connector"
,
type
=
str
,
default
=
"none"
,
choices
=
[
"none"
,
"kvbm"
],
help
=
"Connector to use for the model."
,
)
add_config_dump_args
(
parser
)
parser
.
add_argument
(
"--custom-jinja-template"
,
...
...
@@ -380,6 +387,7 @@ def cmd_line_args():
config
.
enable_local_indexer
=
str
(
args
.
enable_local_indexer
).
lower
()
==
"true"
# Derive use_kv_events from publish_events_and_metrics
config
.
use_kv_events
=
config
.
publish_events_and_metrics
config
.
connector
=
args
.
connector
# Handle custom jinja template path expansion (environment variables and home directory)
if
args
.
custom_jinja_template
:
...
...
docs/kvbm/trtllm-setup.md
View file @
bf19823d
...
...
@@ -25,7 +25,7 @@ To learn what KVBM is, please check [here](kvbm_architecture.md)
> - Ensure that `etcd` and `nats` are running before starting.
> - KVBM only supports TensorRT-LLM’s PyTorch backend.
> - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits.
> - KVBM requires TensorRT-LLM v1.
1
.0rc
5
or newer.
> - KVBM requires TensorRT-LLM v1.
2
.0rc
2
or newer.
> - Enabling KVBM metrics with TensorRT-LLM is still a work in progress.
## Quick Start
...
...
@@ -106,6 +106,16 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json"
```
KVBM is also supported on the prefill worker of disaggregated serving. To launch the prefill worker, run:
```
bash
# [DYNAMO] To serve an LLM model with dynamo
python3
-m
dynamo.trtllm
\
--model-path
Qwen/Qwen3-0.6B
\
--served-model-name
Qwen/Qwen3-0.6B
\
--extra-engine-args
/tmp/kvbm_llm_api_config.yaml
--disaggregation-mode
prefill &
```
Alternatively, can use "trtllm-serve" with KVBM by replacing the above two [DYNAMO] cmds with below:
```
bash
trtllm-serve Qwen/Qwen3-0.6B
--host
localhost
--port
8000
--backend
pytorch
--extra_llm_api_options
/tmp/kvbm_llm_api_config.yaml
...
...
lib/bindings/kvbm/python/kvbm/trtllm_integration/connector/kvbm_connector_leader.py
View file @
bf19823d
...
...
@@ -5,6 +5,7 @@ import logging
import
os
from
typing
import
List
,
Optional
import
tensorrt_llm
from
kvbm
import
KvbmLeader
from
kvbm.trtllm_integration.consolidator_config
import
is_truthy
from
kvbm.trtllm_integration.rust
import
KvbmRequest
...
...
@@ -118,6 +119,12 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
output
=
RustSchedulerOutput
()
for
req
in
scheduler_output
.
new_requests
:
if
not
hasattr
(
req
,
"num_scheduled_tokens"
):
raise
ValueError
(
f
"""num_scheduled_tokens is not found in the SchedulerOutput!
You're currently using TRTLLM
{
tensorrt_llm
.
__version__
}
The mimimum supported version is 1.2.0rc2"""
)
output
.
add_new_request
(
str
(
req
.
request_id
),
req
.
new_tokens
,
...
...
@@ -135,6 +142,14 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req
.
computed_position
,
)
output
.
add_num_scheduled_tokens
(
{
str
(
req
.
request_id
):
req
.
num_scheduled_tokens
for
req
in
scheduler_output
.
new_requests
+
scheduler_output
.
cached_requests
}
)
return
self
.
_connector
.
build_connector_metadata
(
output
)
def
get_num_new_matched_tokens
(
...
...
lib/bindings/kvbm/src/block_manager/vllm/connector/leader/slot.rs
View file @
bf19823d
...
...
@@ -110,18 +110,6 @@ pub trait Slot: std::fmt::Debug {
num_scheduled_tokens
:
usize
,
)
->
Result
<
(),
SlotError
>
;
// TRT-LLM does not include scheduled tokens in the scheduler output.
// Ideally, we should have a dedicated implementation for the TRT-LLM slot.
// However, since only this single function needs to be rewritten for now,
// we keep it as a separate function in Slot.
fn
apply_scheduler_output_with_computed_position
(
&
mut
self
,
tokens
:
&
[
u32
],
block_ids
:
&
[
usize
],
computed_position
:
usize
,
is_new_request
:
bool
,
)
->
Result
<
(),
SlotError
>
;
fn
record_start_iteration
(
&
mut
self
,
iteration
:
u64
)
->
Result
<
(),
SlotError
>
;
fn
mark_as_prefilling
(
&
mut
self
,
iteration
:
u64
)
->
Result
<
(),
SlotError
>
;
...
...
@@ -642,111 +630,6 @@ impl Slot for VllmConnectorSlot {
Ok
(())
}
#[tracing::instrument(level
=
"debug"
,
skip_all,
fields(request_id
=
self
.
request_id
.
as_str()))]
fn
apply_scheduler_output_with_computed_position
(
&
mut
self
,
tokens
:
&
[
u32
],
block_ids
:
&
[
usize
],
computed_position
:
usize
,
is_new_request
:
bool
,
)
->
Result
<
(),
SlotError
>
{
// TRTLLM's KV Connector Manager will have (computed_position - external matches)
// in onborading case
if
computed_position
<
self
.current_position
{
tracing
::
debug!
(
"computed_position={} < current_position={}, so we are onboarding during prefilling phase"
,
computed_position
,
self
.current_position
);
return
Ok
(());
}
// now we decide what we should do for the new computed tokens
tracing
::
debug!
(
"applying scheduler output, computed_position={}, sequence_total_tokens={}"
,
computed_position
,
self
.sequence
.total_tokens
()
);
if
computed_position
<
self
.sequence
.total_tokens
()
{
// no need to apply new tokens, since it's applied when created the slot during prefilling
self
.state
=
SlotState
::
Prefilling
;
}
else
{
tracing
::
debug!
(
"appending {} newly decoded tokens to sequence"
,
tokens
.len
()
);
self
.sequence
.extend
(
tokens
.into
())
.unwrap
();
self
.state
=
SlotState
::
Decoding
;
}
// apply new block_ids, this should be applied for both prefilling and decoding
// because this is unknown when creating the slot
if
!
block_ids
.is_empty
()
{
tracing
::
debug!
(
"assigning {} new device blocks slot"
,
block_ids
.len
());
self
.device_blocks
.extend
(
block_ids
);
}
// This approach is fragile, but it’s the only way currently to skip evaluating
// the device matched blocks and to avoid offloading them again.
// TODO: Consider adding an indicator in the scheduler output to distinguish between
// matched and unmatched device blocks/tokens from the scheduler.
let
maybe_have_device_matched_blocks
=
is_new_request
&&
computed_position
>
0
&&
self
.evaluated_blocks
==
0
;
if
maybe_have_device_matched_blocks
{
self
.evaluated_blocks
=
(
computed_position
+
1
)
/
self
.block_size
;
}
let
num_candidate_blocks
=
((
computed_position
+
1
)
/
self
.block_size
)
.saturating_sub
(
self
.evaluated_blocks
);
if
num_candidate_blocks
>
0
{
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet.
// for now, offload all the blocks to the host
let
offload_block_ids
:
Vec
<
usize
>
=
self
.device_blocks
.iter
()
.skip
(
self
.evaluated_blocks
)
.take
(
num_candidate_blocks
)
.copied
()
.collect
::
<
Vec
<
_
>>
();
assert_eq!
(
offload_block_ids
.len
(),
num_candidate_blocks
,
"device block overflow - candidate blocks exceed block count at offset {}"
,
self
.evaluated_blocks
);
let
offload_token_blocks
:
Vec
<
TokenBlock
>
=
self
.sequence
.blocks
()
.iter
()
.skip
(
self
.evaluated_blocks
)
.take
(
num_candidate_blocks
)
.cloned
()
.collect
::
<
Vec
<
_
>>
();
self
.offload_blocks
(
&
offload_block_ids
,
&
offload_token_blocks
)
.expect
(
"failed to offload blocks"
);
self
.evaluated_blocks
+=
num_candidate_blocks
;
}
// done applying policy
tracing
::
debug!
(
"done applying kv cache policy at current_position: {}; computed_position: {}"
,
self
.current_position
,
computed_position
,
);
// advance current position to computed position
self
.current_position
=
computed_position
;
Ok
(())
}
fn
record_start_iteration
(
&
mut
self
,
iteration
:
u64
)
->
Result
<
(),
SlotError
>
{
if
self
.iteration_first_scheduled
.is_none
()
{
self
.iteration_first_scheduled
=
Some
(
iteration
);
...
...
lib/bindings/kvbm/src/block_manager/vllm/connector/trtllm_leader.rs
View file @
bf19823d
...
...
@@ -351,11 +351,16 @@ impl Leader for KvConnectorLeader {
slot
.state
()
);
slot
.apply_scheduler_output_with_computed_position
(
let
scheduled_tokens
=
*
scheduler_output
.num_scheduled_tokens
.get
(
request_id
)
.unwrap_or
(
&
0
);
slot
.apply_scheduler_output
(
&
new_req
.prompt_token_ids
,
&
new_req
.block_ids
,
new_req
.num_computed_tokens
,
true
,
scheduled_tokens
,
)
?
;
if
let
Some
(
pending_ops
)
=
slot
.take_pending_operations
()
{
...
...
@@ -382,11 +387,16 @@ impl Leader for KvConnectorLeader {
.lock
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to lock slot: {}"
,
e
))
?
;
slot
.apply_scheduler_output_with_computed_position
(
let
scheduled_tokens
=
*
scheduler_output
.num_scheduled_tokens
.get
(
request_id
)
.unwrap_or
(
&
0
);
slot
.apply_scheduler_output
(
&
cached_req
.new_token_ids
,
&
cached_req
.new_block_ids
,
cached_req
.num_computed_tokens
,
false
,
scheduled_tokens
,
)
?
;
if
let
Some
(
pending_ops
)
=
slot
.take_pending_operations
()
{
...
...
tests/kvbm_integration/test_determinism_disagg.py
View file @
bf19823d
...
...
@@ -21,12 +21,14 @@ import os
import
signal
import
subprocess
import
time
from
copy
import
deepcopy
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Optional
,
TextIO
from
typing
import
Any
,
Dict
,
Optional
,
TextIO
import
pytest
import
requests
import
yaml
from
.common
import
DeterminismTester
,
ServerType
from
.common
import
TestDeterminism
as
BaseTestDeterminism
...
...
@@ -105,12 +107,14 @@ class LLMServerManager:
if
self
.
server_type
==
ServerType
.
vllm
:
self
.
_set_up_vllm_config
(
gpu_cache_blocks
)
elif
self
.
server_type
==
ServerType
.
trtllm
:
self
.
_set_up_trtllm_config
(
gpu_cache_blocks
)
else
:
raise
ValueError
(
f
"
{
self
.
server_type
}
is not supported yet in the KVBM test suite"
)
def
_set_up_dynamo_config
(
self
,
router_mode
:
str
=
"
kv
"
):
def
_set_up_dynamo_config
(
self
,
router_mode
:
str
=
"
round-robin
"
):
self
.
dynamo_frontend_cmd
=
[
"python3"
,
"-m"
,
...
...
@@ -165,6 +169,86 @@ class LLMServerManager:
[
"--num-gpu-blocks-override"
,
str
(
gpu_cache_blocks
)]
)
def
_set_up_trtllm_config
(
self
,
gpu_cache_blocks
):
# Mostly the same parameters here as in the
prefill_config_path
=
os
.
environ
.
get
(
"KVBM_TRTLLM_LLMAPI_PREFILL_CONFIG_PATH"
,
"/tmp/kvbm_llm_api_prefill_config.yaml"
,
)
decode_config_path
=
os
.
environ
.
get
(
"KVBM_TRTLLM_LLMAPI_DECODE_CONFIG_PATH"
,
"/tmp/kvbm_llm_api_decode_config.yaml"
,
)
KV_BLOCK_SIZE
=
16
llm_api_config
:
Dict
[
str
,
Any
]
=
{}
llm_api_config
[
"kv_cache_config"
]
=
{
"enable_partial_reuse"
:
False
,
"free_gpu_memory_fraction"
:
0.10
,
"tokens_per_block"
:
KV_BLOCK_SIZE
,
}
# GPU blocks override
if
gpu_cache_blocks
is
not
None
:
del
llm_api_config
[
"kv_cache_config"
][
"free_gpu_memory_fraction"
]
llm_api_config
[
"kv_cache_config"
][
"max_tokens"
]
=
(
int
(
gpu_cache_blocks
)
*
KV_BLOCK_SIZE
)
prefill_config
=
deepcopy
(
llm_api_config
)
prefill_config
[
"disable_overlap_scheduler"
]
=
True
prefill_config
[
"cache_transceiver_config"
]
=
{
"backend"
:
"DEFAULT"
,
"max_tokens_in_buffer"
:
16384
,
}
prefill_config
[
"cuda_graph_config"
]
=
None
decode_config
=
deepcopy
(
llm_api_config
)
decode_config
[
"disable_overlap_scheduler"
]
=
False
decode_config
[
"cache_transceiver_config"
]
=
{
"backend"
:
"DEFAULT"
,
"max_tokens_in_buffer"
:
65536
,
}
model
=
os
.
environ
.
get
(
"KVBM_MODEL_ID"
,
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
cmd_root
=
[
"python3"
,
"-m"
,
"dynamo.trtllm"
,
"--model"
,
model
,
"--kv-block-size"
,
"16"
,
"--max-num-tokens"
,
"8000"
,
]
self
.
prefiller_cmd
=
cmd_root
+
[
"--extra-engine-args"
,
prefill_config_path
,
"--disaggregation-mode"
,
"prefill"
,
"--connector"
,
"kvbm"
,
]
self
.
decoder_cmd
=
cmd_root
+
[
"--extra-engine-args"
,
decode_config_path
,
"--disaggregation-mode"
,
"decode"
,
]
with
open
(
prefill_config_path
,
"w"
)
as
f
:
yaml
.
dump
(
prefill_config
,
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
with
open
(
decode_config_path
,
"w"
)
as
f
:
yaml
.
dump
(
decode_config
,
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
def
start_server
(
self
,
timeout
:
int
=
300
)
->
bool
:
"""Start LLM server and wait for readiness."""
if
self
.
is_server_running
():
...
...
@@ -345,6 +429,7 @@ class LLMServerManager:
# First check basic health
response
=
requests
.
get
(
f
"
{
self
.
base_url
}
/health"
,
timeout
=
5
)
if
response
.
status_code
!=
200
:
print
(
f
"Health check failed with status code:
{
response
.
status_code
}
"
)
return
False
# Then check if the model endpoint is ready with a simple test request
...
...
@@ -363,9 +448,14 @@ class LLMServerManager:
json
=
test_payload
,
timeout
=
10
,
)
if
response
.
status_code
!=
200
:
print
(
f
"Model endpoint test failed with status code:
{
response
.
status_code
}
"
)
return
response
.
status_code
==
200
except
requests
.
exceptions
.
RequestException
:
except
requests
.
exceptions
.
RequestException
as
e
:
print
(
f
"Error checking server status:
{
e
}
"
)
return
False
...
...
@@ -419,6 +509,8 @@ def llm_server(request, runtime_services):
if
importlib
.
util
.
find_spec
(
"vllm"
)
is
not
None
:
server_type
=
ServerType
.
vllm
elif
importlib
.
util
.
find_spec
(
"tensorrt_llm"
)
is
not
None
:
server_type
=
ServerType
.
trtllm
else
:
pytest
.
skip
(
"vllm module is not available in the current environment."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment