Unverified Commit a68c2f8f authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

feat: DIS-373 dynamo KVBM connector API integration with TRTLLM (#2544)


Signed-off-by: default avatarrichardhuo-nv <rihuo@nvidia.com>
parent 43a26958
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Loader for the Rust-based TensorRT-LLM integration objects, using objects from _vllm_integration for now
"""
try:
# TODO: use TRTLLM own integration module
from dynamo._core import _vllm_integration
# Runtime - dynamically loaded classes from Rust extension
KvbmRequest = getattr(_vllm_integration, "KvbmRequest")
KvbmBlockList = getattr(_vllm_integration, "KvbmBlockList")
BlockState = getattr(_vllm_integration, "BlockState")
BlockStates = getattr(_vllm_integration, "BlockStates")
SlotUpdate = getattr(_vllm_integration, "SlotUpdate")
KvConnectorWorker = getattr(_vllm_integration, "PyTrtllmKvConnectorWorker")
KvConnectorLeader = getattr(_vllm_integration, "PyTrtllmKvConnectorLeader")
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
except ImportError:
print(
"Failed to import Dynamo KVBM. TensorRT-LLM integration will not be available."
)
KvbmRequest = None
KvbmBlockList = None
BlockState = None
BlockStates = None
SlotUpdate = None
KvConnectorWorker = None
KvConnectorLeader = None
SchedulerOutput = None
__all__ = [
"KvbmRequest",
"KvbmBlockList",
"BlockState",
"BlockStates",
"SlotUpdate",
"KvConnectorWorker",
"KvConnectorLeader",
"SchedulerOutput",
]
...@@ -14,7 +14,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadat ...@@ -14,7 +14,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadat
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request from vllm.v1.request import Request
from vllm.worker.cache_engine import CacheEngine
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -29,7 +28,7 @@ if TYPE_CHECKING: ...@@ -29,7 +28,7 @@ if TYPE_CHECKING:
# ) # )
# from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput # from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from dynamo.llm import BlockManager, KvbmLeader from dynamo.llm import KvbmLeader
from dynamo.llm.vllm_integration.kv_cache_utils import ( from dynamo.llm.vllm_integration.kv_cache_utils import (
find_and_set_available_port_from_env, find_and_set_available_port_from_env,
) )
...@@ -64,25 +63,12 @@ class KvConnectorLeader: ...@@ -64,25 +63,12 @@ class KvConnectorLeader:
self.vllm_config = vllm_config self.vllm_config = vllm_config
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
bytes_per_block = CacheEngine.get_cache_block_size(
vllm_config.cache_config,
vllm_config.model_config,
vllm_config.parallel_config,
)
total_bytes = bytes_per_block * world_size
leader = KvbmLeader(total_bytes, world_size, drt=self.drt)
block_manager = BlockManager( leader = KvbmLeader(world_size, drt=self.drt)
0,
leader,
vllm_config.cache_config.block_size,
disable_device_pool=True,
)
print(f"KvConnectorLeader initialized with engine_id: {engine_id}") print(f"KvConnectorLeader initialized with engine_id: {engine_id}")
self._connector = RustKvConnectorLeader( self._connector = RustKvConnectorLeader(
engine_id, self.drt, block_manager, leader engine_id, self.drt, vllm_config.cache_config.block_size, leader
) )
# KV Connector # KV Connector
......
...@@ -8,7 +8,7 @@ mod zmq; ...@@ -8,7 +8,7 @@ mod zmq;
mod leader; mod leader;
mod worker; mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig}; pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::BlockTransferHandler; pub use transfer::BlockTransferHandler;
pub use utils::{ pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType, BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
...@@ -130,21 +130,31 @@ mod tests { ...@@ -130,21 +130,31 @@ mod tests {
vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))]; vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))];
let config = KvbmWorkerConfig::builder() let config = KvbmWorkerConfig::builder()
.barrier_id(barrier_id.clone()) .barrier_id_prefix(barrier_id.clone())
.num_device_blocks(NUM_BLOCKS) .num_device_blocks(NUM_BLOCKS)
.tensors(tensors) .tensors(tensors)
.worker_id(i) .device_id(i)
.build()?; .build()?;
let worker = KvbmWorker::new(config).await?; let worker = KvbmWorker::new(config, false).await?;
workers.push(worker); workers.push(worker);
} }
let host_blocks = KvbmLeaderNumBlocksConfig {
cache_size_in_gb: 1.0,
num_blocks_overriden: NUM_BLOCKS,
};
let disk_blocks = KvbmLeaderNumBlocksConfig {
cache_size_in_gb: 1.0,
num_blocks_overriden: NUM_BLOCKS,
};
let leader_config = KvbmLeaderConfig::builder() let leader_config = KvbmLeaderConfig::builder()
.barrier_id(barrier_id) .barrier_id_prefix(barrier_id)
.world_size(num_workers) .world_size(num_workers)
.num_host_blocks(NUM_BLOCKS) .host_blocks_config(host_blocks)
.num_disk_blocks(NUM_BLOCKS) .disk_blocks_config(disk_blocks)
.build()?; .build()?;
// When/if this returns, we know that all the workers were also successful. // When/if this returns, we know that all the workers were also successful.
......
...@@ -9,12 +9,16 @@ use zmq::*; ...@@ -9,12 +9,16 @@ use zmq::*;
use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier; use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier;
use anyhow::{Context, anyhow};
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration; use std::time::Duration;
use tokio::sync::Notify;
use tokio::sync::OnceCell;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken; use tokio::time::sleep;
/// Data that is sent to workers over ETCD to establish a ZMQ connection. /// Data that is sent to workers over ETCD to establish a ZMQ connection.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -25,17 +29,31 @@ pub struct KvbmLeaderData { ...@@ -25,17 +29,31 @@ pub struct KvbmLeaderData {
pub num_disk_blocks: usize, pub num_disk_blocks: usize,
} }
#[derive(Builder, Clone, Debug)] #[derive(Builder, Clone, Debug, Default)]
pub struct KvbmLeaderConfig { pub struct KvbmLeaderNumBlocksConfig {
#[builder(default = "0")] #[builder(default = "0.0")]
num_host_blocks: usize, pub cache_size_in_gb: f64,
#[builder(default = "0")] #[builder(default = "0")]
num_disk_blocks: usize, pub num_blocks_overriden: usize,
}
fn compute_num_blocks(
num_blocks_config: &KvbmLeaderNumBlocksConfig,
bytes_per_block: usize,
) -> usize {
if num_blocks_config.num_blocks_overriden > 0 {
num_blocks_config.num_blocks_overriden
} else {
((num_blocks_config.cache_size_in_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize
}
}
#[derive(Builder, Clone, Debug)]
pub struct KvbmLeaderConfig {
/// The barrier id to use for syncing with workers. /// The barrier id to use for syncing with workers.
#[builder(default = "String::from(\"kvbm\")")] #[builder(default = "String::from(\"kvbm\")")]
barrier_id: String, barrier_id_prefix: String,
/// The world size. /// The world size.
#[builder(default = "1")] #[builder(default = "1")]
...@@ -47,6 +65,12 @@ pub struct KvbmLeaderConfig { ...@@ -47,6 +65,12 @@ pub struct KvbmLeaderConfig {
#[builder(setter(strip_option))] #[builder(setter(strip_option))]
drt: Option<DistributedRuntime>, drt: Option<DistributedRuntime>,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
host_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
disk_blocks_config: KvbmLeaderNumBlocksConfig,
} }
impl KvbmLeaderConfig { impl KvbmLeaderConfig {
...@@ -55,6 +79,15 @@ impl KvbmLeaderConfig { ...@@ -55,6 +79,15 @@ impl KvbmLeaderConfig {
} }
} }
#[derive(Debug, Default)]
pub struct KvbmLeaderState {
pub num_device_blocks: Arc<AtomicUsize>,
pub num_host_blocks: Arc<AtomicUsize>,
pub num_disk_blocks: Arc<AtomicUsize>,
pub workers_allocation_ready: Arc<AtomicBool>,
pub workers_ready_notify: Arc<Notify>,
}
/// The leader of the KVBM. /// The leader of the KVBM.
/// ///
/// This is responsible for: /// This is responsible for:
...@@ -62,9 +95,13 @@ impl KvbmLeaderConfig { ...@@ -62,9 +95,13 @@ impl KvbmLeaderConfig {
/// - Syncing the leader barrier with workers. /// - Syncing the leader barrier with workers.
/// - Sending messages to workers. /// - Sending messages to workers.
pub struct KvbmLeader { pub struct KvbmLeader {
num_device_blocks: usize, state: Arc<KvbmLeaderState>,
zmq_leader: ZmqActiveMessageLeader, zmq_leader: Arc<OnceCell<ZmqActiveMessageLeader>>,
config: KvbmLeaderConfig, config: KvbmLeaderConfig,
//readiness flags
workers_sync_ready: Arc<AtomicBool>,
workers_sync_ready_notify: Arc<Notify>,
workers_sync_done: Arc<AtomicBool>,
} }
impl KvbmLeader { impl KvbmLeader {
...@@ -76,34 +113,90 @@ impl KvbmLeader { ...@@ -76,34 +113,90 @@ impl KvbmLeader {
} }
}; };
tracing::info!( let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?;
"Syncing leader barrier with {} workers on barrier id {}",
config.world_size, let leader = Self {
config.barrier_id state: Arc::new(KvbmLeaderState::default()),
zmq_leader: Arc::new(tokio::sync::OnceCell::new()),
config,
workers_sync_ready: Arc::new(AtomicBool::new(false)),
workers_sync_ready_notify: Arc::new(Notify::new()),
workers_sync_done: Arc::new(AtomicBool::new(false)),
};
let cancel_token = tokio_util::sync::CancellationToken::new();
// The leader_sockets struct cannot be cloned,
// so we use a tuple to "struct" the two urls
let leader_urls = (
leader_sockets.pub_url.clone(),
leader_sockets.ack_url.clone(),
); );
leader.spawn_barrier_task(drt, leader_urls);
leader.spawn_zmq_task(leader_sockets, cancel_token);
let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?; Ok(leader)
}
fn spawn_barrier_task(&self, drt: DistributedRuntime, leader_urls: (String, String)) {
let state = self.state.clone();
let leader_config = self.config.clone();
let ready = Arc::clone(&self.workers_sync_ready);
let notify = Arc::clone(&self.workers_sync_ready_notify);
let done = Arc::clone(&self.workers_sync_done);
let zmq_data = Arc::new(KvbmLeaderData { tokio::spawn(async move {
pub_url: leader_sockets.pub_url.clone(), match KvbmLeader::run_barrier_sync(drt, leader_urls, leader_config).await {
ack_url: leader_sockets.ack_url.clone(), Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) => {
num_host_blocks: config.num_host_blocks, // write back results
num_disk_blocks: config.num_disk_blocks, state
.num_device_blocks
.store(num_device_blocks, Ordering::Release);
state
.num_host_blocks
.store(num_host_blocks, Ordering::Release);
state
.num_disk_blocks
.store(num_disk_blocks, Ordering::Release);
ready.store(true, Ordering::Release);
done.store(true, Ordering::Release);
notify.notify_waiters();
}
Err(e) => {
tracing::error!("Barrier sync failed: {e:?}");
done.store(true, Ordering::Release);
notify.notify_waiters();
}
}
}); });
}
async fn run_barrier_sync(
drt: DistributedRuntime,
leader_urls: (String, String),
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<(usize, usize, usize)> {
let barrier_id_worker_to_leader =
format!("{}{}", leader_config.barrier_id_prefix, "-worker-to-leader");
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_worker_to_leader
);
// Build our leader barrier and publish the data. // Build our leader barrier and publish the data.
// TODO: Use a separate timeout parameter from the ZMQ connection timeout // TODO: Use a separate timeout parameter from the ZMQ connection timeout
let leader_barrier: LeaderBarrier<KvbmLeaderData, worker::KvbmWorkerData> = let worker_to_leader_barrier: LeaderBarrier<(), worker::KvbmWorkerData> =
LeaderBarrier::new( LeaderBarrier::new(
config.barrier_id.clone(), barrier_id_worker_to_leader.clone(),
config.world_size, leader_config.world_size,
Some(Duration::from_secs(config.leader_init_timeout_secs)), Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
); );
let worker_data = leader_barrier let worker_data = worker_to_leader_barrier
.sync(&drt, zmq_data.as_ref()) .sync(&drt, &())
.await .await
.map_err(|e| anyhow::anyhow!("Failed to sync leader barrier: {:?}", e))?; .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
let num_device_blocks = worker_data let num_device_blocks = worker_data
.values() .values()
...@@ -111,46 +204,250 @@ impl KvbmLeader { ...@@ -111,46 +204,250 @@ impl KvbmLeader {
.min() .min()
.unwrap(); .unwrap();
tracing::info!("Leader barrier synced with {} workers", config.world_size); // TODO: this works for TP, need to redefine bytes_per_block when we enable the DP/PP
let bytes_per_block: usize = worker_data.values().map(|d| d.bytes_per_block).sum();
assert!(
bytes_per_block > 0,
"bytes_per_block must be greater than 0"
);
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
tracing::debug!("Worker data: {:?}", worker_data); tracing::debug!("Worker data: {:?}", worker_data);
// Now, create our active message leader. let num_host_blocks =
// This also blocks until a ZMQ connection has been established. compute_num_blocks(&leader_config.host_blocks_config, bytes_per_block);
let cancel_token = CancellationToken::new(); let num_disk_blocks =
let zmq_leader = ZmqActiveMessageLeader::new( compute_num_blocks(&leader_config.disk_blocks_config, bytes_per_block);
leader_sockets,
config.world_size, // Start the second sync to transfer num_host_blocks and num_disk_blocks to worker
Duration::from_secs(config.leader_init_timeout_secs), let barrier_id_leader_to_worker =
cancel_token.clone(), format!("{}{}", leader_config.barrier_id_prefix, "-leader-to-worker");
) tracing::info!(
.await?; "Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
Ok(Self { barrier_id_leader_to_worker
num_device_blocks, );
zmq_leader,
config, let (leader_pub_url, leader_ack_url) = leader_urls;
let zmq_data_leader_to_worker = Arc::new(KvbmLeaderData {
pub_url: leader_pub_url,
ack_url: leader_ack_url,
num_host_blocks,
num_disk_blocks,
});
let leader_to_worker_barrier: LeaderBarrier<KvbmLeaderData, ()> = LeaderBarrier::new(
barrier_id_leader_to_worker.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _worker_data = leader_to_worker_barrier
.sync(&drt, zmq_data_leader_to_worker.as_ref())
.await
.map_err(|e| anyhow::anyhow!("Failed to sync leader to worker barrier: {:?}", e))?;
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
Ok((num_device_blocks, num_host_blocks, num_disk_blocks))
}
fn spawn_zmq_task(
&self,
leader_sockets: LeaderSockets,
cancel: tokio_util::sync::CancellationToken,
) {
let cell = self.zmq_leader.clone();
let state = self.state.clone();
let world_size = self.config.world_size;
let timeout = self.config.leader_init_timeout_secs;
tokio::spawn(async move {
let res = ZmqActiveMessageLeader::new(
leader_sockets,
world_size,
std::time::Duration::from_secs(timeout),
cancel,
)
.await;
match res {
Ok(zmq) => {
let _ = cell.set(zmq);
// mark ready
state
.workers_allocation_ready
.store(true, Ordering::Release);
state.workers_ready_notify.notify_waiters();
}
Err(e) => {
tracing::error!("ZMQ init failed: {e:?}");
}
}
});
}
// This is supposed to be used in non-blocking leader initialization
pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) {
let timeout_secs = self.config.leader_init_timeout_secs;
let state = self.state.clone();
let leader_config = self.config.clone();
let handle = drt.runtime().primary();
handle.spawn(async move {
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait until ZMQ marks ready or we time out.
let waited = tokio::time::timeout(
Duration::from_secs(timeout_secs),
state.workers_ready_notify.notified(),
)
.await;
if waited.is_err() {
tracing::error!(
"leader readiness barrier wait timed out after {timeout_secs} seconds"
);
return;
}
// Double-check the flag (Acquire) after wakeup.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
tracing::error!("leader readiness notify fired but flag not set; aborting");
return;
}
}
match KvbmLeader::run_leader_readiness(drt, leader_config).await {
Ok(()) => {
tracing::info!("leader readiness barrier synced!");
}
Err(e) => {
tracing::error!("leader readiness barrier failed: {e:?}");
}
}
});
}
// This is supposed to be used in blocking leader initialization
pub fn run_leader_readiness_barrier_blocking(
&self,
drt: DistributedRuntime,
) -> anyhow::Result<()> {
let state = self.state.clone();
let timeout_secs = self.config.leader_init_timeout_secs;
let leader_config = self.config.clone();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async move {
// Create the future *before* checking the flag to avoid a lost-notify race.
let notified = state.workers_ready_notify.notified();
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait (with timeout) until ZMQ task marks ready.
tokio::time::timeout(Duration::from_secs(timeout_secs), notified)
.await
.map_err(|_| anyhow!("timed out waiting for workers_allocation_ready after {timeout_secs} seconds"))?;
// Double-check after wake to ensure the flag is actually set.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
return Err(anyhow!(
"notified but workers_allocation_ready is still false"
));
}
}
KvbmLeader::run_leader_readiness(drt, leader_config).await
})
.context("leader readiness barrier failed")
}) })
} }
async fn run_leader_readiness(
drt: DistributedRuntime,
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<()> {
let barrier_id_leader_ready =
format!("{}{}", leader_config.barrier_id_prefix, "-leader-ready");
tracing::info!(
"Syncing leader readiness barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_leader_ready
);
let leader_readiness_barrier: LeaderBarrier<(), ()> = LeaderBarrier::new(
barrier_id_leader_ready.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _ = leader_readiness_barrier
.sync(&drt, &())
.await
.map_err(|e| {
anyhow::anyhow!("Failed to sync leader readiness barrier on leader: {:?}", e)
})?;
Ok(())
}
pub async fn transfer_blocks_request( pub async fn transfer_blocks_request(
&self, &self,
request: BlockTransferRequest, request: BlockTransferRequest,
) -> anyhow::Result<oneshot::Receiver<()>> { ) -> anyhow::Result<oneshot::Receiver<()>> {
let zmq = self
.zmq_leader
.get()
.ok_or_else(|| anyhow::anyhow!("ZMQ leader not ready"))?;
let data = vec![serde_json::to_vec(&request)?]; let data = vec![serde_json::to_vec(&request)?];
self.zmq_leader zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await
.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data) }
.await
pub fn is_worker_sync_ready(&self) -> bool {
self.workers_sync_ready.load(Ordering::Acquire)
}
pub fn is_worker_sync_done(&self) -> bool {
self.workers_sync_done.load(Ordering::Acquire)
} }
pub fn num_device_blocks(&self) -> usize { pub fn num_device_blocks(&self) -> usize {
self.num_device_blocks self.state.num_device_blocks.load(Ordering::Acquire)
} }
pub fn num_host_blocks(&self) -> usize { pub fn num_host_blocks(&self) -> usize {
self.config.num_host_blocks self.state.num_host_blocks.load(Ordering::Acquire)
} }
pub fn num_disk_blocks(&self) -> usize { pub fn num_disk_blocks(&self) -> usize {
self.config.num_disk_blocks self.state.num_disk_blocks.load(Ordering::Acquire)
}
pub async fn wait_worker_sync_ready(&self) -> bool {
if self.is_worker_sync_ready() {
return true;
}
if self.is_worker_sync_done() {
return false;
}
let notified = self.workers_sync_ready_notify.notified();
if self.is_worker_sync_ready() {
return true;
}
if self.is_worker_sync_done() {
return false;
}
// bounded wait
tokio::select! {
_ = notified => {
self.is_worker_sync_ready()
}
_ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false,
}
} }
} }
...@@ -35,6 +35,7 @@ use dynamo_runtime::{ ...@@ -35,6 +35,7 @@ use dynamo_runtime::{
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvbmWorkerData { pub struct KvbmWorkerData {
pub num_device_blocks: usize, pub num_device_blocks: usize,
pub bytes_per_block: usize,
} }
pub fn load_and_validate_tensors( pub fn load_and_validate_tensors(
...@@ -82,7 +83,7 @@ pub fn load_and_validate_tensors( ...@@ -82,7 +83,7 @@ pub fn load_and_validate_tensors(
Ok((device_tensors, shape.unwrap())) Ok((device_tensors, shape.unwrap()))
} }
#[derive(Builder)] #[derive(Builder, Clone)]
#[builder(pattern = "owned")] #[builder(pattern = "owned")]
pub struct KvbmWorkerConfig { pub struct KvbmWorkerConfig {
drt: DistributedRuntime, drt: DistributedRuntime,
...@@ -101,8 +102,11 @@ pub struct KvbmWorkerConfig { ...@@ -101,8 +102,11 @@ pub struct KvbmWorkerConfig {
#[builder(default = "2")] #[builder(default = "2")]
dtype_width_bytes: usize, dtype_width_bytes: usize,
#[builder(default = false)]
is_fully_contiguous_layout: bool,
#[builder(default = "String::from(\"kvbm\")")] #[builder(default = "String::from(\"kvbm\")")]
barrier_id: String, barrier_id_prefix: String,
#[builder(default = "None")] #[builder(default = "None")]
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
...@@ -132,7 +136,7 @@ pub struct KvbmWorker { ...@@ -132,7 +136,7 @@ pub struct KvbmWorker {
} }
impl KvbmWorker { impl KvbmWorker {
pub async fn new(config: KvbmWorkerConfig) -> anyhow::Result<Self> { pub async fn new(config: KvbmWorkerConfig, layout_blocking: bool) -> anyhow::Result<Self> {
tracing::info!( tracing::info!(
"Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}", "Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}",
config.num_device_blocks, config.num_device_blocks,
...@@ -153,62 +157,147 @@ impl KvbmWorker { ...@@ -153,62 +157,147 @@ impl KvbmWorker {
))); )));
} }
let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks { let (layout_type, num_layers, outer_dim, inner_dim) = if !config.is_fully_contiguous_layout
(false, shape[1]) {
} else if shape[1] >= config.num_device_blocks { let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks {
(true, shape[0]) (false, shape[1])
} else if shape[1] >= config.num_device_blocks {
(true, shape[0])
} else {
return Err(anyhow::anyhow!(format!(
"Unsupported kv cache layout. Got shape: {:?}",
shape
)));
};
let num_layers = device_tensors.len();
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
device_tensors.len(),
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::LayerSeparate { outer_contiguous },
num_layers,
outer_dim,
inner_dim,
)
} else { } else {
return Err(anyhow::anyhow!(format!( let num_layers = shape[1];
"Unsupported kv cache layout. Got shape: {:?}", let outer_dim = shape[2];
shape let inner_dim = shape[3..].iter().product::<usize>() / config.page_size;
))); tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::FullyContiguous,
num_layers,
outer_dim,
inner_dim,
)
}; };
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size; let bytes_per_block =
num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
device_tensors.len(),
outer_dim,
config.page_size,
inner_dim
);
let mut layout_builder_instance = LayoutConfigBuilder::default(); let mut layout_builder_instance = LayoutConfigBuilder::default();
let layout_builder = layout_builder_instance let layout_builder = layout_builder_instance
.num_layers(device_tensors.len()) .num_layers(num_layers)
.outer_dim(outer_dim) .outer_dim(outer_dim)
.page_size(config.page_size) .page_size(config.page_size)
.inner_dim(inner_dim) .inner_dim(inner_dim)
.dtype_width_bytes(config.dtype_width_bytes); .dtype_width_bytes(config.dtype_width_bytes);
let layout_type = LayoutType::LayerSeparate { outer_contiguous };
let device_layout = layout_builder let device_layout = layout_builder
.num_blocks(config.num_device_blocks) .num_blocks(config.num_device_blocks)
.build()? .build()?
.create_layout(layout_type, device_tensors)?; .create_layout(layout_type, device_tensors)?;
let layout_builder_clone = layout_builder.clone(); let layout_builder = layout_builder.clone();
let (task, handler_rx) = if layout_blocking {
Self::run_blocking_layout_initialization(
config,
bytes_per_block,
device_layout,
layout_builder,
layout_type,
)
.await?
} else {
Self::run_non_blocking_layout_initialization(
config,
bytes_per_block,
device_layout,
layout_builder,
layout_type,
)
.await?
};
// add worker-connector scheduler here Ok(Self {
// let scheduler = KvbmWorkerScheduler::new(config.scheduler.clone()); task: Some(task),
block_transfer_handler_rx: Some(handler_rx),
})
}
async fn run_blocking_layout_initialization(
config: KvbmWorkerConfig,
bytes_per_block: usize,
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.drt.primary_token().clone(); let cancel_token = config.drt.primary_token().clone();
// barrier sync with leader to get the leader data
let leader_data = tokio::task::block_in_place(|| {
// This is now synchronous blocking code
// We need a separate current-thread runtime to block_on async calls here
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
KvbmWorker::leader_barrier_sync(
config.clone(),
cancel_token.clone(),
bytes_per_block,
)
.await
})
})?;
// establish a oneshot channel to get back the raw BlockTransferHandler // establish a oneshot channel to get back the raw BlockTransferHandler
let (handler_tx, handler_rx) = oneshot::channel(); let (handler_tx, handler_rx) = oneshot::channel();
// establish a oneshot channel to block on the main routine to wait for layout allocation readiness
let (layout_ready_tx, layout_ready_rx) = oneshot::channel::<String>();
let scheduler_client = config.scheduler_client.clone(); let scheduler_client = config.scheduler_client.clone();
let worker_config = config.clone();
// start background worker task to do layout allocation for host or disk
let task = CriticalTaskExecutionHandle::new( let task = CriticalTaskExecutionHandle::new(
move |cancel_token| { move |cancel_token| {
KvbmWorker::worker_task( KvbmWorker::worker_task(
device_layout, device_layout,
layout_builder_clone, layout_builder,
leader_data,
layout_type, layout_type,
config, worker_config,
cancel_token, cancel_token,
handler_tx, handler_tx,
layout_ready_tx,
scheduler_client, scheduler_client,
) )
}, },
...@@ -216,12 +305,105 @@ impl KvbmWorker { ...@@ -216,12 +305,105 @@ impl KvbmWorker {
"kvbm-worker-task", "kvbm-worker-task",
)?; )?;
Ok(Self { // waiting for the worker layout allocation ready
task: Some(task), match layout_ready_rx.await {
block_transfer_handler_rx: Some(handler_rx), Ok(_) => tracing::info!("worker layout allocation finished."),
}) Err(_) => tracing::error!("Worker layout dropped without sending"),
}
let worker_config = config.clone();
let cancel_for_barrier = cancel_token.clone();
// wait until the leader finished the initialization of all components
tokio::task::block_in_place(|| {
// This is now synchronous blocking code
// We need a separate current-thread runtime to block_on async calls here
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
KvbmWorker::leader_readiness_sync(worker_config, cancel_for_barrier).await
})
})?;
Ok((task, handler_rx))
} }
async fn run_non_blocking_layout_initialization(
config: KvbmWorkerConfig,
bytes_per_block: usize,
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage> + Send + 'static>,
layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.drt.primary_token().clone();
let scheduler_client = config.scheduler_client.clone();
// channel to get BlockTransferHandler back to the caller
let (handler_tx, handler_rx) = oneshot::channel::<transfer::BlockTransferHandler>();
// channel that the worker will use to signal layout readiness
let (layout_ready_tx, layout_ready_rx) = oneshot::channel::<String>();
// clone what we need inside the orchestrator
let worker_config = config.clone();
let cancel_token_for_task = cancel_token.clone();
// Single task that orchestrates everything in-order.
let task = CriticalTaskExecutionHandle::new(
move |ct| {
let cfg = worker_config.clone();
let scheduler = scheduler_client.clone();
async move {
// 1) barrier (must finish before worker_task starts)
let leader_data =
KvbmWorker::leader_barrier_sync(cfg.clone(), ct.clone(), bytes_per_block)
.await?;
// 2) start the long-running worker (after barrier)
// Spawn it so the orchestrator can continue with readiness + waiting.
let dev_layout = device_layout; // moved in
let lb = layout_builder; // moved in
let lt = layout_type; // moved in
let worker_fut = KvbmWorker::worker_task(
dev_layout,
lb,
leader_data,
lt,
cfg.clone(),
ct.clone(),
handler_tx,
layout_ready_tx,
scheduler,
);
// If worker_task returns Result, handle/log it inside the spawned task.
tokio::spawn(async move {
if let Err(e) = worker_fut.await {
tracing::error!("worker_task exited with error: {e:#}");
}
});
// 3) wait for the worker’s layout allocation readiness
match layout_ready_rx.await {
Ok(_) => tracing::info!("worker layout allocation finished."),
Err(_) => tracing::warn!("worker layout readiness channel dropped"),
}
// 4) wait for leader to finish its side of initialization
KvbmWorker::leader_readiness_sync(cfg.clone(), ct.clone()).await?;
Ok::<(), anyhow::Error>(())
}
},
cancel_token_for_task,
"kvbm-worker-task",
)?;
Ok((task, handler_rx))
}
/// One-time use method to extract the block transfer handler from the worker. /// One-time use method to extract the block transfer handler from the worker.
/// ///
/// This is a bit of a hack. Improve the API design around this in the future. /// This is a bit of a hack. Improve the API design around this in the future.
...@@ -248,15 +430,11 @@ impl KvbmWorker { ...@@ -248,15 +430,11 @@ impl KvbmWorker {
Ok(blocks) Ok(blocks)
} }
async fn worker_task( async fn leader_barrier_sync(
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
mut layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
config: KvbmWorkerConfig, config: KvbmWorkerConfig,
cancel_token: CancellationToken, cancel_token: CancellationToken,
handler_tx: oneshot::Sender<BlockTransferHandler>, bytes_per_block: usize,
scheduler_client: Option<TransferSchedulerClient>, ) -> anyhow::Result<KvbmLeaderData> {
) -> anyhow::Result<()> {
let drt = config.drt.clone(); let drt = config.drt.clone();
let worker_id = drt let worker_id = drt
...@@ -266,30 +444,61 @@ impl KvbmWorker { ...@@ -266,30 +444,61 @@ impl KvbmWorker {
))? ))?
.id() as usize; .id() as usize;
let barrier_id_worker_to_leader =
format!("{}{}", config.barrier_id_prefix, "-worker-to-leader");
tracing::info!( tracing::info!(
"Worker {} waiting on barrier {}", "Worker {} waiting on barrier {}",
worker_id, worker_id,
config.barrier_id barrier_id_worker_to_leader
); );
let worker_barrier = WorkerBarrier::<KvbmLeaderData, KvbmWorkerData>::new( let worker_to_leader_barrier = WorkerBarrier::<(), KvbmWorkerData>::new(
config.barrier_id, barrier_id_worker_to_leader,
worker_id.to_string(), worker_id.to_string(),
); );
let worker_data = KvbmWorkerData { let worker_data = KvbmWorkerData {
num_device_blocks: config.num_device_blocks, num_device_blocks: config.num_device_blocks,
bytes_per_block,
}; };
tokio::select! {
_ = cancel_token.cancelled() => {
return Err(anyhow::anyhow!("Cancelled"))
}
_leader_data = worker_to_leader_barrier.sync(&drt, &worker_data) => {
_leader_data
}
}
.map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
tracing::debug!(
"Worker {} sent the worker data in worker to leader phase",
worker_id
);
let barrier_id_leader_to_worker =
format!("{}{}", config.barrier_id_prefix, "-leader-to-worker");
tracing::info!(
"Worker {} waiting on barrier {}",
worker_id,
barrier_id_leader_to_worker
);
let leader_to_worker_barrier = WorkerBarrier::<KvbmLeaderData, ()>::new(
barrier_id_leader_to_worker,
worker_id.to_string(),
);
let leader_data = tokio::select! { let leader_data = tokio::select! {
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
return Ok(()) return Err(anyhow::anyhow!("Cancelled"))
} }
leader_data = worker_barrier.sync(&drt, &worker_data) => { leader_data = leader_to_worker_barrier.sync(&drt, &()) => {
leader_data leader_data
} }
} }
.map_err(|e| anyhow::anyhow!("Failed to sync worker barrier: {:?}", e))?; .map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
tracing::info!( tracing::info!(
"Worker {} received leader data: {:?}", "Worker {} received leader data: {:?}",
...@@ -297,6 +506,68 @@ impl KvbmWorker { ...@@ -297,6 +506,68 @@ impl KvbmWorker {
leader_data leader_data
); );
Ok(leader_data)
}
async fn leader_readiness_sync(
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
) -> anyhow::Result<()> {
let drt = config.drt.clone();
let worker_id = drt
.primary_lease()
.ok_or(anyhow::anyhow!(
"unable to get primary lease; check that drt is not static"
))?
.id() as usize;
let barrier_id_leader_readiness =
format!("{}{}", config.barrier_id_prefix, "-leader-ready");
tracing::info!(
"Worker {} waiting on barrier {}",
worker_id,
barrier_id_leader_readiness
);
let leader_readiness_barrier =
WorkerBarrier::<(), ()>::new(barrier_id_leader_readiness, worker_id.to_string());
// leader_data is not important in the leader readiness case
tokio::select! {
_ = cancel_token.cancelled() => {
return Err(anyhow::anyhow!("Cancelled"))
}
_leader_data = leader_readiness_barrier.sync(&drt, &()) => {
_leader_data
}
}
.map_err(|e| anyhow::anyhow!("Failed to sync leader readiness barrier: {:?}", e))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn worker_task(
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
mut layout_builder: LayoutConfigBuilder,
leader_data: KvbmLeaderData,
layout_type: LayoutType,
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
handler_tx: oneshot::Sender<BlockTransferHandler>,
layout_ready_tx: oneshot::Sender<String>,
scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<()> {
let drt = config.drt.clone();
let worker_id = drt
.primary_lease()
.ok_or(anyhow::anyhow!(
"unable to get primary lease; check that drt is not static"
))?
.id() as usize;
let agent = build_agent(worker_id, leader_data.num_disk_blocks > 0)?; let agent = build_agent(worker_id, leader_data.num_disk_blocks > 0)?;
let transfer_context = Arc::new(TransferContext::new( let transfer_context = Arc::new(TransferContext::new(
...@@ -380,6 +651,10 @@ impl KvbmWorker { ...@@ -380,6 +651,10 @@ impl KvbmWorker {
cancel_token.clone(), cancel_token.clone(),
)?; )?;
if layout_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("worker receiver dropped before result was sent");
}
// TODO: Some sort of fancy loop here. // TODO: Some sort of fancy loop here.
// For now, just wait for cancellation. // For now, just wait for cancellation.
cancel_token.cancelled().await; cancel_token.cancelled().await;
......
...@@ -334,9 +334,10 @@ impl FullyContiguousConfig { ...@@ -334,9 +334,10 @@ impl FullyContiguousConfig {
config.validate()?; config.validate()?;
let alignment = config.alignment; let alignment = config.alignment;
let memory_region_size = config.page_size * config.inner_dim * config.dtype_width_bytes; let outer_dim_stride_in_bytes =
let outer_dim_stride_in_bytes = memory_region_size; config.page_size * config.inner_dim * config.dtype_width_bytes;
let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim; let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim;
let memory_region_size = layer_stride_in_bytes;
let natural_block_stride = config.num_layers * layer_stride_in_bytes; let natural_block_stride = config.num_layers * layer_stride_in_bytes;
let block_stride_in_bytes = if alignment > 1 { let block_stride_in_bytes = if alignment > 1 {
......
...@@ -456,6 +456,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPool<S, L, M> ...@@ -456,6 +456,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPool<S, L, M>
&self, &self,
sequence_hashes: &[SequenceHash], sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> { ) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
tracing::debug!("find matching for sequence_hashes: {:?}", sequence_hashes);
self._match_sequence_hashes(sequence_hashes)? self._match_sequence_hashes(sequence_hashes)?
.blocking_recv() .blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)? .map_err(|_| BlockPoolError::ProgressEngineShutdown)?
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
## Overview ## Overview
This suite validates determinism properties of the API-backed LLM under fixed sampling parameters and optionally across prefix cache resets. The tests can automatically start a local vLLM server, warm it up, and compare responses for identical prompts over multiple iterations. This suite validates the determinism properties of the API-backed LLM under fixed sampling parameters and, optionally, across prefix cache resets. The tests can automatically start a local LLM server—either a vLLM server or a TensorRT-LLM server—warm it up, and compare responses for identical prompts over multiple iterations. The suite also automatically detects whether the vLLM or TensorRT-LLM wheel is installed and starts the corresponding server.
## Files ## Files
- `test_determinism.py` — comprehensive determinism tests with automatic vLLM server lifecycle and warmup. - `test_determinism.py` — comprehensive determinism tests with automatic LLM server lifecycle and warmup.
- `test_determinism_with_cache_reset` — run test with warmup, reset cache, then run again without warmup to test determinism across cache reset boundary - `test_determinism_with_cache_reset` — run test with warmup, reset cache, then run again without warmup to test determinism across cache reset boundary
- `test_concurrent_determinism_with_ifeval` — send parametrized number of IFEval prompts (default: 120) with controlled concurrency, with warmup, then reset cache and test again without warmup to validate determinism across cache reset - `test_concurrent_determinism_with_ifeval` — send parametrized number of IFEval prompts (default: 120) with controlled concurrency, with warmup, then reset cache and test again without warmup to validate determinism across cache reset
...@@ -19,7 +19,7 @@ This suite validates determinism properties of the API-backed LLM under fixed sa ...@@ -19,7 +19,7 @@ This suite validates determinism properties of the API-backed LLM under fixed sa
## How It Works ## How It Works
- A `VLLMServerManager` fixture (`vllm_server`) launches `vllm serve` with the Dynamo connector and optional cache block overrides. - A `LLMServerManager` fixture (`llm_server`) launches `vllm serve` or `trtllm-serve` with the Dynamo connector and optional cache block overrides.
- A `tester` fixture binds the test client to the running server's base URL. - A `tester` fixture binds the test client to the running server's base URL.
- The test performs a comprehensive warmup across prompts, then executes repeated requests and checks that responses are identical (deterministic). An optional cache reset phase re-validates determinism across the reset boundary. - The test performs a comprehensive warmup across prompts, then executes repeated requests and checks that responses are identical (deterministic). An optional cache reset phase re-validates determinism across the reset boundary.
...@@ -43,8 +43,8 @@ Environment variables control server settings and test load: ...@@ -43,8 +43,8 @@ Environment variables control server settings and test load:
- Server/model - Server/model
- `KVBM_MODEL_ID` (default: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B`) - `KVBM_MODEL_ID` (default: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B`)
- `KVBM_VLLM_PORT` (default: `8000`) - `KVBM_SERVER_PORT` (default: `8000`)
- `KVBM_VLLM_START_TIMEOUT` (default: `300` seconds) - `KVBM_SERVER_START_TIMEOUT` (default: `300` seconds)
- Cache size overrides - Cache size overrides
- `KVBM_CPU_BLOCKS` (used via test parametrization; default: `10000`) - `KVBM_CPU_BLOCKS` (used via test parametrization; default: `10000`)
...@@ -90,5 +90,5 @@ pytest -v -m "kvbm" -s ...@@ -90,5 +90,5 @@ pytest -v -m "kvbm" -s
- Warmup is critical to avoid initialization effects impacting determinism. - Warmup is critical to avoid initialization effects impacting determinism.
- For faster local iteration, reduce `KVBM_MAX_ITERATIONS` and/or increase intervals. - For faster local iteration, reduce `KVBM_MAX_ITERATIONS` and/or increase intervals.
- Logs are written under the per-test directory created by `tests/conftest.py` and include the vLLM server stdout/stderr. - Logs are written under the per-test directory created by `tests/conftest.py` and include the LLM server stdout/stderr.
- Tests use the static port defined by `KVBM_VLLM_PORT` for vLLM server communication. - Tests use the static port defined by `KVBM_SERVER_PORT` for LLM server communication.
\ No newline at end of file
...@@ -13,6 +13,7 @@ before validation) to avoid server initialization effects that could ...@@ -13,6 +13,7 @@ before validation) to avoid server initialization effects that could
impact determinism measurements. impact determinism measurements.
""" """
import importlib.util
import logging import logging
import os import os
import signal import signal
...@@ -21,8 +22,9 @@ import time ...@@ -21,8 +22,9 @@ import time
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime from datetime import datetime
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, TextIO, Tuple from typing import Any, Dict, List, Optional, TextIO, Tuple
import pytest import pytest
import requests import requests
...@@ -38,8 +40,13 @@ pytestmark = [ ...@@ -38,8 +40,13 @@ pytestmark = [
] ]
class VLLMServerManager: class ServerType(str, Enum):
"""Manages vLLM server lifecycle for determinism testing.""" vllm = "vllm"
trtllm = "trtllm"
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__( def __init__(
self, self,
...@@ -48,8 +55,10 @@ class VLLMServerManager: ...@@ -48,8 +55,10 @@ class VLLMServerManager:
cpu_cache_blocks: Optional[int] = None, cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None, gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
): ):
self.port = port or int(os.environ.get("KVBM_VLLM_PORT", "8000")) self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}" self.base_url = base_url or f"http://localhost:{self.port}"
self.process: Optional[subprocess.Popen] = None self.process: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks self.cpu_cache_blocks = cpu_cache_blocks
...@@ -63,11 +72,41 @@ class VLLMServerManager: ...@@ -63,11 +72,41 @@ class VLLMServerManager:
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}" f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
) )
self.server_log_file = ( self.server_log_file = (
self.log_dir / f"vllm_server_{config_str}_{timestamp}.log" self.log_dir / f"{self.server_type}_server_{config_str}_{timestamp}.log"
) )
self.server_stdout_file: Optional[TextIO] = None self.server_stdout_file: Optional[TextIO] = None
self.server_stderr_file: Optional[TextIO] = None self.server_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
"DYN_LOG": os.environ.get(
"DYN_LOG", "debug,dynamo_llm::block_manager::layout=error"
),
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct serve command # Construct serve command
self.server_cmd = [ self.server_cmd = [
"vllm", "vllm",
...@@ -85,27 +124,52 @@ class VLLMServerManager: ...@@ -85,27 +124,52 @@ class VLLMServerManager:
if gpu_cache_blocks is not None: if gpu_cache_blocks is not None:
self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)]) self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)])
# Environment for the process def _set_up_trtllm_config(self, gpu_cache_blocks):
self.env = os.environ.copy() config_path = os.environ.get(
self.env.update( "KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml"
{
"RUST_BACKTRACE": "1",
"DYN_LOG": os.environ.get(
"DYN_LOG", "debug,dynamo_llm::block_manager::layout=error"
),
"VLLM_SERVER_DEV_MODE": "1",
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
) )
llm_api_config: dict[str, Any] = {}
llm_api_config[
"cuda_graph_config"
] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM
llm_api_config["kv_cache_config"] = {
"enable_partial_reuse": False,
"free_gpu_memory_fraction": 0.10, # Set a small GPU fraction so that we can evict/reset the on-device kv cache faster
}
llm_api_config["kv_connector_config"] = {
"connector_module": "dynamo.llm.trtllm_integration.connector",
"connector_scheduler_class": "DynamoKVBMConnectorLeader",
"connector_worker_class": "DynamoKVBMConnectorWorker",
}
# CPU cache blocks override via env # GPU blocks override
if cpu_cache_blocks is not None: if gpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks) del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
llm_api_config["kv_cache_config"]["max_tokens"] = (
int(gpu_cache_blocks) * 32
) # TRTLLM defaults 32 tokens per block
# Construct serve command
self.server_cmd = [
"trtllm-serve",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--host",
"localhost",
"--port",
str(self.port),
"--backend",
"pytorch",
"--extra_llm_api_options",
config_path,
]
import yaml
with open(config_path, "w") as f:
yaml.dump(llm_api_config, f, default_flow_style=False, sort_keys=False)
def start_server(self, timeout: int = 300) -> bool: def start_server(self, timeout: int = 300) -> bool:
"""Start vLLM server and wait for readiness.""" """Start LLM server and wait for readiness."""
if self.is_server_running(): if self.is_server_running():
self.stop_server() self.stop_server()
time.sleep(2) time.sleep(2)
...@@ -119,7 +183,7 @@ class VLLMServerManager: ...@@ -119,7 +183,7 @@ class VLLMServerManager:
) )
if self.server_stdout_file is not None: if self.server_stdout_file is not None:
self.server_stdout_file.write( self.server_stdout_file.write(
f"=== vLLM Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n" f"=== {self.server_type} Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n"
) )
self.server_stdout_file.flush() self.server_stdout_file.flush()
...@@ -147,7 +211,7 @@ class VLLMServerManager: ...@@ -147,7 +211,7 @@ class VLLMServerManager:
return False return False
def stop_server(self): def stop_server(self):
"""Stop vLLM server and close logs.""" """Stop LLM server and close logs."""
if self.process: if self.process:
try: try:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
...@@ -205,7 +269,12 @@ class VLLMServerManager: ...@@ -205,7 +269,12 @@ class VLLMServerManager:
class DeterminismTester: class DeterminismTester:
"""Test class for model determinism validation.""" """Test class for model determinism validation."""
def __init__(self, base_url: Optional[str] = None, model_id: Optional[str] = None): def __init__(
self,
base_url: Optional[str] = None,
model_id: Optional[str] = None,
server_type: Optional[str] = ServerType.vllm,
):
# Allow environment override for flexibility in CI/local runs # Allow environment override for flexibility in CI/local runs
self.base_url = ( self.base_url = (
base_url or os.environ.get("DYNAMO_API_BASE_URL") or "http://localhost:8000" base_url or os.environ.get("DYNAMO_API_BASE_URL") or "http://localhost:8000"
...@@ -215,6 +284,7 @@ class DeterminismTester: ...@@ -215,6 +284,7 @@ class DeterminismTester:
or os.environ.get("KVBM_MODEL_ID") or os.environ.get("KVBM_MODEL_ID")
or "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" or "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
) )
self.server_type = server_type
self.shakespeare_file = Path("t8.shakespeare.txt") self.shakespeare_file = Path("t8.shakespeare.txt")
self.max_iterations = int(os.environ.get("KVBM_MAX_ITERATIONS", "500")) self.max_iterations = int(os.environ.get("KVBM_MAX_ITERATIONS", "500"))
...@@ -298,11 +368,28 @@ class DeterminismTester: ...@@ -298,11 +368,28 @@ class DeterminismTester:
def reset_prefix_cache(self): def reset_prefix_cache(self):
"""Reset the prefix cache.""" """Reset the prefix cache."""
print("Resetting prefix cache...") print("Resetting prefix cache...")
response = requests.post( if self.server_type == ServerType.trtllm:
f"{self.base_url}/reset_prefix_cache", # TRTLLM doesn't support reset_prefix_cache endpoint API
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")), # 300 shakespeare content could evict the 0.1 x 80G (~1700 blocks) on-device cache
) shakespeare_count = 300
response.raise_for_status() for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
else:
response = requests.post(
f"{self.base_url}/reset_prefix_cache",
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")),
)
response.raise_for_status()
print("Cache reset done") print("Cache reset done")
def warmup_server(self): def warmup_server(self):
...@@ -623,11 +710,11 @@ class DeterminismTester: ...@@ -623,11 +710,11 @@ class DeterminismTester:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def vllm_server(request, runtime_services): def llm_server(request, runtime_services):
"""Start and stop vLLM server for each test with optional cache block overrides. """Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use: To parametrize, use:
@pytest.mark.parametrize("vllm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True) @pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True)
""" """
logger = logging.getLogger("pytest") logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -639,17 +726,27 @@ def vllm_server(request, runtime_services): ...@@ -639,17 +726,27 @@ def vllm_server(request, runtime_services):
# Put logs in the per-test directory set up by tests/conftest.py # Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name) log_dir = Path(request.node.name)
server_manager = VLLMServerManager( if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
raise Exception(
"Neither the vllm nor the tensorrt_llm module is available in the current environment."
)
server_manager = LLMServerManager(
port=port, port=port,
cpu_cache_blocks=cpu_blocks, cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks, gpu_cache_blocks=gpu_blocks,
log_dir=log_dir, log_dir=log_dir,
server_type=server_type,
) )
start_timeout = int(os.environ.get("KVBM_VLLM_START_TIMEOUT", "300")) start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout): if not server_manager.start_server(timeout=start_timeout):
pytest.fail( pytest.fail(
f"Failed to start vLLM server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
) )
yield server_manager yield server_manager
...@@ -658,9 +755,11 @@ def vllm_server(request, runtime_services): ...@@ -658,9 +755,11 @@ def vllm_server(request, runtime_services):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def tester(vllm_server): def tester(llm_server):
"""Create determinism tester bound to the running server's base URL.""" """Create determinism tester bound to the running server's base URL."""
t = DeterminismTester(base_url=vllm_server.base_url) t = DeterminismTester(
base_url=llm_server.base_url, server_type=llm_server.server_type
)
t.download_shakespeare_text() t.download_shakespeare_text()
return t return t
...@@ -669,13 +768,13 @@ class TestDeterminism: ...@@ -669,13 +768,13 @@ class TestDeterminism:
"""Test class for determinism validation.""" """Test class for determinism validation."""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"vllm_server", "llm_server",
[ [
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))}, {"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
], ],
indirect=True, indirect=True,
) )
def test_determinism_with_cache_reset(self, tester, vllm_server, runtime_services): def test_determinism_with_cache_reset(self, tester, llm_server, runtime_services):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup.""" """Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
print("\n" + "=" * 70) print("\n" + "=" * 70)
print("STARTING DETERMINISM TEST (WITH CACHE RESET)") print("STARTING DETERMINISM TEST (WITH CACHE RESET)")
...@@ -797,7 +896,7 @@ class TestDeterminism: ...@@ -797,7 +896,7 @@ class TestDeterminism:
), f"Model is not deterministic across cache reset: {total_failed} comparisons failed" ), f"Model is not deterministic across cache reset: {total_failed} comparisons failed"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"vllm_server", "llm_server",
[ [
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))}, {"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))},
], ],
...@@ -818,7 +917,7 @@ class TestDeterminism: ...@@ -818,7 +917,7 @@ class TestDeterminism:
def test_concurrent_determinism_with_ifeval( def test_concurrent_determinism_with_ifeval(
self, self,
tester, tester,
vllm_server, llm_server,
runtime_services, runtime_services,
num_concurrent, num_concurrent,
max_tokens, max_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment