Unverified Commit a68c2f8f authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

feat: DIS-373 dynamo KVBM connector API integration with TRTLLM (#2544)


Signed-off-by: default avatarrichardhuo-nv <rihuo@nvidia.com>
parent 43a26958
...@@ -66,7 +66,7 @@ jobs: ...@@ -66,7 +66,7 @@ jobs:
docker run -v ${{ github.workspace }}:/workspace -w /workspace \ docker run -v ${{ github.workspace }}:/workspace -w /workspace \
--name ${{ env.CONTAINER_ID }}_pytest \ --name ${{ env.CONTAINER_ID }}_pytest \
${{ steps.define_image_tag.outputs.image_tag }} \ ${{ steps.define_image_tag.outputs.image_tag }} \
bash -c "pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m \"${{ env.PYTEST_MARKS }}\"" bash -c "pytest --basetemp=/tmp --junitxml=${{ env.PYTEST_XML_FILE }} -m \"${{ env.PYTEST_MARKS }}\" --ignore /workspace/lib/bindings/python/src/dynamo/llm/trtllm_integration/connector "
- name: Copy test report from test Container - name: Copy test report from test Container
if: always() if: always()
run: | run: |
......
...@@ -15,6 +15,11 @@ jobs: ...@@ -15,6 +15,11 @@ jobs:
steps: steps:
- name: Check out repository - name: Check out repository
uses: actions/checkout@v4 uses: actions/checkout@v4
with:
# For pull_request events, use the PR head (commit from the contributor's branch/repo)
repository: ${{ github.event.pull_request.head.repo.full_name || github.repository }}
ref: ${{ github.event.pull_request.head.sha || github.sha }}
fetch-depth: 0
# Cache lychee results (e.g. to avoid hitting rate limits) # Cache lychee results (e.g. to avoid hitting rate limits)
# https://lychee.cli.rs/github_action_recipes/caching/ # https://lychee.cli.rs/github_action_recipes/caching/
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Running KVBM in TensorRT-LLM
This guide explains how to leverage KVBM (KV Block Manager) to mange KV cache and do KV offloading in TensorRT-LLM (trtllm).
To learn what KVBM is, please check [here](https://docs.nvidia.com/dynamo/latest/architecture/kvbm_intro.html)
> [!Note]
> - Ensure that `etcd` and `nats` are running before starting.
> - KVBM does not currently support CUDA graphs in TensorRT-LLM.
> - KVBM only supports TensorRT-LLM’s PyTorch backend.
> - To enable disk cache offloading, you must first enable a CPU memory cache offloading.
> - Disable partial reuse `enable_partial_reuse: false` in the LLM API config’s `kv_connector_config` to increase offloading cache hits.
> - KVBM requires TensorRT-LLM at commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer.
> - Enabling KVBM metrics with TensorRT-LLM is still a work in progress.
## Quick Start
To use KVBM in TensorRT-LLM, you can follow the steps below:
```bash
# start up etcd for KVBM leader/worker registration and discovery
docker compose -f deploy/docker-compose.yml up -d
# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer.
./container/build.sh --framework trtllm --tensorrtllm-commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 --enable-kvbm
# launch the container
./container/run.sh --framework trtllm -it --mount-workspace --use-nixl-gds
# enable kv offloading to CPU memory
# 60 means 60GB of pinned CPU memory would be used
export DYN_KVBM_CPU_CACHE_GB=60
# enable kv offloading to disk. Note: To enable disk cache offloading, you must first enable a CPU memory cache offloading.
# 20 means 20GB of disk would be used
export DYN_KVBM_DISK_CACHE_GB=20
# Allocating memory and disk storage can take some time.
# We recommend setting a higher timeout for leader–worker initialization.
# 1200 means 1200 seconds timeout
export DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS=1200
```
```bash
# write an example LLM API config
# Note: Disable partial reuse "enable_partial_reuse: false" in the LLM API config’s "kv_connector_config" to increase offloading cache hits.
cat > "/tmp/kvbm_llm_api_config.yaml" <<EOF
backend: pytorch
cuda_graph_config: null
kv_cache_config:
enable_partial_reuse: false
free_gpu_memory_fraction: 0.80
kv_connector_config:
connector_module: dynamo.llm.trtllm_integration.connector
connector_scheduler_class: DynamoKVBMConnectorLeader
connector_worker_class: DynamoKVBMConnectorWorker
EOF
# start dynamo frontend
python3 -m dynamo.frontend --http-port 8000 &
# To serve an LLM model with dynamo
python3 -m dynamo.trtllm \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--extra-engine-args /tmp/kvbm_llm_api_config.yaml &
# make a call to LLM
curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
],
"stream":false,
"max_tokens": 30
}'
# Optionally, we could also serve an LLM with trtllm-serve to utilize the KVBM feature.
trtllm-serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B --host localhost --port 8001 --backend pytorch --extra_llm_api_options /tmp/kvbm_llm_api_config.yaml
```
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
use super::*; use super::*;
use anyhow::Result;
use dynamo_llm::block_manager::block::{ use dynamo_llm::block_manager::block::{
data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical, data::logical::distributed_leader_worker::DistributedLeaderWorkerResources, locality::Logical,
}; };
...@@ -220,3 +221,113 @@ impl BlockManager { ...@@ -220,3 +221,113 @@ impl BlockManager {
&self.inner &self.inner
} }
} }
#[derive(Default)]
pub struct BlockManagerBuilder {
worker_id: u64,
leader: Option<distributed::KvbmLeader>,
page_size: usize,
disable_device_pool: bool,
}
impl BlockManagerBuilder {
pub fn new() -> Self {
Self {
page_size: 32, // default consistent with BlockManager::new
..Default::default()
}
}
pub fn worker_id(mut self, id: u64) -> Self {
self.worker_id = id;
self
}
pub fn page_size(mut self, ps: usize) -> Self {
self.page_size = ps;
self
}
pub fn leader(mut self, l: distributed::KvbmLeader) -> Self {
self.leader = Some(l);
self
}
pub fn disable_device_pool(mut self, yes: bool) -> Self {
self.disable_device_pool = yes;
self
}
/// Async build (call from an async context).
pub async fn build(self) -> Result<BlockManager> {
let worker_id = self.worker_id;
let leader = self.leader.ok_or_else(|| {
anyhow::anyhow!("leader is required (runtime is always taken from leader)")
})?;
// Get (inner leader handle, runtime) from the provided leader.
let (leader_inner, drt) = leader.dissolve();
let cancel_token = CancellationToken::new();
// Runtime & model config
let runtime_config = dynamo_llm::block_manager::KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.cancellation_token(cancel_token.clone())
.build()?;
let mut config =
dynamo_llm::block_manager::KvBlockManagerConfig::builder().runtime(runtime_config);
let model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
.num_layers(1)
.outer_dim(1)
.page_size(self.page_size)
.inner_dim(1)
.build()?;
config = config.model(model_config);
// Layouts derived from leader’s counts
if !self.disable_device_pool {
config = config.device_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_device_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}
if leader_inner.num_host_blocks() > 0 {
config = config.host_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_host_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}
if leader_inner.num_disk_blocks() > 0 {
config = config.disk_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(leader_inner.num_disk_blocks())
.logical(Some(BlockParallelismStrategy::LeaderWorkerSharded))
.build()?,
);
}
let config = config.build()?;
let resources =
DistributedLeaderWorkerResources::new(Some(leader_inner), cancel_token.child_token())?;
let inner = dynamo_llm::block_manager::KvBlockManager::<
Logical<DistributedLeaderWorkerResources>,
BasicMetadata,
>::new(config, resources)
.await?;
Ok(BlockManager {
inner,
drt,
_controller: None,
})
}
}
...@@ -8,5 +8,5 @@ mod utils; ...@@ -8,5 +8,5 @@ mod utils;
mod worker; mod worker;
pub use leader::KvbmLeader; pub use leader::KvbmLeader;
pub use utils::get_barrier_id; pub use utils::get_barrier_id_prefix;
pub use worker::{KvbmWorker, VllmTensor}; pub use worker::{KvbmWorker, VllmTensor};
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*; use super::*;
use utils::get_barrier_id; use utils::get_barrier_id_prefix;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use llm_rs::block_manager::distributed::{KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig}; use llm_rs::block_manager::distributed::{
KvbmLeader as KvbmLeaderImpl, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig,
};
const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB"; const CPU_CACHE: &str = "DYN_KVBM_CPU_CACHE_GB";
const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"; const CPU_CACHE_OVERRIDE: &str = "DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS";
...@@ -16,15 +18,32 @@ const DISK_CACHE_OVERRIDE: &str = "DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS"; ...@@ -16,15 +18,32 @@ const DISK_CACHE_OVERRIDE: &str = "DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS";
const LEADER_WORKER_INIT_TIMEOUT_SECS: &str = "DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS"; const LEADER_WORKER_INIT_TIMEOUT_SECS: &str = "DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS";
const DEFAULT_INIT_TIMEOUT_SECS: u64 = 120; const DEFAULT_INIT_TIMEOUT_SECS: u64 = 120;
fn compute_num_blocks(cache_size_key: &str, override_key: &str, bytes_per_block: usize) -> usize { fn read_env_usize(key: &str) -> Option<usize> {
if let Ok(override_num_blocks) = std::env::var(override_key) { std::env::var(key).ok()?.trim().parse::<usize>().ok()
override_num_blocks.parse::<usize>().unwrap_or(0) }
} else {
let cache_size_gb = std::env::var(cache_size_key) fn read_cache_size_float(key: &str) -> f64 {
std::env::var(key)
.unwrap_or_default() .unwrap_or_default()
.parse::<f64>() .parse::<f64>()
.unwrap_or(0.0); .unwrap_or(0.0)
((cache_size_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize }
fn get_blocks_config(cache_size_key: &str, override_key: &str) -> KvbmLeaderNumBlocksConfig {
if let Some(nblocks) = read_env_usize(override_key) {
// Optional: still read cache size for observability, but override takes precedence.
let cache_gb: f64 = read_cache_size_float(cache_size_key);
return KvbmLeaderNumBlocksConfig {
cache_size_in_gb: cache_gb,
num_blocks_overriden: nblocks,
};
}
// No override -> compute from cache size (in GB)
let cache_gb: f64 = read_cache_size_float(cache_size_key);
KvbmLeaderNumBlocksConfig {
cache_size_in_gb: cache_gb,
num_blocks_overriden: 0,
} }
} }
...@@ -51,22 +70,19 @@ impl KvbmLeader { ...@@ -51,22 +70,19 @@ impl KvbmLeader {
#[pymethods] #[pymethods]
impl KvbmLeader { impl KvbmLeader {
#[new] #[new]
#[pyo3(signature = (bytes_per_block, world_size, drt))] #[pyo3(signature = (world_size, drt))]
fn new(bytes_per_block: usize, world_size: usize, drt: DistributedRuntime) -> PyResult<Self> { fn new(world_size: usize, drt: DistributedRuntime) -> PyResult<Self> {
let num_host_blocks = compute_num_blocks(CPU_CACHE, CPU_CACHE_OVERRIDE, bytes_per_block); let barrier_id_prefix = get_barrier_id_prefix();
let num_disk_blocks = compute_num_blocks(DISK_CACHE, DISK_CACHE_OVERRIDE, bytes_per_block);
let barrier_id = get_barrier_id();
let leader_init_timeout_sec: u64 = let leader_init_timeout_sec: u64 =
get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS); get_leader_init_timeout_secs(LEADER_WORKER_INIT_TIMEOUT_SECS);
let config = KvbmLeaderConfig::builder() let config = KvbmLeaderConfig::builder()
.barrier_id(barrier_id) .barrier_id_prefix(barrier_id_prefix)
.num_host_blocks(num_host_blocks)
.num_disk_blocks(num_disk_blocks)
.world_size(world_size) .world_size(world_size)
.leader_init_timeout_secs(leader_init_timeout_sec) .leader_init_timeout_secs(leader_init_timeout_sec)
.drt(drt.inner().clone()) .drt(drt.inner().clone())
.host_blocks_config(get_blocks_config(CPU_CACHE, CPU_CACHE_OVERRIDE))
.disk_blocks_config(get_blocks_config(DISK_CACHE, DISK_CACHE_OVERRIDE))
.build() .build()
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub fn get_barrier_id() -> String { pub fn get_barrier_id_prefix() -> String {
std::env::var("DYN_KVBM_BARRIER_ID").unwrap_or("kvbm".to_string()) std::env::var("DYN_KVBM_BARRIER_ID_PREFIX")
.ok()
.filter(|s| !s.trim().is_empty())
.unwrap_or_else(|| "kvbm".to_string())
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use super::*; use super::*;
use std::sync::Arc; use std::sync::Arc;
use utils::get_barrier_id; use utils::get_barrier_id_prefix;
use llm_rs::block_manager::distributed::{ use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl, BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
...@@ -107,7 +107,7 @@ impl KvbmWorker { ...@@ -107,7 +107,7 @@ impl KvbmWorker {
#[pymethods] #[pymethods]
impl KvbmWorker { impl KvbmWorker {
#[new] #[new]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None))] #[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false))]
fn new( fn new(
num_device_blocks: usize, num_device_blocks: usize,
page_size: usize, page_size: usize,
...@@ -115,6 +115,7 @@ impl KvbmWorker { ...@@ -115,6 +115,7 @@ impl KvbmWorker {
device_id: usize, device_id: usize,
dtype_width_bytes: usize, dtype_width_bytes: usize,
drt: Option<DistributedRuntime>, drt: Option<DistributedRuntime>,
layout_blocking: bool,
) -> PyResult<Self> { ) -> PyResult<Self> {
let py_drt = drt.ok_or_else(|| { let py_drt = drt.ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided") pyo3::exceptions::PyValueError::new_err("DistributedRuntime (drt) must be provided")
...@@ -131,7 +132,7 @@ impl KvbmWorker { ...@@ -131,7 +132,7 @@ impl KvbmWorker {
vllm_tensors.push(Arc::new(vllm_tensor)); vllm_tensors.push(Arc::new(vllm_tensor));
} }
let barrier_id = get_barrier_id(); let barrier_id_prefix = get_barrier_id_prefix();
let config = KvbmWorkerConfig::builder() let config = KvbmWorkerConfig::builder()
.drt(drt) .drt(drt)
...@@ -140,13 +141,13 @@ impl KvbmWorker { ...@@ -140,13 +141,13 @@ impl KvbmWorker {
.tensors(vllm_tensors) .tensors(vllm_tensors)
.device_id(device_id) .device_id(device_id)
.dtype_width_bytes(dtype_width_bytes) .dtype_width_bytes(dtype_width_bytes)
.barrier_id(barrier_id) .barrier_id_prefix(barrier_id_prefix)
.build() .build()
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
let worker = rt let worker = rt
.block_on(async move { .block_on(async move {
let kvbm_worker = KvbmWorkerImpl::new(config).await?; let kvbm_worker = KvbmWorkerImpl::new(config, layout_blocking).await?;
anyhow::Ok(kvbm_worker) anyhow::Ok(kvbm_worker)
}) })
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -50,6 +50,9 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -50,6 +50,9 @@ fn _vllm_integration(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<connector::worker::PyKvConnectorWorker>()?; m.add_class::<connector::worker::PyKvConnectorWorker>()?;
m.add_class::<connector::leader::PyKvConnectorLeader>()?; m.add_class::<connector::leader::PyKvConnectorLeader>()?;
m.add_class::<connector::SchedulerOutput>()?; m.add_class::<connector::SchedulerOutput>()?;
// TODO: use TRTLLM own integration module
m.add_class::<connector::trtllm_worker::PyTrtllmKvConnectorWorker>()?;
m.add_class::<connector::trtllm_leader::PyTrtllmKvConnectorLeader>()?;
Ok(()) Ok(())
} }
......
...@@ -7,6 +7,8 @@ use dynamo_llm::block_manager::{ ...@@ -7,6 +7,8 @@ use dynamo_llm::block_manager::{
}; };
pub mod leader; pub mod leader;
pub mod trtllm_leader;
pub mod trtllm_worker;
pub mod worker; pub mod worker;
use pyo3::prelude::*; use pyo3::prelude::*;
......
...@@ -9,7 +9,7 @@ use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics; ...@@ -9,7 +9,7 @@ use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use slot::{ConnectorSlotManager, SlotError, SlotManager, SlotState}; use slot::{ConnectorSlotManager, SlotError, SlotManager, SlotState};
use crate::llm::block_manager::BlockManager as PyBlockManager; use crate::llm::block_manager::BlockManagerBuilder;
use crate::llm::block_manager::{ use crate::llm::block_manager::{
distributed::KvbmLeader as PyKvbmLeader, vllm::connector::leader::slot::VllmConnectorSlot, distributed::KvbmLeader as PyKvbmLeader, vllm::connector::leader::slot::VllmConnectorSlot,
vllm::KvbmRequest, VllmBlockManager, vllm::KvbmRequest, VllmBlockManager,
...@@ -26,10 +26,12 @@ use dynamo_llm::block_manager::{ ...@@ -26,10 +26,12 @@ use dynamo_llm::block_manager::{
BasicMetadata, DiskStorage, ImmutableBlock, PinnedStorage, BasicMetadata, DiskStorage, ImmutableBlock, PinnedStorage,
}; };
use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens}; use dynamo_llm::tokens::{SaltHash, TokenBlockSequence, Tokens};
use std::sync::{Arc, OnceLock};
use std::{collections::HashSet, sync::Mutex}; use std::{collections::HashSet, sync::Mutex};
use tokio; use tokio;
use tokio::runtime::Handle;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot;
type VllmLocality = Logical<DistributedLeaderWorkerResources>; type VllmLocality = Logical<DistributedLeaderWorkerResources>;
...@@ -71,11 +73,13 @@ pub trait Leader: Send + Sync + std::fmt::Debug { ...@@ -71,11 +73,13 @@ pub trait Leader: Send + Sync + std::fmt::Debug {
fn has_slot(&self, request_id: String) -> bool; fn has_slot(&self, request_id: String) -> bool;
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()>; fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()>;
fn slot_manager(&self) -> &ConnectorSlotManager<String>;
} }
#[derive(Debug)] #[derive(Debug)]
pub struct KvConnectorLeader { pub struct KvConnectorLeader {
slot_manager: ConnectorSlotManager<String>, slot_manager: Arc<OnceLock<ConnectorSlotManager<String>>>,
block_size: usize, block_size: usize,
inflight_requests: HashSet<String>, inflight_requests: HashSet<String>,
onboarding_slots: HashSet<String>, onboarding_slots: HashSet<String>,
...@@ -87,37 +91,86 @@ impl KvConnectorLeader { ...@@ -87,37 +91,86 @@ impl KvConnectorLeader {
fn new( fn new(
worker_id: String, worker_id: String,
drt: PyDistributedRuntime, drt: PyDistributedRuntime,
block_manager: PyBlockManager, page_size: usize,
leader: PyKvbmLeader, leader_py: PyKvbmLeader,
) -> Self { ) -> Self {
tracing::info!( tracing::info!(
"KvConnectorLeader initialized with worker_id: {}", "KvConnectorLeader initialized with worker_id: {}",
worker_id worker_id
); );
// if drt is none, then we must construct a runtime and distributed runtime let leader = leader_py.get_inner().clone();
let block_manager = block_manager.get_block_manager().clone();
let block_size = block_manager.block_size();
let leader = leader.get_inner();
// if we need a drt, get it from here
let drt = drt.inner().clone(); let drt = drt.inner().clone();
let handle: Handle = drt.runtime().primary();
let ns = drt let ns = drt
.namespace(kvbm_connector::KVBM_CONNECTOR_LEADER) .namespace(kvbm_connector::KVBM_CONNECTOR_LEADER)
.unwrap(); .unwrap();
let kvbm_metrics = KvbmMetrics::new(&ns); let kvbm_metrics = KvbmMetrics::new(&ns);
let kvbm_metrics_clone = kvbm_metrics.clone();
Self { let slot_manager_cell = Arc::new(OnceLock::new());
slot_manager: ConnectorSlotManager::new( let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<String>();
block_manager.clone(),
leader, {
let slot_manager_cell = slot_manager_cell.clone();
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
if !ready {
tracing::error!(
"KvConnectorLeader init aborted: leader worker barrier not ready!",
);
return;
}
let block_manager = match BlockManagerBuilder::new()
.worker_id(0)
.leader(leader_py)
.page_size(page_size)
.disable_device_pool(false)
.build()
.await
{
Ok(bm) => bm,
Err(e) => {
tracing::error!("Failed to build BlockManager: {}", e);
return;
}
};
// Create the slot manager now that everything is ready
let sm = ConnectorSlotManager::new(
block_manager.get_block_manager().clone(),
leader.clone(),
drt.clone(), drt.clone(),
kvbm_metrics.clone(), kvbm_metrics_clone.clone(),
), );
block_size,
let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
let _ = leader.run_leader_readiness_barrier_blocking(drt);
if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent");
}
});
}
tokio::task::block_in_place(|| {
handle.block_on(async {
match leader_ready_rx.await {
Ok(_) => tracing::info!("KvConnectorLeader init complete."),
Err(_) => tracing::warn!("KvConnectorLeader init channel dropped"),
}
});
});
Self {
slot_manager: slot_manager_cell,
block_size: page_size,
inflight_requests: HashSet::new(), inflight_requests: HashSet::new(),
onboarding_slots: HashSet::new(), onboarding_slots: HashSet::new(),
iteration_counter: 0, iteration_counter: 0,
...@@ -127,6 +180,13 @@ impl KvConnectorLeader { ...@@ -127,6 +180,13 @@ impl KvConnectorLeader {
} }
impl Leader for KvConnectorLeader { impl Leader for KvConnectorLeader {
#[inline]
fn slot_manager(&self) -> &ConnectorSlotManager<String> {
self.slot_manager
.get()
.expect("slot_manager not initialized")
}
/// Match the tokens in the request with the available block pools. /// Match the tokens in the request with the available block pools.
/// Note: the necessary details of the request are captured prior to this call. For vllm, /// Note: the necessary details of the request are captured prior to this call. For vllm,
/// we make a create slot call prior to this call, so a slot is guaranteed to exist. /// we make a create slot call prior to this call, so a slot is guaranteed to exist.
...@@ -147,7 +207,7 @@ impl Leader for KvConnectorLeader { ...@@ -147,7 +207,7 @@ impl Leader for KvConnectorLeader {
// the number of device matched tokens should be less than or equal to the number of tokens in the request // the number of device matched tokens should be less than or equal to the number of tokens in the request
debug_assert!(num_computed_tokens % self.block_size == 0); debug_assert!(num_computed_tokens % self.block_size == 0);
let shared_slot = self.slot_manager.get_slot(&request_id)?; let shared_slot = self.slot_manager().get_slot(&request_id)?;
let mut slot = shared_slot let mut slot = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -215,7 +275,7 @@ impl Leader for KvConnectorLeader { ...@@ -215,7 +275,7 @@ impl Leader for KvConnectorLeader {
num_external_tokens num_external_tokens
); );
let shared_slot = self.slot_manager.get_slot(&request_id)?; let shared_slot = self.slot_manager().get_slot(&request_id)?;
let mut slot = shared_slot let mut slot = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -271,7 +331,7 @@ impl Leader for KvConnectorLeader { ...@@ -271,7 +331,7 @@ impl Leader for KvConnectorLeader {
// This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot // This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot
// once for onboarding (this loop), then again for prefill/decode (new_requests loop). // once for onboarding (this loop), then again for prefill/decode (new_requests loop).
for request_id in onboarding_slots.iter() { for request_id in onboarding_slots.iter() {
let shared_slot = self.slot_manager.get_slot(request_id)?; let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot let mut slot = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -300,7 +360,7 @@ impl Leader for KvConnectorLeader { ...@@ -300,7 +360,7 @@ impl Leader for KvConnectorLeader {
"request_id {request_id} not found in inflight_requests: " "request_id {request_id} not found in inflight_requests: "
); );
let shared_slot = self.slot_manager.get_slot(request_id)?; let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot let mut slot = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -343,7 +403,7 @@ impl Leader for KvConnectorLeader { ...@@ -343,7 +403,7 @@ impl Leader for KvConnectorLeader {
// we really do not know what to expect here: // we really do not know what to expect here:
// first let's try to get the slot, it might fail because maybe preemption put us thru // first let's try to get the slot, it might fail because maybe preemption put us thru
// a finished cycle -- who knows // a finished cycle -- who knows
let shared_slot = self.slot_manager.get_slot(request_id); let shared_slot = self.slot_manager().get_slot(request_id);
match &shared_slot { match &shared_slot {
Ok(_) => { Ok(_) => {
tracing::info!("after preemption, slot is still alive"); tracing::info!("after preemption, slot is still alive");
...@@ -371,7 +431,7 @@ impl Leader for KvConnectorLeader { ...@@ -371,7 +431,7 @@ impl Leader for KvConnectorLeader {
"request_id {request_id} not found in inflight_requests: " "request_id {request_id} not found in inflight_requests: "
); );
let shared_slot = self.slot_manager.get_slot(request_id)?; let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot let mut slot = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -399,7 +459,7 @@ impl Leader for KvConnectorLeader { ...@@ -399,7 +459,7 @@ impl Leader for KvConnectorLeader {
} }
for unscheduled_req in inflight_requests.iter() { for unscheduled_req in inflight_requests.iter() {
let shared_slot = self.slot_manager.get_slot(unscheduled_req)?; let shared_slot = self.slot_manager().get_slot(unscheduled_req)?;
let mut slot_guard = shared_slot let mut slot_guard = shared_slot
.lock() .lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
...@@ -424,7 +484,7 @@ impl Leader for KvConnectorLeader { ...@@ -424,7 +484,7 @@ impl Leader for KvConnectorLeader {
) -> anyhow::Result<bool> { ) -> anyhow::Result<bool> {
tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}"); tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}");
if !self.slot_manager.has_slot(&request_id) { if !self.slot_manager().has_slot(&request_id) {
tracing::warn!( tracing::warn!(
"request_finished called for request_id: {request_id} but slot is not found" "request_finished called for request_id: {request_id} but slot is not found"
); );
...@@ -433,7 +493,7 @@ impl Leader for KvConnectorLeader { ...@@ -433,7 +493,7 @@ impl Leader for KvConnectorLeader {
} }
// grab the slot // grab the slot
let shared_slot = self.slot_manager.get_slot(&request_id)?; let shared_slot = self.slot_manager().get_slot(&request_id)?;
// mark the slot as finished // mark the slot as finished
let mut slot = shared_slot let mut slot = shared_slot
...@@ -450,7 +510,7 @@ impl Leader for KvConnectorLeader { ...@@ -450,7 +510,7 @@ impl Leader for KvConnectorLeader {
self.inflight_requests.remove(&request_id); self.inflight_requests.remove(&request_id);
// remove it from the manager as we will never use it again // remove it from the manager as we will never use it again
self.slot_manager.remove_slot(&request_id)?; self.slot_manager().remove_slot(&request_id)?;
// if the slot has finished, we can return false to vllm, indicating all gpu blocks are free to be reused // if the slot has finished, we can return false to vllm, indicating all gpu blocks are free to be reused
// otherwise, we return true, which means there are still outstanding operations on gpu blocks which // otherwise, we return true, which means there are still outstanding operations on gpu blocks which
...@@ -465,13 +525,13 @@ impl Leader for KvConnectorLeader { ...@@ -465,13 +525,13 @@ impl Leader for KvConnectorLeader {
} }
fn has_slot(&self, request_id: String) -> bool { fn has_slot(&self, request_id: String) -> bool {
self.slot_manager.has_slot(&request_id) self.slot_manager().has_slot(&request_id)
} }
/// Create a new slot for the given request ID. /// Create a new slot for the given request ID.
/// This is used to create a new slot for the request. /// This is used to create a new slot for the request.
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()> { fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()> {
self.slot_manager self.slot_manager()
.create_slot(&request.request_id, tokens, request.salt_hash)?; .create_slot(&request.request_id, tokens, request.salt_hash)?;
self.inflight_requests.insert(request.request_id); self.inflight_requests.insert(request.request_id);
...@@ -488,11 +548,11 @@ pub struct PyKvConnectorLeader { ...@@ -488,11 +548,11 @@ pub struct PyKvConnectorLeader {
#[pymethods] #[pymethods]
impl PyKvConnectorLeader { impl PyKvConnectorLeader {
#[new] #[new]
#[pyo3(signature = (worker_id, drt, block_manager, leader))] #[pyo3(signature = (worker_id, drt, page_size, leader))]
pub fn new( pub fn new(
worker_id: String, worker_id: String,
drt: PyDistributedRuntime, drt: PyDistributedRuntime,
block_manager: PyBlockManager, page_size: usize,
leader: PyKvbmLeader, leader: PyKvbmLeader,
) -> Self { ) -> Self {
let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD") let enable_kvbm_record = std::env::var("ENABLE_KVBM_RECORD")
...@@ -501,18 +561,10 @@ impl PyKvConnectorLeader { ...@@ -501,18 +561,10 @@ impl PyKvConnectorLeader {
let connector_leader: Box<dyn Leader> = if enable_kvbm_record { let connector_leader: Box<dyn Leader> = if enable_kvbm_record {
Box::new(recorder::KvConnectorLeaderRecorder::new( Box::new(recorder::KvConnectorLeaderRecorder::new(
worker_id, worker_id, drt, page_size, leader,
drt,
block_manager,
leader,
)) ))
} else { } else {
Box::new(KvConnectorLeader::new( Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader))
worker_id,
drt,
block_manager,
leader,
))
}; };
Self { connector_leader } Self { connector_leader }
} }
......
...@@ -88,51 +88,35 @@ impl KvConnectorLeaderRecorder { ...@@ -88,51 +88,35 @@ impl KvConnectorLeaderRecorder {
pub fn new( pub fn new(
worker_id: String, worker_id: String,
drt: PyDistributedRuntime, drt: PyDistributedRuntime,
block_manager: PyBlockManager, page_size: usize,
leader: PyKvbmLeader, leader_py: PyKvbmLeader,
) -> Self { ) -> Self {
tracing::info!( tracing::info!(
"KvConnectorLeaderRecorder initialized with worker_id: {}", "KvConnectorLeaderRecorder initialized with worker_id: {}",
worker_id worker_id
); );
// if drt is none, then we must construct a runtime and distributed runtime let leader = leader_py.get_inner().clone();
let block_manager = block_manager.get_block_manager().clone(); let drt = drt.inner().clone();
let block_size = block_manager.block_size(); let handle: Handle = drt.runtime().primary();
let leader = leader.get_inner(); let ns = drt
.namespace(kvbm_connector::KVBM_CONNECTOR_LEADER)
.unwrap();
// if we need a drt, get it from here let kvbm_metrics = KvbmMetrics::new(&ns);
let drt = drt.inner().clone(); let kvbm_metrics_clone = kvbm_metrics.clone();
let token = CancellationToken::new(); let token = CancellationToken::new();
let output_path = "/tmp/records.jsonl"; let output_path = "/tmp/records.jsonl";
tracing::info!("recording events to {}", output_path); tracing::info!("recording events to {}", output_path);
let ns = drt.namespace("kvbm_connector_leader").unwrap();
let kvbm_metrics = KvbmMetrics::new(&ns);
let recorder = drt let recorder = drt
.runtime() .runtime()
.primary() .primary()
.block_on(async { Recorder::new(token, &output_path, None, None, None).await }) .block_on(async { Recorder::new(token, &output_path, None, None, None).await })
.unwrap(); .unwrap();
let connector_leader = KvConnectorLeader {
slot_manager: ConnectorSlotManager::new(
block_manager.clone(),
leader,
drt.clone(),
kvbm_metrics.clone(),
),
block_size,
inflight_requests: HashSet::new(),
onboarding_slots: HashSet::new(),
iteration_counter: 0,
kvbm_metrics,
};
let (unbounded_tx, unbounded_rx) = mpsc::unbounded_channel(); let (unbounded_tx, unbounded_rx) = mpsc::unbounded_channel();
let recorder_tx = recorder.event_sender(); let recorder_tx = recorder.event_sender();
...@@ -141,6 +125,73 @@ impl KvConnectorLeaderRecorder { ...@@ -141,6 +125,73 @@ impl KvConnectorLeaderRecorder {
.primary() .primary()
.spawn(Self::forward_unbounded_to_sender(unbounded_rx, recorder_tx)); .spawn(Self::forward_unbounded_to_sender(unbounded_rx, recorder_tx));
let slot_manager_cell = Arc::new(OnceLock::new());
let (leader_ready_tx, leader_ready_rx) = oneshot::channel::<String>();
{
let slot_manager_cell = slot_manager_cell.clone();
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
if !ready {
tracing::error!(
"KvConnectorLeader init aborted: leader worker barrier not ready!",
);
return;
}
let block_manager = match BlockManagerBuilder::new()
.worker_id(0)
.leader(leader_py)
.page_size(page_size)
.disable_device_pool(false)
.build()
.await
{
Ok(bm) => bm,
Err(e) => {
tracing::error!("Failed to build BlockManager: {}", e);
return;
}
};
// Create the slot manager now that everything is ready
let sm = ConnectorSlotManager::new(
block_manager.get_block_manager().clone(),
leader.clone(),
drt.clone(),
kvbm_metrics_clone.clone(),
);
let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);
if leader_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("main routine receiver dropped before result was sent");
}
});
}
tokio::task::block_in_place(|| {
handle.block_on(async {
match leader_ready_rx.await {
Ok(_) => tracing::info!("KvConnectorLeader init complete."),
Err(_) => tracing::warn!("KvConnectorLeader init channel dropped"),
}
});
});
let connector_leader = KvConnectorLeader {
slot_manager: slot_manager_cell,
block_size: page_size,
inflight_requests: HashSet::new(),
onboarding_slots: HashSet::new(),
iteration_counter: 0,
kvbm_metrics,
};
Self { Self {
_recorder: recorder, _recorder: recorder,
unbounded_tx, unbounded_tx,
...@@ -161,6 +212,10 @@ impl KvConnectorLeaderRecorder { ...@@ -161,6 +212,10 @@ impl KvConnectorLeaderRecorder {
} }
impl Leader for KvConnectorLeaderRecorder { impl Leader for KvConnectorLeaderRecorder {
#[inline]
fn slot_manager(&self) -> &ConnectorSlotManager<String> {
self.connector_leader.slot_manager()
}
/// Match the tokens in the request with the available block pools. /// Match the tokens in the request with the available block pools.
/// Note: the necessary details of the request are captured prior to this call. For vllm, /// Note: the necessary details of the request are captured prior to this call. For vllm,
/// we make a create slot call prior to this call, so a slot is guaranteed to exist. /// we make a create slot call prior to this call, so a slot is guaranteed to exist.
......
...@@ -106,6 +106,17 @@ pub trait Slot: std::fmt::Debug { ...@@ -106,6 +106,17 @@ pub trait Slot: std::fmt::Debug {
num_scheduled_tokens: usize, num_scheduled_tokens: usize,
) -> Result<(), SlotError>; ) -> Result<(), SlotError>;
// TRT-LLM does not include scheduled tokens in the scheduler output.
// Ideally, we should have a dedicated implementation for the TRT-LLM slot.
// However, since only this single function needs to be rewritten for now,
// we keep it as a separate function in Slot.
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
) -> Result<(), SlotError>;
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>; fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError>;
fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>; fn mark_as_prefilling(&mut self, iteration: u64) -> Result<(), SlotError>;
...@@ -228,6 +239,11 @@ impl<R: RequestKey> SlotManager<R> for ConnectorSlotManager<R> { ...@@ -228,6 +239,11 @@ impl<R: RequestKey> SlotManager<R> for ConnectorSlotManager<R> {
tokens: Vec<u32>, tokens: Vec<u32>,
salt_hash: SaltHash, salt_hash: SaltHash,
) -> Result<(), SlotError> { ) -> Result<(), SlotError> {
tracing::debug!(
"creating slot with request_id: {}, num_tokens: {}",
request_id,
tokens.len()
);
let slot = VllmConnectorSlot::new( let slot = VllmConnectorSlot::new(
request_id.to_string(), request_id.to_string(),
tokens.into(), tokens.into(),
...@@ -566,6 +582,98 @@ impl Slot for VllmConnectorSlot { ...@@ -566,6 +582,98 @@ impl Slot for VllmConnectorSlot {
Ok(()) Ok(())
} }
#[tracing::instrument(level = "debug", skip_all, fields(request_id = self.request_id.as_str()))]
fn apply_scheduler_output_with_computed_position(
&mut self,
tokens: &[u32],
block_ids: &[usize],
computed_position: usize,
) -> Result<(), SlotError> {
// TRTLLM's KV Connector Manager will have (computed_position - external matches)
// in onborading case
if computed_position < self.current_position {
tracing::debug!(
"computed_position={} < current_position={}, so we are onboarding during prefilling phase",
computed_position, self.current_position
);
return Ok(());
}
// now we decide what we should do for the new computed tokens
tracing::debug!(
"applying scheduler output, computed_position={}, sequence_total_tokens={}",
computed_position,
self.sequence.total_tokens()
);
if computed_position < self.sequence.total_tokens() {
// no need to apply new tokens, since it's applied when created the slot during prefilling
self.state = SlotState::Prefilling;
} else {
tracing::debug!(
"appending {} newly decoded tokens to sequence",
tokens.len()
);
self.sequence.extend(tokens.into()).unwrap();
self.state = SlotState::Decoding;
}
// apply new block_ids, this should be applied for both prefilling and decoding
// because this is unknown when creating the slot
if !block_ids.is_empty() {
tracing::debug!("assigning {} new device blocks slot", block_ids.len());
self.device_blocks.extend(block_ids);
}
let num_candidate_blocks =
((computed_position + 1) / self.block_size) - self.evaluated_blocks;
if num_candidate_blocks != 0 {
// do we have a mechanism for skipping gpu cache hit blocks? not sure yet.
// for now, offload all the blocks to the host
let offload_block_ids: Vec<usize> = self
.device_blocks
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.copied()
.collect::<Vec<_>>();
assert_eq!(
offload_block_ids.len(),
num_candidate_blocks,
"device block overflow - candidate blocks exceed block count at offset {}",
self.evaluated_blocks
);
let offload_token_blocks: Vec<TokenBlock> = self
.sequence
.blocks()
.iter()
.skip(self.evaluated_blocks)
.take(num_candidate_blocks)
.cloned()
.collect::<Vec<_>>();
self.offload_blocks(&offload_block_ids, &offload_token_blocks)
.expect("failed to offload blocks");
self.evaluated_blocks += num_candidate_blocks;
}
// done applying policy
tracing::debug!(
"done applying kv cache policy at current_position: {}; computed_position: {}",
self.current_position,
computed_position,
);
// advance current position to computed position
self.current_position = computed_position;
Ok(())
}
fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> { fn record_start_iteration(&mut self, iteration: u64) -> Result<(), SlotError> {
if self.iteration_first_scheduled.is_none() { if self.iteration_first_scheduled.is_none() {
self.iteration_first_scheduled = Some(iteration); self.iteration_first_scheduled = Some(iteration);
...@@ -676,7 +784,7 @@ impl Slot for VllmConnectorSlot { ...@@ -676,7 +784,7 @@ impl Slot for VllmConnectorSlot {
let num_matched_blocks = num_matched_host_blocks + num_matched_disk_blocks; let num_matched_blocks = num_matched_host_blocks + num_matched_disk_blocks;
tracing::debug!( tracing::debug!(
"matched {} host blocks and {} disk blocks; {} total blocks", "successfully matched {} host blocks and {} disk blocks; {} total blocks",
num_matched_host_blocks, num_matched_host_blocks,
num_matched_disk_blocks, num_matched_disk_blocks,
num_matched_blocks num_matched_blocks
...@@ -925,7 +1033,7 @@ impl VllmConnectorSlot { ...@@ -925,7 +1033,7 @@ impl VllmConnectorSlot {
tracing::debug!( tracing::debug!(
request_id = self.request_id, request_id = self.request_id,
operation_id = %operation_id, operation_id = %operation_id,
"onboarding {} blocks from {:?} to device", "start onboarding {} blocks from {:?} to device",
num_blocks, num_blocks,
src_storage_pool, src_storage_pool,
); );
...@@ -1227,10 +1335,12 @@ async fn process_offload_request( ...@@ -1227,10 +1335,12 @@ async fn process_offload_request(
// 4. Wait for the offload request to complete // 4. Wait for the offload request to complete
match notify_receiver.await { match notify_receiver.await {
Ok(_) => { Ok(_) => {
tracing::debug!("Transfer completed successfully"); tracing::debug!("Offloading transfer completed successfully");
} }
Err(_) => { Err(_) => {
return Err(anyhow::anyhow!("Transfer completion notification failed")); return Err(anyhow::anyhow!(
"Offloading transfer completion notification failed"
));
} }
} }
tracing::debug!( tracing::debug!(
...@@ -1301,10 +1411,12 @@ async fn process_onboard_request( ...@@ -1301,10 +1411,12 @@ async fn process_onboard_request(
match notify_receiver.await { match notify_receiver.await {
Ok(_) => { Ok(_) => {
tracing::debug!("Transfer completed successfully"); tracing::debug!("Onboarding transfer completed successfully");
} }
Err(_) => { Err(_) => {
return Err(anyhow::anyhow!("Transfer completion notification failed")); return Err(anyhow::anyhow!(
"Onboarding transfer completion notification failed"
));
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use crate::llm::block_manager::vllm::connector::leader::slot::{
ConnectorSlotManager, SlotManager, SlotState,
};
use crate::llm::block_manager::BlockManagerBuilder;
use crate::llm::block_manager::{distributed::KvbmLeader as PyKvbmLeader, vllm::KvbmRequest};
use crate::DistributedRuntime as PyDistributedRuntime;
use anyhow;
use dynamo_llm::block_manager::metrics_kvbm::KvbmMetrics;
use dynamo_runtime::metrics::prometheus_names::kvbm_connector;
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};
use tokio::runtime::Handle;
pub trait Leader: Send + Sync + std::fmt::Debug {
fn get_num_new_matched_tokens(
&mut self,
request_id: String,
request_num_tokens: usize,
num_computed_tokens: usize,
) -> anyhow::Result<(usize, bool)>;
fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
context_current_position: usize,
) -> anyhow::Result<()>;
fn build_connector_metadata(
&mut self,
scheduler_output: SchedulerOutput,
) -> anyhow::Result<Vec<u8>>;
fn request_finished(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
) -> anyhow::Result<bool>;
fn has_slot(&self, request_id: String) -> bool;
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()>;
fn slot_manager(&self) -> &ConnectorSlotManager<String>;
}
#[derive(Debug)]
pub struct KvConnectorLeader {
slot_manager: Arc<OnceLock<ConnectorSlotManager<String>>>,
block_size: usize,
inflight_requests: HashSet<String>,
onboarding_slots: HashSet<String>,
iteration_counter: u64,
inflight_request_to_num_external_tokens: HashMap<String, usize>,
kvbm_metrics: KvbmMetrics,
}
impl KvConnectorLeader {
fn new(
worker_id: u64,
drt: PyDistributedRuntime,
page_size: usize,
leader_py: PyKvbmLeader,
) -> Self {
tracing::info!(
"KvConnectorLeader initialized with worker_id: {}",
worker_id
);
let leader = leader_py.get_inner().clone();
let drt = drt.inner().clone();
let handle: Handle = drt.runtime().primary();
let ns = drt
.namespace(kvbm_connector::KVBM_CONNECTOR_LEADER)
.unwrap();
let kvbm_metrics = KvbmMetrics::new(&ns);
let kvbm_metrics_clone = kvbm_metrics.clone();
let slot_manager_cell = Arc::new(OnceLock::new());
{
let slot_manager_cell = slot_manager_cell.clone();
handle.spawn(async move {
let ready = leader.wait_worker_sync_ready().await;
if !ready {
tracing::error!(
"KvConnectorLeader init aborted: leader worker barrier not ready!",
);
return;
}
let block_manager = match BlockManagerBuilder::new()
.worker_id(0)
.leader(leader_py)
.page_size(page_size)
.disable_device_pool(false)
.build()
.await
{
Ok(bm) => bm,
Err(e) => {
tracing::error!("Failed to build BlockManager: {}", e);
return;
}
};
// Create the slot manager now that everything is ready
let sm = ConnectorSlotManager::new(
block_manager.get_block_manager().clone(),
leader.clone(),
drt.clone(),
kvbm_metrics_clone.clone(),
);
let _ = slot_manager_cell.set(sm);
// another barrier sync to make sure worker init won't return before leader is ready
leader.spawn_leader_readiness_barrier(drt);
tracing::info!("KvConnectorLeader init complete.");
});
}
Self {
slot_manager: slot_manager_cell,
block_size: page_size,
inflight_requests: HashSet::new(),
onboarding_slots: HashSet::new(),
iteration_counter: 0,
inflight_request_to_num_external_tokens: HashMap::new(),
kvbm_metrics,
}
}
}
impl Leader for KvConnectorLeader {
#[inline]
fn slot_manager(&self) -> &ConnectorSlotManager<String> {
self.slot_manager
.get()
.expect("slot_manager not initialized")
}
/// Match the tokens in the request with the available block pools.
/// Note: the necessary details of the request are captured prior to this call. For trtllm,
/// we make a create slot call prior to this call, so a slot is guaranteed to exist.
///
/// To align with the connector interface, we must ensure that if no blocks are matched, we return (0, false).
/// In our implementation, if we match any block, we return (num_matched_tokens, true).
#[tracing::instrument(level = "debug", skip(self, request_num_tokens, num_computed_tokens))]
fn get_num_new_matched_tokens(
&mut self,
request_id: String,
request_num_tokens: usize,
num_computed_tokens: usize,
) -> anyhow::Result<(usize, bool)> {
tracing::debug!(
"request_num_tokens: {request_num_tokens}; num_computed_tokens: {num_computed_tokens}"
);
// TRTLLM could match partial blocks if enable_partial_reuse = True,
// immediately return 0 to simplify things.
if num_computed_tokens % self.block_size != 0 {
return Ok((0, false));
}
let shared_slot = self.slot_manager().get_slot(&request_id)?;
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
// early exit if we cannot match full block
if (slot.sequence().total_tokens() - num_computed_tokens) < self.block_size {
let total_tokens = slot.sequence().total_tokens();
tracing::debug!(
"total_tokens in sequence: {total_tokens}; num_computed_tokens: {num_computed_tokens}; can not match full block."
);
return Ok((0, false));
}
// find matches for any remaining tokens
// this will advance the computed position and hold any newly matched blocks in the slot
slot.acquire_local_matches(num_computed_tokens)?;
// return the number of external tokens that are ready for onboarding
// we always return true here as we always asynchronously onboard matched blocks
if let SlotState::OnboardStaged(num_external_tokens) = slot.state() {
debug_assert!((num_computed_tokens + num_external_tokens) % self.block_size == 0);
tracing::debug!(
request_id = request_id,
"scheduling onboarding for {} external tokens",
num_external_tokens
);
// Add to the map so that onboarding can be triggered in update_state_after_alloc.
self.inflight_request_to_num_external_tokens
.insert(request_id, num_external_tokens);
self.kvbm_metrics
.matched_tokens
.inc_by(num_external_tokens as u64);
Ok((num_external_tokens, true))
} else {
Ok((0, false))
}
}
/// Note: TRTLLM will not provide any scheduler output data for requests that are onboarding. it is entirely
/// on the connector's implementation to handle this case.
#[tracing::instrument(level = "debug", skip_all, fields(request_id))]
fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
context_current_position: usize,
) -> anyhow::Result<()> {
tracing::debug!(
request_id,
"num_device_blocks: {}, context_current_position: {}",
block_ids.len(),
context_current_position
);
let shared_slot = self.slot_manager().get_slot(&request_id)?;
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
// we have not yet advanced the computed position, but now we can, since we have an indication that we have
// necessary gpu blocks into which we will load the external tokens.
slot.append_mutable_device_blocks(&block_ids)?;
if let Some(&num_external_tokens) = self
.inflight_request_to_num_external_tokens
.get(&request_id)
{
if num_external_tokens > 0 {
let num_computed_tokens = context_current_position - num_external_tokens;
slot.record_cached_device_tokens(num_computed_tokens);
slot.advance_computed_position(num_computed_tokens)?;
tracing::debug!(
request_id = request_id,
"triggering onboarding for {} external tokens",
num_external_tokens
);
slot.trigger_onboarding(num_external_tokens)?;
self.onboarding_slots.insert(request_id.clone());
}
self.inflight_request_to_num_external_tokens
.remove(&request_id);
}
Ok(())
}
#[tracing::instrument(level = "debug", skip_all, fields(iteration = self.iteration_counter + 1))]
fn build_connector_metadata(
&mut self,
scheduler_output: SchedulerOutput,
) -> anyhow::Result<Vec<u8>> {
// the iteration counter is used to track the number of times we have built the connector metadata
// all connetor operations have the iteration counter at which they were issued.
// this allows operations to be lazily enqueued to the transfer engine
// the worker side of the connector will track all operations for completion before the request is
// allowed to be marked as finished.
self.iteration_counter += 1;
let iteration = self.iteration_counter;
tracing::debug!("Building connector metadata");
tracing::debug!("SchedulerOutput: {scheduler_output:#?}");
let mut inflight_requests = self.inflight_requests.clone();
let mut md = ConnectorMetadata::new(iteration);
let onboarding_slots = std::mem::take(&mut self.onboarding_slots);
// Worker-side - we create a request slot for onboarding, then delete it when onboarding is finished, then
// recreate it again when we start the prefill/decode phase.
//
// This is kind of a nice abstraction as it keeps the events simplier; however, we now create the request-slot
// once for onboarding (this loop), then again for prefill/decode (new_requests loop).
for request_id in onboarding_slots.iter() {
let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
md.create_slot(request_id.clone());
if let Some(pending_ops) = slot.take_pending_operations() {
tracing::debug!("adding {} pending onboarding operations", pending_ops.len());
md.add_operations(pending_ops);
}
}
// todo: update the code and abstraction to account for this two-phase lifecycle.
for new_req in &scheduler_output.new_requests {
let request_id = &new_req.request_id;
assert!(
inflight_requests.remove(request_id),
"request_id {request_id} not found in inflight_requests: "
);
let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
// inform the worker that a new request-slot should be created
md.create_slot(new_req.request_id.clone());
slot.record_start_iteration(iteration)?;
debug_assert!(
matches!(
slot.state(),
SlotState::Initialized | SlotState::Onboarding(_)
),
"current slot state: {:?}",
slot.state()
);
slot.apply_scheduler_output_with_computed_position(
&new_req.prompt_token_ids,
&new_req.block_ids,
new_req.num_computed_tokens,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
tracing::debug!(
"adding {} pending operations for slot {}",
pending_ops.len(),
new_req.request_id
);
md.add_operations(pending_ops);
}
}
for cached_req in &scheduler_output.cached_requests {
let request_id = &cached_req.request_id;
// note: evicition might trigger this assert
assert!(
inflight_requests.remove(request_id),
"request_id {request_id} not found in inflight_requests: "
);
let shared_slot = self.slot_manager().get_slot(request_id)?;
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
slot.apply_scheduler_output_with_computed_position(
&cached_req.new_token_ids,
&cached_req.new_block_ids,
cached_req.num_computed_tokens,
)?;
if let Some(pending_ops) = slot.take_pending_operations() {
tracing::debug!(
"adding {} pending operations for slot {}",
pending_ops.len(),
request_id
);
md.add_operations(pending_ops);
}
}
tracing::debug!("metadata: {md:#?}");
serde_json::to_vec(&md)
.map_err(|e| anyhow::anyhow!("Failed to serialize connector metadata: {}", e))
}
fn request_finished(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
) -> anyhow::Result<bool> {
tracing::debug!("Request finished: {request_id}; block_ids: {block_ids:?}");
// grab the slot
let shared_slot = self.slot_manager().get_slot(&request_id)?;
// mark the slot as finished
let mut slot = shared_slot
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock slot: {}", e))?;
slot.mark_as_finished(self.iteration_counter)?;
// todo: allow the request to resolve when it should exit
// the request may have some outstanding operations
// we would like to inform it to shutdown, then have it signal to the work that is officially gone,
// then we can remove the slot and trigger the worker to clean up as well.
// remove it from the manager as we will never use it again
self.slot_manager().remove_slot(&request_id)?;
self.inflight_request_to_num_external_tokens
.remove(&request_id);
// if the slot has finished, we can return false to trtllm, indicating all gpu blocks are free to be reused
// otherwise, we return true, which means there are still outstanding operations on gpu blocks which
// must be awaited before the gpu blocks can be reused. if we return true, then it is the worker side
// of the connector api which will be used to inform trtllm that the request is finished.
if let SlotState::Finished = slot.state() {
Ok(false)
} else {
debug_assert!(matches!(slot.state(), SlotState::Finishing));
Ok(true)
}
}
fn has_slot(&self, request_id: String) -> bool {
self.slot_manager().has_slot(&request_id)
}
/// Create a new slot for the given request ID.
/// This is used to create a new slot for the request.
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> anyhow::Result<()> {
self.slot_manager()
.create_slot(&request.request_id, tokens, request.salt_hash)?;
self.inflight_requests.insert(request.request_id);
Ok(())
}
}
#[pyclass]
pub struct PyTrtllmKvConnectorLeader {
connector_leader: Box<dyn Leader>,
}
#[pymethods]
impl PyTrtllmKvConnectorLeader {
#[new]
#[pyo3(signature = (worker_id, drt, page_size, leader))]
pub fn new(
worker_id: u64,
drt: PyDistributedRuntime,
page_size: usize,
leader: PyKvbmLeader,
) -> Self {
let connector_leader: Box<dyn Leader> =
Box::new(KvConnectorLeader::new(worker_id, drt, page_size, leader));
Self { connector_leader }
}
fn get_num_new_matched_tokens(
&mut self,
request_id: String,
request_num_tokens: usize,
num_computed_tokens: usize,
) -> PyResult<(usize, bool)> {
self.connector_leader
.get_num_new_matched_tokens(request_id, request_num_tokens, num_computed_tokens)
.map_err(to_pyerr)
}
fn update_state_after_alloc(
&mut self,
request_id: String,
block_ids: Vec<BlockId>,
context_current_position: usize,
) -> PyResult<()> {
self.connector_leader
.update_state_after_alloc(request_id, block_ids, context_current_position)
.map_err(to_pyerr)
}
fn build_connector_metadata(&mut self, scheduler_output: SchedulerOutput) -> PyResult<Vec<u8>> {
self.connector_leader
.build_connector_metadata(scheduler_output)
.map_err(to_pyerr)
}
fn request_finished(&mut self, request_id: &str, block_ids: Vec<BlockId>) -> PyResult<bool> {
self.connector_leader
.request_finished(request_id.to_string(), block_ids)
.map_err(to_pyerr)
}
fn has_slot(&self, request_id: &str) -> bool {
self.connector_leader.has_slot(request_id.to_string())
}
fn create_slot(&mut self, request: KvbmRequest, tokens: Vec<u32>) -> PyResult<()> {
self.connector_leader
.create_slot(request, tokens)
.map_err(to_pyerr)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::block_manager::connector::protocol::TransferType;
use dynamo_llm::block_manager::connector::scheduler::{
Scheduler, TransferSchedulerClient, WorkerSchedulerClient,
};
use std::collections::HashSet;
use std::sync::{Arc, OnceLock};
use super::*;
use crate::llm::block_manager::distributed::get_barrier_id_prefix;
use crate::llm::block_manager::vllm::connector::worker::event_sync_blocking;
use crate::{
llm::block_manager::distributed::VllmTensor, to_pyerr,
DistributedRuntime as PyDistributedRuntime,
};
use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
use dynamo_runtime::DistributedRuntime;
pub trait Worker: Send + Sync {
fn register_kv_caches(
&mut self,
num_device_blocks: usize,
page_size: usize,
device_id: usize,
dtype_width_bytes: usize,
kv_cache_tensor: Arc<VllmTensor>,
raw_event_handles: Vec<u64>,
) -> anyhow::Result<()>;
fn bind_connector_meta(&mut self, metadata: Vec<u8>) -> anyhow::Result<()>;
fn start_load_kv(&mut self) -> anyhow::Result<()>;
fn save_kv_layer(&mut self, layer_idx: usize) -> anyhow::Result<()>;
fn get_finished(
&mut self,
finished_gen_req_ids: Vec<u64>,
started_loading_req_ids: Vec<u64>,
) -> (Vec<u64>, Vec<u64>);
}
pub struct KvConnectorWorker {
drt: DistributedRuntime,
kvbm_worker: OnceLock<KvbmWorker>,
connector: WorkerSchedulerClient,
transfer_client: TransferSchedulerClient,
/// Map of request id to inflight load requests
maybe_finished_onboarding: HashSet<String>,
/// Map of request id to inflight finished requests
maybe_finished_offloading: HashSet<String>,
onboarding_operations: Vec<WorkerTransferRequest>,
offloading_operations: Vec<WorkerTransferRequest>,
bound: bool,
iteration: u64,
layers_complete: usize,
/// cuda events created by the python side
layer_events: Vec<u64>,
}
impl KvConnectorWorker {
fn new(py_drt: PyDistributedRuntime, trtllm_rank: String) -> anyhow::Result<Self> {
let drt = py_drt.inner.clone();
let runtime = drt.runtime().primary();
let (scheduler, worker_client, transfer_client) = Scheduler::new(drt.primary_token());
CriticalTaskExecutionHandle::new_with_runtime(
move |_| {
let mut scheduler = scheduler;
async move { scheduler.run().await }
},
drt.primary_token(),
"kv-connector-scheduler-task",
&runtime,
)?
.detach();
tracing::info!(
"KvConnectorWorker initialized with worker_rank: {}",
trtllm_rank
);
Ok(Self {
drt,
kvbm_worker: OnceLock::new(),
connector: worker_client,
transfer_client,
maybe_finished_onboarding: HashSet::new(),
maybe_finished_offloading: HashSet::new(),
onboarding_operations: Vec::new(),
offloading_operations: Vec::new(),
bound: false,
iteration: 0,
layers_complete: 0,
layer_events: Vec::new(),
})
}
}
impl Worker for KvConnectorWorker {
fn register_kv_caches(
&mut self,
num_device_blocks: usize,
page_size: usize,
device_id: usize,
dtype_width_bytes: usize,
kv_cache_tensor: Arc<VllmTensor>,
raw_event_handles: Vec<u64>,
) -> anyhow::Result<()> {
if self.kvbm_worker.get().is_some() {
tracing::warn!("kvbm worker already registered");
return Err(anyhow::anyhow!("kvbm worker already registered"));
}
let kv_cache_tensors = vec![kv_cache_tensor as Arc<dyn TorchTensor>];
let config = KvbmWorkerConfig::builder()
.drt(self.drt.clone())
.num_device_blocks(num_device_blocks)
.page_size(page_size)
.tensors(kv_cache_tensors)
.device_id(device_id)
.dtype_width_bytes(dtype_width_bytes)
.is_fully_contiguous_layout(true)
.barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone()))
.build()?;
self.layer_events = raw_event_handles;
let worker = self.drt.runtime().primary().block_on(async move {
let worker = KvbmWorker::new(config, true).await?;
anyhow::Ok(worker)
})?;
self.kvbm_worker
.set(worker)
.map_err(|_| anyhow::anyhow!("failed to set kvbm worker"))?;
Ok(())
}
fn bind_connector_meta(&mut self, metadata: Vec<u8>) -> anyhow::Result<()> {
let metadata: ConnectorMetadata = serde_json::from_slice(&metadata)?;
self.bound = true;
self.iteration = metadata.iteration;
self.layers_complete = 0;
tracing::debug!(
iteration = self.iteration,
"bound new metadata: {metadata:#?}"
);
self.connector.start_next_iteration()?;
debug_assert_eq!(
self.connector.iteration(),
metadata.iteration,
"iteration mismatch"
);
// local actions
// - create a request slot for each new request
// - for each action in the metadata, add the action to the request slot
// - send the list of actions to the engine to track completion
for slot in metadata.new_slots {
debug_assert!(!self.connector.has_slot(&slot), "slot already exists");
self.connector.create_slot(slot)?;
}
let mut onboarding_operations = Vec::new();
let mut offloading_operations = Vec::new();
for operation in metadata.operations {
tracing::debug!(
request_id = operation.request_id, operation_id = %operation.uuid,
"adding operation to slot: {operation:#?}"
);
match operation.transfer_type {
TransferType::Load => onboarding_operations.push(operation),
TransferType::Store => offloading_operations.push(operation),
}
}
debug_assert!(
self.onboarding_operations.is_empty(),
"onboarding operations should be empty"
);
self.onboarding_operations = onboarding_operations;
debug_assert!(
self.offloading_operations.is_empty(),
"offloading operations should be empty"
);
self.offloading_operations = offloading_operations;
Ok(())
}
fn save_kv_layer(&mut self, _layer_idx: usize) -> anyhow::Result<()> {
self.layers_complete += 1;
if self.layers_complete == self.layer_events.len() {
let offloading_operations = std::mem::take(&mut self.offloading_operations);
// block on the the completion of the last layer
// todo(ryan): capture the context, pass this to the scheduler to do the await on another thread
// or put the event on a stream and use stream waits to keep it all on device.
event_sync_blocking(self.layer_events[self.layers_complete - 1]);
for operation in offloading_operations {
self.connector.enqueue_request(operation);
}
}
Ok(())
}
fn start_load_kv(&mut self) -> anyhow::Result<()> {
let onboarding_operations = self.onboarding_operations.clone();
for operation in onboarding_operations {
let request_id = operation.request_id.clone();
self.connector.enqueue_request(operation);
self.maybe_finished_onboarding.insert(request_id);
}
Ok(())
}
fn get_finished(
&mut self,
finished_gen_req_ids: Vec<u64>,
started_loading_req_ids: Vec<u64>,
) -> (Vec<u64>, Vec<u64>) {
// we do not have to visit every slot on every pass, just slots we are waiting on
//
// there are two conditions where we would be waiting:
// 1. if we have requested a load, we need to wait for it to complete
// - the load request would come in via the metadata this is processsed in the bind
// 2. if we have requested a finished event, then we need to await for all outstanding
// operations to complete -- either by finishing or being cancelled
// - the finish request is triggered by this function, it is not seen in the metadata
//
// under each scenario, we mark the `maybe_finished_onboarding` and `maybe_finished_offloading` hashsets with
// the request id
//
// on each forward pass we visit the maybe slots to see if they are finished
let mut is_finished_offloading = HashSet::new();
let mut is_finished_onboarding = HashSet::new();
// before we process the maybes, add any newly annotated finished requests
// to the maybe finished set
for request_id in finished_gen_req_ids {
tracing::debug!(request_id, "marking request as finished");
if !self.connector.has_slot(&request_id.to_string()) {
tracing::warn!(
request_id,
"finished request received for unknown request_id; assuming never started"
);
continue;
}
if self
.maybe_finished_offloading
.contains(&request_id.to_string())
{
tracing::warn!(request_id, "possibly got a duplicate finished request; request_id already in the maybe_finished_offloading set");
} else {
tracing::debug!(
request_id,
"received finished request; adding to maybe_finished_offloading set"
);
self.maybe_finished_offloading
.insert(request_id.to_string());
}
}
for request_id in started_loading_req_ids {
tracing::debug!(request_id, "marking request as finished");
if !self.connector.has_slot(&request_id.to_string()) {
tracing::warn!(
request_id,
"finished request received for unknown request_id; assuming never started"
);
continue;
}
if self
.maybe_finished_onboarding
.contains(&request_id.to_string())
{
tracing::warn!(request_id, "possibly got a duplicate finished request; request_id already in the maybe_finished_onboarding set");
}
}
// visit each request slot in the maybe finished set
for request_id in self.maybe_finished_offloading.iter() {
if self.connector.has_slot(request_id) {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished offloading");
is_finished_offloading.insert(request_id.to_string());
} else {
tracing::debug!(request_id, "request slot is not finished offloading");
}
} else {
// made this condition more strict slot existence checks were added as a prerequesite
// to be added to the maybe_finished_offloading set.
panic!("request slot missing for {request_id}; however, it was present when added to the maybe finished offloading set");
}
}
// remove the finished requests from the maybe finished set
// note: when storing is finished we also remove the request from the engine state
for request_id in &is_finished_offloading {
self.maybe_finished_offloading.remove(request_id);
// currently chomping the error as the engine is closed and we are shutting down
if self.connector.has_slot(request_id) {
self.connector.remove_slot(request_id);
} else {
tracing::debug!(request_id, "is_finished_offloading: request slot is not found - likely aborted, removing from is finished offloading set");
}
}
// visit each request slot in the maybe finished set to see if it is finished
for request_id in self.maybe_finished_onboarding.iter() {
if self.connector.has_slot(request_id) {
if self.connector.is_complete(request_id) {
tracing::debug!(request_id, "request slot is finished onboarding");
is_finished_onboarding.insert(request_id.clone());
} else {
tracing::debug!(request_id, "request slot is not finished onboarding");
}
} else {
panic!("request slot missing for {request_id}; however, it was present when added to the maybe finished onboarding set");
}
}
// remove the finished requests from the maybe finished set
for request_id in &is_finished_onboarding {
self.maybe_finished_onboarding.remove(request_id);
if self.connector.has_slot(request_id) {
self.connector.remove_slot(request_id);
}
}
let finished_offloading: Vec<u64> = is_finished_offloading
.iter()
.filter_map(|s| s.parse::<u64>().ok()) // parse String -> u64
.collect();
let finished_onboarding: Vec<u64> = is_finished_onboarding
.iter()
.filter_map(|s| s.parse::<u64>().ok()) // parse String -> u64
.collect();
(finished_offloading, finished_onboarding)
}
}
#[pyclass]
pub struct PyTrtllmKvConnectorWorker {
connector_worker: Box<dyn Worker>,
}
#[pymethods]
impl PyTrtllmKvConnectorWorker {
#[new]
#[pyo3(signature = (py_drt, trtllm_rank))]
pub fn new(py_drt: PyDistributedRuntime, trtllm_rank: String) -> PyResult<Self> {
let connector_worker: Box<dyn Worker> =
Box::new(KvConnectorWorker::new(py_drt, trtllm_rank).map_err(to_pyerr)?);
Ok(Self { connector_worker })
}
pub fn register_kv_caches(
&mut self,
num_device_blocks: usize,
page_size: usize,
device_id: usize,
dtype_width_bytes: usize,
kv_cache_tensor: Py<PyAny>,
raw_event_handles: Vec<u64>,
) -> PyResult<()> {
// Convert Python tensor to Rust VllmTensor objects
let rust_kv_cache_tensor = Arc::new(VllmTensor::new(kv_cache_tensor).map_err(to_pyerr)?);
self.connector_worker
.register_kv_caches(
num_device_blocks,
page_size,
device_id,
dtype_width_bytes,
rust_kv_cache_tensor,
raw_event_handles,
)
.map_err(to_pyerr)
}
pub fn bind_connector_meta(&mut self, metadata: Vec<u8>) -> PyResult<()> {
self.connector_worker
.bind_connector_meta(metadata)
.map_err(to_pyerr)
}
pub fn save_kv_layer(&mut self, layer_idx: usize) -> PyResult<()> {
self.connector_worker
.save_kv_layer(layer_idx)
.map_err(to_pyerr)
}
pub fn start_load_kv(&mut self) -> PyResult<()> {
self.connector_worker.start_load_kv().map_err(to_pyerr)
}
pub fn get_finished(
&mut self,
finished_gen_req_ids: Vec<u64>,
started_loading_req_ids: Vec<u64>,
) -> (Vec<u64>, Vec<u64>) {
self.connector_worker
.get_finished(finished_gen_req_ids, started_loading_req_ids)
}
}
...@@ -11,7 +11,7 @@ use std::collections::HashSet; ...@@ -11,7 +11,7 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use super::*; use super::*;
use crate::llm::block_manager::distributed::get_barrier_id; use crate::llm::block_manager::distributed::get_barrier_id_prefix;
use crate::{ use crate::{
llm::block_manager::distributed::VllmTensor, to_pyerr, llm::block_manager::distributed::VllmTensor, to_pyerr,
DistributedRuntime as PyDistributedRuntime, DistributedRuntime as PyDistributedRuntime,
...@@ -166,12 +166,12 @@ impl Worker for KvConnectorWorker { ...@@ -166,12 +166,12 @@ impl Worker for KvConnectorWorker {
.tensors(vllm_tensors) .tensors(vllm_tensors)
.device_id(device_id) .device_id(device_id)
.dtype_width_bytes(dtype_width_bytes) .dtype_width_bytes(dtype_width_bytes)
.barrier_id(get_barrier_id()) .barrier_id_prefix(get_barrier_id_prefix())
.scheduler_client(Some(self.transfer_client.clone())) .scheduler_client(Some(self.transfer_client.clone()))
.build()?; .build()?;
let worker = self.drt.runtime().primary().block_on(async move { let worker = self.drt.runtime().primary().block_on(async move {
let worker = KvbmWorker::new(config).await?; let worker = KvbmWorker::new(config, false).await?;
anyhow::Ok(worker) anyhow::Ok(worker)
})?; })?;
...@@ -477,7 +477,7 @@ fn _get_current_context() -> CUcontext { ...@@ -477,7 +477,7 @@ fn _get_current_context() -> CUcontext {
ctx ctx
} }
fn event_sync_blocking(event: u64) { pub fn event_sync_blocking(event: u64) {
let status = unsafe { cuEventSynchronize(event as CUevent) }; let status = unsafe { cuEventSynchronize(event as CUevent) };
assert_eq!( assert_eq!(
status, status,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .kvbm_connector_leader import DynamoKVBMConnectorLeader
from .kvbm_connector_worker import DynamoKVBMConnectorWorker
__all__ = ["DynamoKVBMConnectorLeader", "DynamoKVBMConnectorWorker"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler,
SchedulerOutput,
)
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from dynamo.llm import KvbmLeader
from dynamo.llm.trtllm_integration.rust import KvbmRequest
from dynamo.llm.trtllm_integration.rust import (
KvConnectorLeader as RustKvConnectorLeader,
)
from dynamo.llm.trtllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from dynamo.runtime import DistributedRuntime
class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
self.drt = DistributedRuntime.detached()
world_size = self._config.mapping.world_size
self.block_size = self._config.tokens_per_block
# Set bytes_per_block to 0, because we will retrieve the actual value from the worker side.
leader = KvbmLeader(world_size, drt=self.drt)
print(
f"KvConnectorLeader initialized with rank: {executor_config.mapping.rank}"
)
self._connector = RustKvConnectorLeader(
executor_config.mapping.rank, self.drt, self.block_size, leader
)
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes:
"""
Build the metadata for the worker.
This is called by the KV Cache Manager when adding a sequence.
Args:
scheduler_output: The data for all inflight requests.
Returns:
The metadata for the workers.
"""
output = RustSchedulerOutput()
for req in scheduler_output.new_requests:
output.add_new_request(
str(req.request_id),
req.new_tokens,
req.new_block_ids,
req.computed_position,
)
resumed_from_preemption = False
for req in scheduler_output.cached_requests:
output.add_cached_request(
str(req.request_id),
resumed_from_preemption,
req.new_tokens,
req.new_block_ids,
req.computed_position,
)
return self._connector.build_connector_metadata(output)
def get_num_new_matched_tokens(
self, request: LlmRequest, num_computed_tokens: int
) -> tuple[int, bool]:
"""
Get the number of tokens that can be loaded from remote KV cache.
This does not include the tokens already matched on device (indicated by `num_computed_tokens`).
Args:
request: The request to get the number of tokens for.
num_computed_tokens: The number of tokens already matched on device.
Returns:
The number of tokens that can be loaded from remote KV cache.
Whether the tokens will be loaded asynchronously.
"""
self._create_slot(request)
return self._connector.get_num_new_matched_tokens(
str(request.request_id),
len(request.get_tokens(0)),
num_computed_tokens,
)
def update_state_after_alloc(self, request: LlmRequest, block_ids: List[int]):
"""
Called after get_num_new_matched_tokens is called to provide the block ids to the scheduler.
Args:
request: The request that was allocated resources.
block_ids: The KV cacheblock IDs that were allocated.
"""
self._connector.update_state_after_alloc(
str(request.request_id), block_ids, request.context_current_position
)
def request_finished(self, request: LlmRequest, cache_block_ids: list[int]) -> bool:
"""
Called when a request is finished generating tokens.
Args:
request: The request that finished generating tokens.
Returns:
Whether the request is performing asynchronous saving operations.
If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers).
"""
is_async_saving = self._connector.request_finished(
str(request.request_id), cache_block_ids
)
return is_async_saving
def _create_slot(self, request: LlmRequest) -> None:
"""Create a slot for the request"""
if self._connector.has_slot(str(request.request_id)):
return None
if bool(request.multimodal_positions):
raise ValueError("Unsupported request - requires mm extra keys")
all_token_ids = request.get_tokens(0)
# extract the critial aspects of the request that effect how the tokens are hashed
request = KvbmRequest(
request_id=str(request.request_id), lora_name=None, salt_hash=None
)
self._connector.create_slot(request, all_token_ids)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import KvCacheConnectorWorker
from tensorrt_llm.bindings.executor import ExecutorConfig
from dynamo.llm.trtllm_integration.rust import (
KvConnectorWorker as RustKvConnectorWorker,
)
from dynamo.runtime import DistributedRuntime
class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
self.drt = DistributedRuntime.detached()
self.rank = executor_config.mapping.rank
self._connector = RustKvConnectorWorker(
self.drt, str(executor_config.mapping.rank)
)
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
"""
Register the KV cache tensors to the worker.
This can be used for something like NIXL registration.
Args:
kv_cache_tensor: The contiguous KV cache tensor.
"""
print(f"Register KV Caches on rank {self.rank}")
logger.info(
f"KvConnectorWorker started registering the kv caches on rank {self._config.mapping.rank}"
)
num_device_blocks = kv_cache_tensor.shape[0]
page_size = self._config.tokens_per_block
device_id = kv_cache_tensor.device.index
kv_cache_dtype = kv_cache_tensor.dtype
num_cache_layers = kv_cache_tensor.shape[1]
self.events = [
torch.cuda.Event(enable_timing=False, interprocess=False)
for _ in range(num_cache_layers)
]
for event in self.events:
event.record(torch.cuda.current_stream(device_id))
raw_event_handles = [event.cuda_event for event in self.events]
self._connector.register_kv_caches(
num_device_blocks,
page_size,
device_id,
kv_cache_dtype.itemsize,
kv_cache_tensor,
raw_event_handles,
)
def bind_connector_meta(self, metadata: object):
"""Set the connector metadata from the scheduler.
This function should be called by the model runner every time
before the model execution. The metadata will be used for runtime
KV cache loading and saving.
Args:
metadata (bytes): the connector metadata.
"""
super().bind_connector_meta(metadata)
self._connector.bind_connector_meta(metadata)
def start_load_kv(self, stream: torch.cuda.Stream):
"""
Begin loading the KV cache in preparation for the next forward pass.
Specific blocks to transfer are indicated by the scheduler's metadata.
"""
self._connector.start_load_kv()
def wait_for_save(self, stream: torch.cuda.Stream):
"""
Block until all synchronous saving operations are complete. Called at the end of the forward pass.
"""
pass
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
"""
Wait for a layer to finish being loaded before proceeding with the forward pass on the layer.
Note: This function is called immediately before the layer's work is enqueued into the stream.
Args:
layer_idx: The index of the layer to wait for.
stream: The stream the forward pass is being executed on.
"""
pass
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
"""
Begin saving the KV cache for a layer.
Note: This function is called immediately after the layer's work is enqueued into the stream.
Args:
layer_idx: The index of the layer to save.
stream: The stream the forward pass is being executed on.
"""
self.events[layer_idx].record(stream)
self._connector.save_kv_layer(layer_idx)
def get_finished(
self, finished_gen_req_ids: list[int], started_loading_req_ids: list[int]
) -> tuple[list[int], list[int]]:
"""
Get the requests that have finished loading and saving.
Args:
finished_gen_req_ids: The IDs of the requests that have finished generating tokens, and are now asynchronously saving.
started_loading_req_ids: The IDs of the requests that have started asynchronously loading.
Returns:
The IDs of the requests that have finished saving.
The IDs of the requests that have finished loading.
Note: IDs may only be returned from this call after they've been provided in the `finished_gen_req_ids` and `started_loading_req_ids` arguments.
Additionally, the runtime will only take action based on these returned IDs once they've been returned by ALL workers. This allows some workers to take longer than others to complete the operations.
"""
return self._connector.get_finished(
finished_gen_req_ids, started_loading_req_ids
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment