Unverified Commit a68c2f8f authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

feat: DIS-373 dynamo KVBM connector API integration with TRTLLM (#2544)


Signed-off-by: default avatarrichardhuo-nv <rihuo@nvidia.com>
parent 43a26958
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Loader for the Rust-based TensorRT-LLM integration objects, using objects from _vllm_integration for now
"""
try:
# TODO: use TRTLLM own integration module
from dynamo._core import _vllm_integration
# Runtime - dynamically loaded classes from Rust extension
KvbmRequest = getattr(_vllm_integration, "KvbmRequest")
KvbmBlockList = getattr(_vllm_integration, "KvbmBlockList")
BlockState = getattr(_vllm_integration, "BlockState")
BlockStates = getattr(_vllm_integration, "BlockStates")
SlotUpdate = getattr(_vllm_integration, "SlotUpdate")
KvConnectorWorker = getattr(_vllm_integration, "PyTrtllmKvConnectorWorker")
KvConnectorLeader = getattr(_vllm_integration, "PyTrtllmKvConnectorLeader")
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
except ImportError:
print(
"Failed to import Dynamo KVBM. TensorRT-LLM integration will not be available."
)
KvbmRequest = None
KvbmBlockList = None
BlockState = None
BlockStates = None
SlotUpdate = None
KvConnectorWorker = None
KvConnectorLeader = None
SchedulerOutput = None
__all__ = [
"KvbmRequest",
"KvbmBlockList",
"BlockState",
"BlockStates",
"SlotUpdate",
"KvConnectorWorker",
"KvConnectorLeader",
"SchedulerOutput",
]
......@@ -14,7 +14,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadat
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.worker.cache_engine import CacheEngine
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -29,7 +28,7 @@ if TYPE_CHECKING:
# )
# from dynamo.llm.vllm_integration.rust import SchedulerOutput as RustSchedulerOutput
from dynamo.llm import BlockManager, KvbmLeader
from dynamo.llm import KvbmLeader
from dynamo.llm.vllm_integration.kv_cache_utils import (
find_and_set_available_port_from_env,
)
......@@ -64,25 +63,12 @@ class KvConnectorLeader:
self.vllm_config = vllm_config
world_size = vllm_config.parallel_config.world_size
bytes_per_block = CacheEngine.get_cache_block_size(
vllm_config.cache_config,
vllm_config.model_config,
vllm_config.parallel_config,
)
total_bytes = bytes_per_block * world_size
leader = KvbmLeader(total_bytes, world_size, drt=self.drt)
block_manager = BlockManager(
0,
leader,
vllm_config.cache_config.block_size,
disable_device_pool=True,
)
leader = KvbmLeader(world_size, drt=self.drt)
print(f"KvConnectorLeader initialized with engine_id: {engine_id}")
self._connector = RustKvConnectorLeader(
engine_id, self.drt, block_manager, leader
engine_id, self.drt, vllm_config.cache_config.block_size, leader
)
# KV Connector
......
......@@ -8,7 +8,7 @@ mod zmq;
mod leader;
mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig};
pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::BlockTransferHandler;
pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
......@@ -130,21 +130,31 @@ mod tests {
vec![Arc::new(MockTensor::new(vec![2, NUM_BLOCKS, 4096]))];
let config = KvbmWorkerConfig::builder()
.barrier_id(barrier_id.clone())
.barrier_id_prefix(barrier_id.clone())
.num_device_blocks(NUM_BLOCKS)
.tensors(tensors)
.worker_id(i)
.device_id(i)
.build()?;
let worker = KvbmWorker::new(config).await?;
let worker = KvbmWorker::new(config, false).await?;
workers.push(worker);
}
let host_blocks = KvbmLeaderNumBlocksConfig {
cache_size_in_gb: 1.0,
num_blocks_overriden: NUM_BLOCKS,
};
let disk_blocks = KvbmLeaderNumBlocksConfig {
cache_size_in_gb: 1.0,
num_blocks_overriden: NUM_BLOCKS,
};
let leader_config = KvbmLeaderConfig::builder()
.barrier_id(barrier_id)
.barrier_id_prefix(barrier_id)
.world_size(num_workers)
.num_host_blocks(NUM_BLOCKS)
.num_disk_blocks(NUM_BLOCKS)
.host_blocks_config(host_blocks)
.disk_blocks_config(disk_blocks)
.build()?;
// When/if this returns, we know that all the workers were also successful.
......
......@@ -9,12 +9,16 @@ use zmq::*;
use dynamo_runtime::utils::leader_worker_barrier::LeaderBarrier;
use anyhow::{Context, anyhow};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tokio::sync::OnceCell;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tokio::time::sleep;
/// Data that is sent to workers over ETCD to establish a ZMQ connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -25,17 +29,31 @@ pub struct KvbmLeaderData {
pub num_disk_blocks: usize,
}
#[derive(Builder, Clone, Debug)]
pub struct KvbmLeaderConfig {
#[builder(default = "0")]
num_host_blocks: usize,
#[derive(Builder, Clone, Debug, Default)]
pub struct KvbmLeaderNumBlocksConfig {
#[builder(default = "0.0")]
pub cache_size_in_gb: f64,
#[builder(default = "0")]
num_disk_blocks: usize,
pub num_blocks_overriden: usize,
}
fn compute_num_blocks(
num_blocks_config: &KvbmLeaderNumBlocksConfig,
bytes_per_block: usize,
) -> usize {
if num_blocks_config.num_blocks_overriden > 0 {
num_blocks_config.num_blocks_overriden
} else {
((num_blocks_config.cache_size_in_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize
}
}
#[derive(Builder, Clone, Debug)]
pub struct KvbmLeaderConfig {
/// The barrier id to use for syncing with workers.
#[builder(default = "String::from(\"kvbm\")")]
barrier_id: String,
barrier_id_prefix: String,
/// The world size.
#[builder(default = "1")]
......@@ -47,6 +65,12 @@ pub struct KvbmLeaderConfig {
#[builder(setter(strip_option))]
drt: Option<DistributedRuntime>,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
host_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
disk_blocks_config: KvbmLeaderNumBlocksConfig,
}
impl KvbmLeaderConfig {
......@@ -55,6 +79,15 @@ impl KvbmLeaderConfig {
}
}
#[derive(Debug, Default)]
pub struct KvbmLeaderState {
pub num_device_blocks: Arc<AtomicUsize>,
pub num_host_blocks: Arc<AtomicUsize>,
pub num_disk_blocks: Arc<AtomicUsize>,
pub workers_allocation_ready: Arc<AtomicBool>,
pub workers_ready_notify: Arc<Notify>,
}
/// The leader of the KVBM.
///
/// This is responsible for:
......@@ -62,9 +95,13 @@ impl KvbmLeaderConfig {
/// - Syncing the leader barrier with workers.
/// - Sending messages to workers.
pub struct KvbmLeader {
num_device_blocks: usize,
zmq_leader: ZmqActiveMessageLeader,
state: Arc<KvbmLeaderState>,
zmq_leader: Arc<OnceCell<ZmqActiveMessageLeader>>,
config: KvbmLeaderConfig,
//readiness flags
workers_sync_ready: Arc<AtomicBool>,
workers_sync_ready_notify: Arc<Notify>,
workers_sync_done: Arc<AtomicBool>,
}
impl KvbmLeader {
......@@ -76,34 +113,90 @@ impl KvbmLeader {
}
};
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
config.world_size,
config.barrier_id
let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?;
let leader = Self {
state: Arc::new(KvbmLeaderState::default()),
zmq_leader: Arc::new(tokio::sync::OnceCell::new()),
config,
workers_sync_ready: Arc::new(AtomicBool::new(false)),
workers_sync_ready_notify: Arc::new(Notify::new()),
workers_sync_done: Arc::new(AtomicBool::new(false)),
};
let cancel_token = tokio_util::sync::CancellationToken::new();
// The leader_sockets struct cannot be cloned,
// so we use a tuple to "struct" the two urls
let leader_urls = (
leader_sockets.pub_url.clone(),
leader_sockets.ack_url.clone(),
);
leader.spawn_barrier_task(drt, leader_urls);
leader.spawn_zmq_task(leader_sockets, cancel_token);
let leader_sockets = new_leader_sockets("tcp://127.0.0.1")?;
Ok(leader)
}
fn spawn_barrier_task(&self, drt: DistributedRuntime, leader_urls: (String, String)) {
let state = self.state.clone();
let leader_config = self.config.clone();
let ready = Arc::clone(&self.workers_sync_ready);
let notify = Arc::clone(&self.workers_sync_ready_notify);
let done = Arc::clone(&self.workers_sync_done);
let zmq_data = Arc::new(KvbmLeaderData {
pub_url: leader_sockets.pub_url.clone(),
ack_url: leader_sockets.ack_url.clone(),
num_host_blocks: config.num_host_blocks,
num_disk_blocks: config.num_disk_blocks,
tokio::spawn(async move {
match KvbmLeader::run_barrier_sync(drt, leader_urls, leader_config).await {
Ok((num_device_blocks, num_host_blocks, num_disk_blocks)) => {
// write back results
state
.num_device_blocks
.store(num_device_blocks, Ordering::Release);
state
.num_host_blocks
.store(num_host_blocks, Ordering::Release);
state
.num_disk_blocks
.store(num_disk_blocks, Ordering::Release);
ready.store(true, Ordering::Release);
done.store(true, Ordering::Release);
notify.notify_waiters();
}
Err(e) => {
tracing::error!("Barrier sync failed: {e:?}");
done.store(true, Ordering::Release);
notify.notify_waiters();
}
}
});
}
async fn run_barrier_sync(
drt: DistributedRuntime,
leader_urls: (String, String),
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<(usize, usize, usize)> {
let barrier_id_worker_to_leader =
format!("{}{}", leader_config.barrier_id_prefix, "-worker-to-leader");
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_worker_to_leader
);
// Build our leader barrier and publish the data.
// TODO: Use a separate timeout parameter from the ZMQ connection timeout
let leader_barrier: LeaderBarrier<KvbmLeaderData, worker::KvbmWorkerData> =
let worker_to_leader_barrier: LeaderBarrier<(), worker::KvbmWorkerData> =
LeaderBarrier::new(
config.barrier_id.clone(),
config.world_size,
Some(Duration::from_secs(config.leader_init_timeout_secs)),
barrier_id_worker_to_leader.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let worker_data = leader_barrier
.sync(&drt, zmq_data.as_ref())
let worker_data = worker_to_leader_barrier
.sync(&drt, &())
.await
.map_err(|e| anyhow::anyhow!("Failed to sync leader barrier: {:?}", e))?;
.map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
let num_device_blocks = worker_data
.values()
......@@ -111,46 +204,250 @@ impl KvbmLeader {
.min()
.unwrap();
tracing::info!("Leader barrier synced with {} workers", config.world_size);
// TODO: this works for TP, need to redefine bytes_per_block when we enable the DP/PP
let bytes_per_block: usize = worker_data.values().map(|d| d.bytes_per_block).sum();
assert!(
bytes_per_block > 0,
"bytes_per_block must be greater than 0"
);
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
tracing::debug!("Worker data: {:?}", worker_data);
// Now, create our active message leader.
// This also blocks until a ZMQ connection has been established.
let cancel_token = CancellationToken::new();
let zmq_leader = ZmqActiveMessageLeader::new(
leader_sockets,
config.world_size,
Duration::from_secs(config.leader_init_timeout_secs),
cancel_token.clone(),
)
.await?;
Ok(Self {
num_device_blocks,
zmq_leader,
config,
let num_host_blocks =
compute_num_blocks(&leader_config.host_blocks_config, bytes_per_block);
let num_disk_blocks =
compute_num_blocks(&leader_config.disk_blocks_config, bytes_per_block);
// Start the second sync to transfer num_host_blocks and num_disk_blocks to worker
let barrier_id_leader_to_worker =
format!("{}{}", leader_config.barrier_id_prefix, "-leader-to-worker");
tracing::info!(
"Syncing leader barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_leader_to_worker
);
let (leader_pub_url, leader_ack_url) = leader_urls;
let zmq_data_leader_to_worker = Arc::new(KvbmLeaderData {
pub_url: leader_pub_url,
ack_url: leader_ack_url,
num_host_blocks,
num_disk_blocks,
});
let leader_to_worker_barrier: LeaderBarrier<KvbmLeaderData, ()> = LeaderBarrier::new(
barrier_id_leader_to_worker.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _worker_data = leader_to_worker_barrier
.sync(&drt, zmq_data_leader_to_worker.as_ref())
.await
.map_err(|e| anyhow::anyhow!("Failed to sync leader to worker barrier: {:?}", e))?;
tracing::info!(
"Worker to leader barrier synced with {} workers",
leader_config.world_size
);
Ok((num_device_blocks, num_host_blocks, num_disk_blocks))
}
fn spawn_zmq_task(
&self,
leader_sockets: LeaderSockets,
cancel: tokio_util::sync::CancellationToken,
) {
let cell = self.zmq_leader.clone();
let state = self.state.clone();
let world_size = self.config.world_size;
let timeout = self.config.leader_init_timeout_secs;
tokio::spawn(async move {
let res = ZmqActiveMessageLeader::new(
leader_sockets,
world_size,
std::time::Duration::from_secs(timeout),
cancel,
)
.await;
match res {
Ok(zmq) => {
let _ = cell.set(zmq);
// mark ready
state
.workers_allocation_ready
.store(true, Ordering::Release);
state.workers_ready_notify.notify_waiters();
}
Err(e) => {
tracing::error!("ZMQ init failed: {e:?}");
}
}
});
}
// This is supposed to be used in non-blocking leader initialization
pub fn spawn_leader_readiness_barrier(&self, drt: DistributedRuntime) {
let timeout_secs = self.config.leader_init_timeout_secs;
let state = self.state.clone();
let leader_config = self.config.clone();
let handle = drt.runtime().primary();
handle.spawn(async move {
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait until ZMQ marks ready or we time out.
let waited = tokio::time::timeout(
Duration::from_secs(timeout_secs),
state.workers_ready_notify.notified(),
)
.await;
if waited.is_err() {
tracing::error!(
"leader readiness barrier wait timed out after {timeout_secs} seconds"
);
return;
}
// Double-check the flag (Acquire) after wakeup.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
tracing::error!("leader readiness notify fired but flag not set; aborting");
return;
}
}
match KvbmLeader::run_leader_readiness(drt, leader_config).await {
Ok(()) => {
tracing::info!("leader readiness barrier synced!");
}
Err(e) => {
tracing::error!("leader readiness barrier failed: {e:?}");
}
}
});
}
// This is supposed to be used in blocking leader initialization
pub fn run_leader_readiness_barrier_blocking(
&self,
drt: DistributedRuntime,
) -> anyhow::Result<()> {
let state = self.state.clone();
let timeout_secs = self.config.leader_init_timeout_secs;
let leader_config = self.config.clone();
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async move {
// Create the future *before* checking the flag to avoid a lost-notify race.
let notified = state.workers_ready_notify.notified();
if !state.workers_allocation_ready.load(Ordering::Acquire) {
// Wait (with timeout) until ZMQ task marks ready.
tokio::time::timeout(Duration::from_secs(timeout_secs), notified)
.await
.map_err(|_| anyhow!("timed out waiting for workers_allocation_ready after {timeout_secs} seconds"))?;
// Double-check after wake to ensure the flag is actually set.
if !state.workers_allocation_ready.load(Ordering::Acquire) {
return Err(anyhow!(
"notified but workers_allocation_ready is still false"
));
}
}
KvbmLeader::run_leader_readiness(drt, leader_config).await
})
.context("leader readiness barrier failed")
})
}
async fn run_leader_readiness(
drt: DistributedRuntime,
leader_config: KvbmLeaderConfig,
) -> anyhow::Result<()> {
let barrier_id_leader_ready =
format!("{}{}", leader_config.barrier_id_prefix, "-leader-ready");
tracing::info!(
"Syncing leader readiness barrier with {} workers on barrier id {}",
leader_config.world_size,
barrier_id_leader_ready
);
let leader_readiness_barrier: LeaderBarrier<(), ()> = LeaderBarrier::new(
barrier_id_leader_ready.clone(),
leader_config.world_size,
Some(Duration::from_secs(leader_config.leader_init_timeout_secs)),
);
let _ = leader_readiness_barrier
.sync(&drt, &())
.await
.map_err(|e| {
anyhow::anyhow!("Failed to sync leader readiness barrier on leader: {:?}", e)
})?;
Ok(())
}
pub async fn transfer_blocks_request(
&self,
request: BlockTransferRequest,
) -> anyhow::Result<oneshot::Receiver<()>> {
let zmq = self
.zmq_leader
.get()
.ok_or_else(|| anyhow::anyhow!("ZMQ leader not ready"))?;
let data = vec![serde_json::to_vec(&request)?];
self.zmq_leader
.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data)
.await
zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await
}
pub fn is_worker_sync_ready(&self) -> bool {
self.workers_sync_ready.load(Ordering::Acquire)
}
pub fn is_worker_sync_done(&self) -> bool {
self.workers_sync_done.load(Ordering::Acquire)
}
pub fn num_device_blocks(&self) -> usize {
self.num_device_blocks
self.state.num_device_blocks.load(Ordering::Acquire)
}
pub fn num_host_blocks(&self) -> usize {
self.config.num_host_blocks
self.state.num_host_blocks.load(Ordering::Acquire)
}
pub fn num_disk_blocks(&self) -> usize {
self.config.num_disk_blocks
self.state.num_disk_blocks.load(Ordering::Acquire)
}
pub async fn wait_worker_sync_ready(&self) -> bool {
if self.is_worker_sync_ready() {
return true;
}
if self.is_worker_sync_done() {
return false;
}
let notified = self.workers_sync_ready_notify.notified();
if self.is_worker_sync_ready() {
return true;
}
if self.is_worker_sync_done() {
return false;
}
// bounded wait
tokio::select! {
_ = notified => {
self.is_worker_sync_ready()
}
_ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false,
}
}
}
......@@ -35,6 +35,7 @@ use dynamo_runtime::{
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KvbmWorkerData {
pub num_device_blocks: usize,
pub bytes_per_block: usize,
}
pub fn load_and_validate_tensors(
......@@ -82,7 +83,7 @@ pub fn load_and_validate_tensors(
Ok((device_tensors, shape.unwrap()))
}
#[derive(Builder)]
#[derive(Builder, Clone)]
#[builder(pattern = "owned")]
pub struct KvbmWorkerConfig {
drt: DistributedRuntime,
......@@ -101,8 +102,11 @@ pub struct KvbmWorkerConfig {
#[builder(default = "2")]
dtype_width_bytes: usize,
#[builder(default = false)]
is_fully_contiguous_layout: bool,
#[builder(default = "String::from(\"kvbm\")")]
barrier_id: String,
barrier_id_prefix: String,
#[builder(default = "None")]
scheduler_client: Option<TransferSchedulerClient>,
......@@ -132,7 +136,7 @@ pub struct KvbmWorker {
}
impl KvbmWorker {
pub async fn new(config: KvbmWorkerConfig) -> anyhow::Result<Self> {
pub async fn new(config: KvbmWorkerConfig, layout_blocking: bool) -> anyhow::Result<Self> {
tracing::info!(
"Initializing KvbmWorker with params: num_device_blocks={}, page_size={}, dtype_width_bytes={}",
config.num_device_blocks,
......@@ -153,62 +157,147 @@ impl KvbmWorker {
)));
}
let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks {
(false, shape[1])
} else if shape[1] >= config.num_device_blocks {
(true, shape[0])
let (layout_type, num_layers, outer_dim, inner_dim) = if !config.is_fully_contiguous_layout
{
let (outer_contiguous, outer_dim) = if shape[0] >= config.num_device_blocks {
(false, shape[1])
} else if shape[1] >= config.num_device_blocks {
(true, shape[0])
} else {
return Err(anyhow::anyhow!(format!(
"Unsupported kv cache layout. Got shape: {:?}",
shape
)));
};
let num_layers = device_tensors.len();
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
device_tensors.len(),
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::LayerSeparate { outer_contiguous },
num_layers,
outer_dim,
inner_dim,
)
} else {
return Err(anyhow::anyhow!(format!(
"Unsupported kv cache layout. Got shape: {:?}",
shape
)));
let num_layers = shape[1];
let outer_dim = shape[2];
let inner_dim = shape[3..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
num_layers,
outer_dim,
config.page_size,
inner_dim
);
(
LayoutType::FullyContiguous,
num_layers,
outer_dim,
inner_dim,
)
};
let inner_dim = shape[2..].iter().product::<usize>() / config.page_size;
tracing::info!(
"Inferred layout: num_layers={}, outer_dim={}, page_size={}, inner_dim={}",
device_tensors.len(),
outer_dim,
config.page_size,
inner_dim
);
let bytes_per_block =
num_layers * outer_dim * config.page_size * inner_dim * config.dtype_width_bytes;
let mut layout_builder_instance = LayoutConfigBuilder::default();
let layout_builder = layout_builder_instance
.num_layers(device_tensors.len())
.num_layers(num_layers)
.outer_dim(outer_dim)
.page_size(config.page_size)
.inner_dim(inner_dim)
.dtype_width_bytes(config.dtype_width_bytes);
let layout_type = LayoutType::LayerSeparate { outer_contiguous };
let device_layout = layout_builder
.num_blocks(config.num_device_blocks)
.build()?
.create_layout(layout_type, device_tensors)?;
let layout_builder_clone = layout_builder.clone();
let layout_builder = layout_builder.clone();
let (task, handler_rx) = if layout_blocking {
Self::run_blocking_layout_initialization(
config,
bytes_per_block,
device_layout,
layout_builder,
layout_type,
)
.await?
} else {
Self::run_non_blocking_layout_initialization(
config,
bytes_per_block,
device_layout,
layout_builder,
layout_type,
)
.await?
};
// add worker-connector scheduler here
// let scheduler = KvbmWorkerScheduler::new(config.scheduler.clone());
Ok(Self {
task: Some(task),
block_transfer_handler_rx: Some(handler_rx),
})
}
async fn run_blocking_layout_initialization(
config: KvbmWorkerConfig,
bytes_per_block: usize,
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.drt.primary_token().clone();
// barrier sync with leader to get the leader data
let leader_data = tokio::task::block_in_place(|| {
// This is now synchronous blocking code
// We need a separate current-thread runtime to block_on async calls here
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
KvbmWorker::leader_barrier_sync(
config.clone(),
cancel_token.clone(),
bytes_per_block,
)
.await
})
})?;
// establish a oneshot channel to get back the raw BlockTransferHandler
let (handler_tx, handler_rx) = oneshot::channel();
// establish a oneshot channel to block on the main routine to wait for layout allocation readiness
let (layout_ready_tx, layout_ready_rx) = oneshot::channel::<String>();
let scheduler_client = config.scheduler_client.clone();
let worker_config = config.clone();
// start background worker task to do layout allocation for host or disk
let task = CriticalTaskExecutionHandle::new(
move |cancel_token| {
KvbmWorker::worker_task(
device_layout,
layout_builder_clone,
layout_builder,
leader_data,
layout_type,
config,
worker_config,
cancel_token,
handler_tx,
layout_ready_tx,
scheduler_client,
)
},
......@@ -216,12 +305,105 @@ impl KvbmWorker {
"kvbm-worker-task",
)?;
Ok(Self {
task: Some(task),
block_transfer_handler_rx: Some(handler_rx),
})
// waiting for the worker layout allocation ready
match layout_ready_rx.await {
Ok(_) => tracing::info!("worker layout allocation finished."),
Err(_) => tracing::error!("Worker layout dropped without sending"),
}
let worker_config = config.clone();
let cancel_for_barrier = cancel_token.clone();
// wait until the leader finished the initialization of all components
tokio::task::block_in_place(|| {
// This is now synchronous blocking code
// We need a separate current-thread runtime to block_on async calls here
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
KvbmWorker::leader_readiness_sync(worker_config, cancel_for_barrier).await
})
})?;
Ok((task, handler_rx))
}
async fn run_non_blocking_layout_initialization(
config: KvbmWorkerConfig,
bytes_per_block: usize,
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage> + Send + 'static>,
layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.drt.primary_token().clone();
let scheduler_client = config.scheduler_client.clone();
// channel to get BlockTransferHandler back to the caller
let (handler_tx, handler_rx) = oneshot::channel::<transfer::BlockTransferHandler>();
// channel that the worker will use to signal layout readiness
let (layout_ready_tx, layout_ready_rx) = oneshot::channel::<String>();
// clone what we need inside the orchestrator
let worker_config = config.clone();
let cancel_token_for_task = cancel_token.clone();
// Single task that orchestrates everything in-order.
let task = CriticalTaskExecutionHandle::new(
move |ct| {
let cfg = worker_config.clone();
let scheduler = scheduler_client.clone();
async move {
// 1) barrier (must finish before worker_task starts)
let leader_data =
KvbmWorker::leader_barrier_sync(cfg.clone(), ct.clone(), bytes_per_block)
.await?;
// 2) start the long-running worker (after barrier)
// Spawn it so the orchestrator can continue with readiness + waiting.
let dev_layout = device_layout; // moved in
let lb = layout_builder; // moved in
let lt = layout_type; // moved in
let worker_fut = KvbmWorker::worker_task(
dev_layout,
lb,
leader_data,
lt,
cfg.clone(),
ct.clone(),
handler_tx,
layout_ready_tx,
scheduler,
);
// If worker_task returns Result, handle/log it inside the spawned task.
tokio::spawn(async move {
if let Err(e) = worker_fut.await {
tracing::error!("worker_task exited with error: {e:#}");
}
});
// 3) wait for the worker’s layout allocation readiness
match layout_ready_rx.await {
Ok(_) => tracing::info!("worker layout allocation finished."),
Err(_) => tracing::warn!("worker layout readiness channel dropped"),
}
// 4) wait for leader to finish its side of initialization
KvbmWorker::leader_readiness_sync(cfg.clone(), ct.clone()).await?;
Ok::<(), anyhow::Error>(())
}
},
cancel_token_for_task,
"kvbm-worker-task",
)?;
Ok((task, handler_rx))
}
/// One-time use method to extract the block transfer handler from the worker.
///
/// This is a bit of a hack. Improve the API design around this in the future.
......@@ -248,15 +430,11 @@ impl KvbmWorker {
Ok(blocks)
}
async fn worker_task(
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
mut layout_builder: LayoutConfigBuilder,
layout_type: LayoutType,
async fn leader_barrier_sync(
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
handler_tx: oneshot::Sender<BlockTransferHandler>,
scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<()> {
bytes_per_block: usize,
) -> anyhow::Result<KvbmLeaderData> {
let drt = config.drt.clone();
let worker_id = drt
......@@ -266,30 +444,61 @@ impl KvbmWorker {
))?
.id() as usize;
let barrier_id_worker_to_leader =
format!("{}{}", config.barrier_id_prefix, "-worker-to-leader");
tracing::info!(
"Worker {} waiting on barrier {}",
worker_id,
config.barrier_id
barrier_id_worker_to_leader
);
let worker_barrier = WorkerBarrier::<KvbmLeaderData, KvbmWorkerData>::new(
config.barrier_id,
let worker_to_leader_barrier = WorkerBarrier::<(), KvbmWorkerData>::new(
barrier_id_worker_to_leader,
worker_id.to_string(),
);
let worker_data = KvbmWorkerData {
num_device_blocks: config.num_device_blocks,
bytes_per_block,
};
tokio::select! {
_ = cancel_token.cancelled() => {
return Err(anyhow::anyhow!("Cancelled"))
}
_leader_data = worker_to_leader_barrier.sync(&drt, &worker_data) => {
_leader_data
}
}
.map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
tracing::debug!(
"Worker {} sent the worker data in worker to leader phase",
worker_id
);
let barrier_id_leader_to_worker =
format!("{}{}", config.barrier_id_prefix, "-leader-to-worker");
tracing::info!(
"Worker {} waiting on barrier {}",
worker_id,
barrier_id_leader_to_worker
);
let leader_to_worker_barrier = WorkerBarrier::<KvbmLeaderData, ()>::new(
barrier_id_leader_to_worker,
worker_id.to_string(),
);
let leader_data = tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(())
return Err(anyhow::anyhow!("Cancelled"))
}
leader_data = worker_barrier.sync(&drt, &worker_data) => {
leader_data = leader_to_worker_barrier.sync(&drt, &()) => {
leader_data
}
}
.map_err(|e| anyhow::anyhow!("Failed to sync worker barrier: {:?}", e))?;
.map_err(|e| anyhow::anyhow!("Failed to sync worker to leader barrier: {:?}", e))?;
tracing::info!(
"Worker {} received leader data: {:?}",
......@@ -297,6 +506,68 @@ impl KvbmWorker {
leader_data
);
Ok(leader_data)
}
async fn leader_readiness_sync(
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
) -> anyhow::Result<()> {
let drt = config.drt.clone();
let worker_id = drt
.primary_lease()
.ok_or(anyhow::anyhow!(
"unable to get primary lease; check that drt is not static"
))?
.id() as usize;
let barrier_id_leader_readiness =
format!("{}{}", config.barrier_id_prefix, "-leader-ready");
tracing::info!(
"Worker {} waiting on barrier {}",
worker_id,
barrier_id_leader_readiness
);
let leader_readiness_barrier =
WorkerBarrier::<(), ()>::new(barrier_id_leader_readiness, worker_id.to_string());
// leader_data is not important in the leader readiness case
tokio::select! {
_ = cancel_token.cancelled() => {
return Err(anyhow::anyhow!("Cancelled"))
}
_leader_data = leader_readiness_barrier.sync(&drt, &()) => {
_leader_data
}
}
.map_err(|e| anyhow::anyhow!("Failed to sync leader readiness barrier: {:?}", e))?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn worker_task(
device_layout: Box<dyn NixlLayout<StorageType = DeviceStorage>>,
mut layout_builder: LayoutConfigBuilder,
leader_data: KvbmLeaderData,
layout_type: LayoutType,
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
handler_tx: oneshot::Sender<BlockTransferHandler>,
layout_ready_tx: oneshot::Sender<String>,
scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<()> {
let drt = config.drt.clone();
let worker_id = drt
.primary_lease()
.ok_or(anyhow::anyhow!(
"unable to get primary lease; check that drt is not static"
))?
.id() as usize;
let agent = build_agent(worker_id, leader_data.num_disk_blocks > 0)?;
let transfer_context = Arc::new(TransferContext::new(
......@@ -380,6 +651,10 @@ impl KvbmWorker {
cancel_token.clone(),
)?;
if layout_ready_tx.send("finished".to_string()).is_err() {
tracing::error!("worker receiver dropped before result was sent");
}
// TODO: Some sort of fancy loop here.
// For now, just wait for cancellation.
cancel_token.cancelled().await;
......
......@@ -334,9 +334,10 @@ impl FullyContiguousConfig {
config.validate()?;
let alignment = config.alignment;
let memory_region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
let outer_dim_stride_in_bytes = memory_region_size;
let outer_dim_stride_in_bytes =
config.page_size * config.inner_dim * config.dtype_width_bytes;
let layer_stride_in_bytes = outer_dim_stride_in_bytes * config.outer_dim;
let memory_region_size = layer_stride_in_bytes;
let natural_block_stride = config.num_layers * layer_stride_in_bytes;
let block_stride_in_bytes = if alignment > 1 {
......
......@@ -456,6 +456,7 @@ impl<S: Storage, L: LocalityProvider, M: BlockMetadata> BlockPool<S, L, M>
&self,
sequence_hashes: &[SequenceHash],
) -> BlockPoolResult<ImmutableBlocks<S, L, M>> {
tracing::debug!("find matching for sequence_hashes: {:?}", sequence_hashes);
self._match_sequence_hashes(sequence_hashes)?
.blocking_recv()
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?
......
......@@ -2,11 +2,11 @@
## Overview
This suite validates determinism properties of the API-backed LLM under fixed sampling parameters and optionally across prefix cache resets. The tests can automatically start a local vLLM server, warm it up, and compare responses for identical prompts over multiple iterations.
This suite validates the determinism properties of the API-backed LLM under fixed sampling parameters and, optionally, across prefix cache resets. The tests can automatically start a local LLM server—either a vLLM server or a TensorRT-LLM server—warm it up, and compare responses for identical prompts over multiple iterations. The suite also automatically detects whether the vLLM or TensorRT-LLM wheel is installed and starts the corresponding server.
## Files
- `test_determinism.py` — comprehensive determinism tests with automatic vLLM server lifecycle and warmup.
- `test_determinism.py` — comprehensive determinism tests with automatic LLM server lifecycle and warmup.
- `test_determinism_with_cache_reset` — run test with warmup, reset cache, then run again without warmup to test determinism across cache reset boundary
- `test_concurrent_determinism_with_ifeval` — send parametrized number of IFEval prompts (default: 120) with controlled concurrency, with warmup, then reset cache and test again without warmup to validate determinism across cache reset
......@@ -19,7 +19,7 @@ This suite validates determinism properties of the API-backed LLM under fixed sa
## How It Works
- A `VLLMServerManager` fixture (`vllm_server`) launches `vllm serve` with the Dynamo connector and optional cache block overrides.
- A `LLMServerManager` fixture (`llm_server`) launches `vllm serve` or `trtllm-serve` with the Dynamo connector and optional cache block overrides.
- A `tester` fixture binds the test client to the running server's base URL.
- The test performs a comprehensive warmup across prompts, then executes repeated requests and checks that responses are identical (deterministic). An optional cache reset phase re-validates determinism across the reset boundary.
......@@ -43,8 +43,8 @@ Environment variables control server settings and test load:
- Server/model
- `KVBM_MODEL_ID` (default: `deepseek-ai/DeepSeek-R1-Distill-Llama-8B`)
- `KVBM_VLLM_PORT` (default: `8000`)
- `KVBM_VLLM_START_TIMEOUT` (default: `300` seconds)
- `KVBM_SERVER_PORT` (default: `8000`)
- `KVBM_SERVER_START_TIMEOUT` (default: `300` seconds)
- Cache size overrides
- `KVBM_CPU_BLOCKS` (used via test parametrization; default: `10000`)
......@@ -90,5 +90,5 @@ pytest -v -m "kvbm" -s
- Warmup is critical to avoid initialization effects impacting determinism.
- For faster local iteration, reduce `KVBM_MAX_ITERATIONS` and/or increase intervals.
- Logs are written under the per-test directory created by `tests/conftest.py` and include the vLLM server stdout/stderr.
- Tests use the static port defined by `KVBM_VLLM_PORT` for vLLM server communication.
\ No newline at end of file
- Logs are written under the per-test directory created by `tests/conftest.py` and include the LLM server stdout/stderr.
- Tests use the static port defined by `KVBM_SERVER_PORT` for LLM server communication.
......@@ -13,6 +13,7 @@ before validation) to avoid server initialization effects that could
impact determinism measurements.
"""
import importlib.util
import logging
import os
import signal
......@@ -21,8 +22,9 @@ import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, TextIO, Tuple
from typing import Any, Dict, List, Optional, TextIO, Tuple
import pytest
import requests
......@@ -38,8 +40,13 @@ pytestmark = [
]
class VLLMServerManager:
"""Manages vLLM server lifecycle for determinism testing."""
class ServerType(str, Enum):
vllm = "vllm"
trtllm = "trtllm"
class LLMServerManager:
"""Manages LLM server lifecycle for determinism testing."""
def __init__(
self,
......@@ -48,8 +55,10 @@ class VLLMServerManager:
cpu_cache_blocks: Optional[int] = None,
gpu_cache_blocks: Optional[int] = None,
log_dir: Optional[Path] = None,
server_type: Optional[str] = ServerType.vllm,
):
self.port = port or int(os.environ.get("KVBM_VLLM_PORT", "8000"))
self.server_type = server_type
self.port = port or int(os.environ.get("KVBM_SERVER_PORT", "8000"))
self.base_url = base_url or f"http://localhost:{self.port}"
self.process: Optional[subprocess.Popen] = None
self.cpu_cache_blocks = cpu_cache_blocks
......@@ -63,11 +72,41 @@ class VLLMServerManager:
f"cpu{cpu_cache_blocks or 'default'}_gpu{gpu_cache_blocks or 'default'}"
)
self.server_log_file = (
self.log_dir / f"vllm_server_{config_str}_{timestamp}.log"
self.log_dir / f"{self.server_type}_server_{config_str}_{timestamp}.log"
)
self.server_stdout_file: Optional[TextIO] = None
self.server_stderr_file: Optional[TextIO] = None
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
"DYN_LOG": os.environ.get(
"DYN_LOG", "debug,dynamo_llm::block_manager::layout=error"
),
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
)
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
if self.server_type == ServerType.vllm:
self._set_up_vllm_config(gpu_cache_blocks)
elif self.server_type == ServerType.trtllm:
self._set_up_trtllm_config(gpu_cache_blocks)
else:
raise ValueError(
f"{self.server_type} is not supported yet in the KVBM test suite"
)
def _set_up_vllm_config(self, gpu_cache_blocks):
self.env["VLLM_SERVER_DEV_MODE"] = "1"
# Construct serve command
self.server_cmd = [
"vllm",
......@@ -85,27 +124,52 @@ class VLLMServerManager:
if gpu_cache_blocks is not None:
self.server_cmd.extend(["--num-gpu-blocks-override", str(gpu_cache_blocks)])
# Environment for the process
self.env = os.environ.copy()
self.env.update(
{
"RUST_BACKTRACE": "1",
"DYN_LOG": os.environ.get(
"DYN_LOG", "debug,dynamo_llm::block_manager::layout=error"
),
"VLLM_SERVER_DEV_MODE": "1",
# DynamoConnector connection settings
"NATS_SERVER": "nats://localhost:4222",
"ETCD_ENDPOINTS": "http://localhost:2379",
}
def _set_up_trtllm_config(self, gpu_cache_blocks):
config_path = os.environ.get(
"KVBM_TRTLLM_LLMAPI_CONFIG_PATH", "/tmp/kvbm_llm_api_config.yaml"
)
llm_api_config: dict[str, Any] = {}
llm_api_config[
"cuda_graph_config"
] = None # explicitly disable CUDA graph since Connector API doesn't support CUDA graph yet in TRTLLM
llm_api_config["kv_cache_config"] = {
"enable_partial_reuse": False,
"free_gpu_memory_fraction": 0.10, # Set a small GPU fraction so that we can evict/reset the on-device kv cache faster
}
llm_api_config["kv_connector_config"] = {
"connector_module": "dynamo.llm.trtllm_integration.connector",
"connector_scheduler_class": "DynamoKVBMConnectorLeader",
"connector_worker_class": "DynamoKVBMConnectorWorker",
}
# CPU cache blocks override via env
if cpu_cache_blocks is not None:
self.env["DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS"] = str(cpu_cache_blocks)
# GPU blocks override
if gpu_cache_blocks is not None:
del llm_api_config["kv_cache_config"]["free_gpu_memory_fraction"]
llm_api_config["kv_cache_config"]["max_tokens"] = (
int(gpu_cache_blocks) * 32
) # TRTLLM defaults 32 tokens per block
# Construct serve command
self.server_cmd = [
"trtllm-serve",
os.environ.get("KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"),
"--host",
"localhost",
"--port",
str(self.port),
"--backend",
"pytorch",
"--extra_llm_api_options",
config_path,
]
import yaml
with open(config_path, "w") as f:
yaml.dump(llm_api_config, f, default_flow_style=False, sort_keys=False)
def start_server(self, timeout: int = 300) -> bool:
"""Start vLLM server and wait for readiness."""
"""Start LLM server and wait for readiness."""
if self.is_server_running():
self.stop_server()
time.sleep(2)
......@@ -119,7 +183,7 @@ class VLLMServerManager:
)
if self.server_stdout_file is not None:
self.server_stdout_file.write(
f"=== vLLM Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n"
f"=== {self.server_type} Server Started at {datetime.now()} ===\nCommand: {' '.join(self.server_cmd)}\n"
)
self.server_stdout_file.flush()
......@@ -147,7 +211,7 @@ class VLLMServerManager:
return False
def stop_server(self):
"""Stop vLLM server and close logs."""
"""Stop LLM server and close logs."""
if self.process:
try:
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
......@@ -205,7 +269,12 @@ class VLLMServerManager:
class DeterminismTester:
"""Test class for model determinism validation."""
def __init__(self, base_url: Optional[str] = None, model_id: Optional[str] = None):
def __init__(
self,
base_url: Optional[str] = None,
model_id: Optional[str] = None,
server_type: Optional[str] = ServerType.vllm,
):
# Allow environment override for flexibility in CI/local runs
self.base_url = (
base_url or os.environ.get("DYNAMO_API_BASE_URL") or "http://localhost:8000"
......@@ -215,6 +284,7 @@ class DeterminismTester:
or os.environ.get("KVBM_MODEL_ID")
or "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
self.server_type = server_type
self.shakespeare_file = Path("t8.shakespeare.txt")
self.max_iterations = int(os.environ.get("KVBM_MAX_ITERATIONS", "500"))
......@@ -298,11 +368,28 @@ class DeterminismTester:
def reset_prefix_cache(self):
"""Reset the prefix cache."""
print("Resetting prefix cache...")
response = requests.post(
f"{self.base_url}/reset_prefix_cache",
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")),
)
response.raise_for_status()
if self.server_type == ServerType.trtllm:
# TRTLLM doesn't support reset_prefix_cache endpoint API
# 300 shakespeare content could evict the 0.1 x 80G (~1700 blocks) on-device cache
shakespeare_count = 300
for seq_idx in range(1, shakespeare_count + 1):
start_word = (seq_idx - 1) * self.word_count
content = self.get_shakespeare_content(start_word)
if content:
print(
f"Resetting Shakespeare sequence {seq_idx} (words {start_word}-{start_word + self.word_count - 1})..."
)
try:
self.make_request(content)
except Exception as e:
print(f"Resetting request failed: {e}")
else:
response = requests.post(
f"{self.base_url}/reset_prefix_cache",
timeout=int(os.environ.get("KVBM_HTTP_TIMEOUT", "30")),
)
response.raise_for_status()
print("Cache reset done")
def warmup_server(self):
......@@ -623,11 +710,11 @@ class DeterminismTester:
@pytest.fixture(scope="function")
def vllm_server(request, runtime_services):
"""Start and stop vLLM server for each test with optional cache block overrides.
def llm_server(request, runtime_services):
"""Start and stop a LLM server for each test with optional cache block overrides.
To parametrize, use:
@pytest.mark.parametrize("vllm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True)
@pytest.mark.parametrize("llm_server", [{"cpu_blocks": 10000, "gpu_blocks": 2048}], indirect=True)
"""
logger = logging.getLogger("pytest")
logger.setLevel(logging.INFO)
......@@ -639,17 +726,27 @@ def vllm_server(request, runtime_services):
# Put logs in the per-test directory set up by tests/conftest.py
log_dir = Path(request.node.name)
server_manager = VLLMServerManager(
if importlib.util.find_spec("vllm") is not None:
server_type = ServerType.vllm
elif importlib.util.find_spec("tensorrt_llm") is not None:
server_type = ServerType.trtllm
else:
raise Exception(
"Neither the vllm nor the tensorrt_llm module is available in the current environment."
)
server_manager = LLMServerManager(
port=port,
cpu_cache_blocks=cpu_blocks,
gpu_cache_blocks=gpu_blocks,
log_dir=log_dir,
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_VLLM_START_TIMEOUT", "300"))
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start vLLM server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
)
yield server_manager
......@@ -658,9 +755,11 @@ def vllm_server(request, runtime_services):
@pytest.fixture(scope="function")
def tester(vllm_server):
def tester(llm_server):
"""Create determinism tester bound to the running server's base URL."""
t = DeterminismTester(base_url=vllm_server.base_url)
t = DeterminismTester(
base_url=llm_server.base_url, server_type=llm_server.server_type
)
t.download_shakespeare_text()
return t
......@@ -669,13 +768,13 @@ class TestDeterminism:
"""Test class for determinism validation."""
@pytest.mark.parametrize(
"vllm_server",
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
def test_determinism_with_cache_reset(self, tester, vllm_server, runtime_services):
def test_determinism_with_cache_reset(self, tester, llm_server, runtime_services):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
print("\n" + "=" * 70)
print("STARTING DETERMINISM TEST (WITH CACHE RESET)")
......@@ -797,7 +896,7 @@ class TestDeterminism:
), f"Model is not deterministic across cache reset: {total_failed} comparisons failed"
@pytest.mark.parametrize(
"vllm_server",
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "20000"))},
],
......@@ -818,7 +917,7 @@ class TestDeterminism:
def test_concurrent_determinism_with_ifeval(
self,
tester,
vllm_server,
llm_server,
runtime_services,
num_concurrent,
max_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment