"deploy/utils/vscode:/vscode.git/clone" did not exist on "d81a00efdf29d3e39b8a52d6b8151ac117400e81"
Unverified Commit bf19823d authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Support Dynamo KVBM with TRTLLM Disagg (#3527)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 0e0d6c16
......@@ -28,6 +28,7 @@ from tensorrt_llm.llmapi import (
SchedulerConfig,
)
from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_options
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from tensorrt_llm.metrics import MetricsCollector
......@@ -107,6 +108,22 @@ async def get_engine_runtime_config(
return runtime_config
def build_kv_connector_config(config: Config):
if config.connector is not None:
if config.connector == "kvbm":
return KvCacheConnectorConfig(
connector_module="kvbm.trtllm_integration.connector",
connector_scheduler_class="DynamoKVBMConnectorLeader",
connector_worker_class="DynamoKVBMConnectorWorker",
)
elif config.connector == "none":
return None
else:
logging.error(f"Invalid connector: {config.connector}")
sys.exit(1)
return None
async def worker():
config = cmd_line_args()
......@@ -166,6 +183,9 @@ async def init(runtime: DistributedRuntime, config: Config):
free_gpu_memory_fraction=config.free_gpu_memory_fraction
)
if config.connector is not None and "kvbm" in config.connector:
kv_cache_config.enable_partial_reuse = False
dynamic_batch_config = DynamicBatchConfig(
enable_batch_size_tuning=True,
enable_max_num_tokens_tuning=False,
......@@ -175,6 +195,8 @@ async def init(runtime: DistributedRuntime, config: Config):
capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
dynamic_batch_config=dynamic_batch_config,
)
kv_connector_config = build_kv_connector_config(config)
modality = getattr(config, "modality", None) or "text"
arg_map = {
"model": model_path,
......@@ -190,6 +212,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_beam_width": config.max_beam_width,
"max_batch_size": config.max_batch_size,
"return_perf_metrics": config.publish_events_and_metrics,
"kv_connector_config": kv_connector_config,
}
if config.extra_engine_args != "":
......
......@@ -281,6 +281,13 @@ def cmd_line_args():
choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
)
parser.add_argument(
"--connector",
type=str,
default="none",
choices=["none", "kvbm"],
help="Connector to use for the model.",
)
add_config_dump_args(parser)
parser.add_argument(
"--custom-jinja-template",
......@@ -380,6 +387,7 @@ def cmd_line_args():
config.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
# Derive use_kv_events from publish_events_and_metrics
config.use_kv_events = config.publish_events_and_metrics
config.connector = args.connector
# Handle custom jinja template path expansion (environment variables and home directory)
if args.custom_jinja_template:
......
......@@ -25,7 +25,7 @@ To learn what KVBM is, please check [here](kvbm_architecture.md)
> - Ensure that `etcd` and `nats` are running before starting.
> - KVBM only supports TensorRT-LLM’s PyTorch backend.
> - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits.
> - KVBM requires TensorRT-LLM v1.1.0rc5 or newer.
> - KVBM requires TensorRT-LLM v1.2.0rc2 or newer.
> - Enabling KVBM metrics with TensorRT-LLM is still a work in progress.
## Quick Start
......@@ -106,6 +106,16 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json"
```
KVBM is also supported on the prefill worker of disaggregated serving. To launch the prefill worker, run:
```bash
# [DYNAMO] To serve an LLM model with dynamo
python3 -m dynamo.trtllm \
--model-path Qwen/Qwen3-0.6B \
--served-model-name Qwen/Qwen3-0.6B \
--extra-engine-args /tmp/kvbm_llm_api_config.yaml
--disaggregation-mode prefill &
```
Alternatively, can use "trtllm-serve" with KVBM by replacing the above two [DYNAMO] cmds with below:
```bash
trtllm-serve Qwen/Qwen3-0.6B --host localhost --port 8000 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml
......
......@@ -5,6 +5,7 @@ import logging
import os
from typing import List, Optional
import tensorrt_llm
from kvbm import KvbmLeader
from kvbm.trtllm_integration.consolidator_config import is_truthy
from kvbm.trtllm_integration.rust import KvbmRequest
......@@ -118,6 +119,12 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
output = RustSchedulerOutput()
for req in scheduler_output.new_requests:
if not hasattr(req, "num_scheduled_tokens"):
raise ValueError(
f"""num_scheduled_tokens is not found in the SchedulerOutput!
You're currently using TRTLLM {tensorrt_llm.__version__}
The mimimum supported version is 1.2.0rc2"""
)
output.add_new_request(
str(req.request_id),
req.new_tokens,
......@@ -135,6 +142,14 @@ class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
req.computed_position,
)
output.add_num_scheduled_tokens(
{
str(req.request_id): req.num_scheduled_tokens
for req in scheduler_output.new_requests
+ scheduler_output.cached_requests
}
)
return self._connector.build_connector_metadata(output)
def get_num_new_matched_tokens(
......
......@@ -110,18 +110,6 @@ pub trait Slot: std::fmt::Debug {
num_scheduled_tokens: usize,
) -> Result<(), SlotError>;
// TRT-LLM does not include scheduled tokens in the scheduler output.
// Ideally, we should have a dedicated implementation for the TRT-LLM slot.
// However, since only this single function needs to be rewritten for now,
// we keep it as a separate function in Slot.
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>;
......@@ -642,111 +630,6 @@ impl Slot for VllmConnectorSlot {
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))]
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
is_new_request: bool,
) -> Result<(), SlotError> {
// TRTLLM's KV Connector Manager will have (computed_position - external matches)
// in onborading case
if computed_position < self.current_position {
tracing::debug!(
"computed_position={} < current_position={}, so we are onboarding during prefilling phase",
computed_position,
self.current_position
);
return Ok(());
}
// now we decide what we should do for the new computed tokens
tracing::debug!(
"applying scheduler output, computed_position={}, sequence_total_tokens={}",
computed_position,
self.sequence.total_tokens()
);
if computed_position < self.sequence.total_tokens() {
// no need to apply new tokens, since it's applied when created the slot during prefilling
self.state = SlotState::Prefilling;
} else {
tracing::debug!(
"appending {} newly decoded tokens to sequence",
tokens.len()
);
self.sequence.extend(tokens.into()).unwrap();
self.state = SlotState::Decoding;
}
// apply new block_ids, this should be applied for both prefilling and decoding
// because this is unknown when creating the slot
if !block_ids.is_empty() {
tracing::debug!("assigning {} new device blocks slot", block_ids.len());
self.device_blocks.extend(block_ids);
}
// This approach is fragile, but it’s the only way currently to skip evaluating
// the device matched blocks and to avoid offloading them again.
// TODO: Consider adding an indicator in the scheduler output to distinguish between
// matched and unmatched device blocks/tokens from the scheduler.
let maybe_have_device_matched_blocks =
is_new_request && computed_position > 0 && self.evaluated_blocks == 0;
if maybe_have_device_matched_blocks {
self.evaluated_blocks = (computed_position + 1) / self.block_size;
}
let num_candidate_blocks =
((computed_position + 1) / self.block_size).saturating_sub(self.evaluated_blocks);
if num_candidate_blocks > 0 {
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet.
// for now, offload all the blocks to the host
let offload_block_ids: Vec<usize> = self
.device_blocks
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.copied()
.collect::<Vec<_>>();
assert_eq!(
offload_block_ids.len(),
num_candidate_blocks,
"device block overflow - candidate blocks exceed block count at offset {}",
self.evaluated_blocks
);
let offload_token_blocks: Vec<TokenBlock> = self
.sequence
.blocks()
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.cloned()
.collect::<Vec<_>>();
self.offload_blocks(&offload_block_ids, &offload_token_blocks)
.expect("failed to offload blocks");
self.evaluated_blocks += num_candidate_blocks;
}
// done applying policy
tracing::debug!(
"done applying kv cache policy at current_position: {}; computed_position: {}",
self.current_position,
computed_position,
);
// advance current position to computed position
self.current_position = computed_position;
Ok(())
}
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> {
if self.iteration_first_scheduled.is_none() {
self.iteration_first_scheduled = Some(iteration);
......
......@@ -351,11 +351,16 @@ impl Leader for KvConnectorLeader {
slot.state()
);
slot.apply_scheduler_output_with_computed_position(
let scheduled_tokens = *scheduler_output
.num_scheduled_tokens
.get(request_id)
.unwrap_or(&0);
slot.apply_scheduler_output(
&new_req.prompt_token_ids,
&new_req.block_ids,
new_req.num_computed_tokens,
true,
scheduled_tokens,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......@@ -382,11 +387,16 @@ impl Leader for KvConnectorLeader {
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
slot.apply_scheduler_output_with_computed_position(
let scheduled_tokens = *scheduler_output
.num_scheduled_tokens
.get(request_id)
.unwrap_or(&0);
slot.apply_scheduler_output(
&cached_req.new_token_ids,
&cached_req.new_block_ids,
cached_req.num_computed_tokens,
false,
scheduled_tokens,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
......
......@@ -21,12 +21,14 @@ import os
import signal
import subprocess
import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional, TextIO
from typing import Any, Dict, Optional, TextIO
import pytest
import requests
import yaml
from .common import DeterminismTester, ServerType
from .common import TestDeterminism as BaseTestDeterminism
......@@ -105,12 +107,14 @@ class LLMServerManager:
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_dynamo_config(self, router_mode: str = "kv"):
def _set_up_dynamo_config(self, router_mode: str = "round-robin"):
self.dynamo_frontend_cmd = [
"python3",
"-m",
......@@ -165,6 +169,86 @@ class LLMServerManager:
["--num-gpu-blocks-override", str(gpu_cache_blocks)]
)
def _set_up_trtllm_config(self, gpu_cache_blocks):
# Mostly the same parameters here as in the
prefill_config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_PREFILL_CONFIG_PATH",
"/tmp/kvbm_llm_api_prefill_config.yaml",
)
decode_config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_DECODE_CONFIG_PATH",
"/tmp/kvbm_llm_api_decode_config.yaml",
)
KV_BLOCK_SIZE = 16
llm_api_config: Dict[str, Any] = {}
llm_api_config["kv_cache_config"] = {
"enable_partial_reuse": False,
"free_gpu_memory_fraction": 0.10,
"tokens_per_block": KV_BLOCK_SIZE,
}
# GPU blocks override
if gpu_cache_blocks is not None:
del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
llm_api_config["kv_cache_config"]["max_tokens"] = (
int(gpu_cache_blocks) * KV_BLOCK_SIZE
)
prefill_config = deepcopy(llm_api_config)
prefill_config["disable_overlap_scheduler"] = True
prefill_config["cache_transceiver_config"] = {
"backend": "DEFAULT",
"max_tokens_in_buffer": 16384,
}
prefill_config["cuda_graph_config"] = None
decode_config = deepcopy(llm_api_config)
decode_config["disable_overlap_scheduler"] = False
decode_config["cache_transceiver_config"] = {
"backend": "DEFAULT",
"max_tokens_in_buffer": 65536,
}
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
cmd_root = [
"python3",
"-m",
"dynamo.trtllm",
"--model",
model,
"--kv-block-size",
"16",
"--max-num-tokens",
"8000",
]
self.prefiller_cmd = cmd_root + [
"--extra-engine-args",
prefill_config_path,
"--disaggregation-mode",
"prefill",
"--connector",
"kvbm",
]
self.decoder_cmd = cmd_root + [
"--extra-engine-args",
decode_config_path,
"--disaggregation-mode",
"decode",
]
with open(prefill_config_path, "w") as f:
yaml.dump(prefill_config, f, default_flow_style=False, sort_keys=False)
with open(decode_config_path, "w") as f:
yaml.dump(decode_config, f, default_flow_style=False, sort_keys=False)
def start_server(self, timeout: int = 300) -> bool:
"""Start LLM server and wait for readiness."""
if self.is_server_running():
......@@ -345,6 +429,7 @@ class LLMServerManager:
# First check basic health
response = requests.get(f"{self.base_url}/health", timeout=5)
if response.status_code != 200:
print(f"Health check failed with status code: {response.status_code}")
return False
# Then check if the model endpoint is ready with a simple test request
......@@ -363,9 +448,14 @@ class LLMServerManager:
json=test_payload,
timeout=10,
)
if response.status_code != 200:
print(
f"Model endpoint test failed with status code: {response.status_code}"
)
return response.status_code == 200
except requests.exceptions.RequestException:
except requests.exceptions.RequestException as e:
print(f"Error checking server status: {e}")
return False
......@@ -419,6 +509,8 @@ def llm_server(request, runtime_services):
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
pytest.skip("vllm module is not available in the current environment.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment