Unverified Commit 9ca2923d authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

chore: Revert KVBM v2 transfer integration (#5406)

parent 91a8d07f
......@@ -118,11 +118,11 @@ impl TorchTensor for VllmTensor {
#[pyclass]
#[derive(Clone)]
pub struct BlockTransferHandler {
_impl: Arc<dyn RustBlockTransferHandler>,
_impl: Arc<RustBlockTransferHandler>,
}
impl BlockTransferHandler {
pub fn get_handler(&self) -> Arc<dyn RustBlockTransferHandler> {
pub fn get_handler(&self) -> Arc<RustBlockTransferHandler> {
self._impl.clone()
}
}
......
......@@ -9,7 +9,7 @@ mod leader;
mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::{BlockTransferHandler, BlockTransferHandlerV1, BlockTransferHandlerV2};
pub use transfer::BlockTransferHandler;
pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
};
......
......@@ -11,10 +11,9 @@ use zmq::*;
use BlockTransferPool::*;
use crate::block_manager::{
Storage,
BasicMetadata, Storage,
block::{
BasicMetadata, Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock,
WritableBlock,
Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock,
data::local::LocalBlockData,
locality,
transfer::{TransferContext, WriteTo, WriteToStrategy},
......@@ -22,10 +21,6 @@ use crate::block_manager::{
connector::scheduler::{SchedulingDecision, TransferSchedulerClient},
offload::MAX_TRANSFER_BATCH_SIZE,
storage::{DeviceStorage, DiskStorage, Local, PinnedStorage},
v2::physical::{
layout::PhysicalLayout, manager::TransportManager, transfer::LayoutHandle,
transfer::options::TransferOptions,
},
};
use anyhow::Result;
......@@ -49,9 +44,9 @@ impl ConnectorTransferBatcher {
}
}
pub async fn execute_batched_transfer<T: BlockTransferDirectHandler>(
pub async fn execute_batched_transfer(
&self,
handler: &T,
handler: &BlockTransferHandler,
request: BlockTransferRequest,
) -> Result<()> {
let blocks = request.blocks();
......@@ -88,21 +83,9 @@ impl ConnectorTransferBatcher {
}
}
#[async_trait]
pub trait BlockTransferHandler: Send + Sync {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()>;
fn scheduler_client(&self) -> Option<TransferSchedulerClient>;
}
#[async_trait]
pub trait BlockTransferDirectHandler {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()>;
}
/// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s.
#[derive(Clone)]
pub struct BlockTransferHandlerV1 {
pub struct BlockTransferHandler {
device: Option<LocalBlockDataList<DeviceStorage>>,
host: Option<LocalBlockDataList<PinnedStorage>>,
disk: Option<LocalBlockDataList<DiskStorage>>,
......@@ -112,46 +95,7 @@ pub struct BlockTransferHandlerV1 {
// add worker-connector scheduler client here
}
#[async_trait]
impl BlockTransferHandler for BlockTransferHandlerV1 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
fn scheduler_client(&self) -> Option<TransferSchedulerClient> {
self.scheduler_client.clone()
}
}
#[async_trait]
impl BlockTransferDirectHandler for BlockTransferHandlerV1 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}");
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
}
impl BlockTransferHandlerV1 {
impl BlockTransferHandler {
pub fn new(
device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>,
host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>,
......@@ -234,94 +178,41 @@ impl BlockTransferHandlerV1 {
}
}
}
}
#[derive(Clone)]
pub struct BlockTransferHandlerV2 {
device_handle: Option<LayoutHandle>,
host_handle: Option<LayoutHandle>,
disk_handle: Option<LayoutHandle>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
}
impl BlockTransferHandlerV2 {
pub fn new(
device_layout: Option<PhysicalLayout>,
host_layout: Option<PhysicalLayout>,
disk_layout: Option<PhysicalLayout>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
) -> Result<Self> {
Ok(Self {
device_handle: device_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
host_handle: host_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
disk_handle: disk_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
transport_manager,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
})
}
}
#[async_trait]
impl BlockTransferHandler for BlockTransferHandlerV2 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
/// Execute transfer with batching to prevent resource exhaustion
pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
fn scheduler_client(&self) -> Option<TransferSchedulerClient> {
self.scheduler_client.clone()
}
}
/// Execute transfer directly without batching (used by the batcher)
pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
#[async_trait]
impl BlockTransferDirectHandler for BlockTransferHandlerV2 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
let (src, dst) = match (request.from_pool(), request.to_pool()) {
(Device, Host) => (self.device_handle.as_ref(), self.host_handle.as_ref()),
(Device, Disk) => (self.device_handle.as_ref(), self.disk_handle.as_ref()),
(Host, Device) => (self.host_handle.as_ref(), self.device_handle.as_ref()),
(Host, Disk) => (self.host_handle.as_ref(), self.disk_handle.as_ref()),
(Disk, Device) => (self.disk_handle.as_ref(), self.device_handle.as_ref()),
_ => return Err(anyhow::anyhow!("Invalid transfer type.")),
};
tracing::debug!("request: {request:#?}");
if let (Some(src), Some(dst)) = (src, dst) {
let src_block_ids = request
.blocks()
.iter()
.map(|(from, _)| *from)
.collect::<Vec<_>>();
let dst_block_ids = request
.blocks()
.iter()
.map(|(_, to)| *to)
.collect::<Vec<_>>();
self.transport_manager
.execute_transfer(
*src,
&src_block_ids,
*dst,
&dst_block_ids,
TransferOptions::default(),
)?
.await?;
} else {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
}
#[async_trait]
impl<T: ?Sized + BlockTransferHandler> Handler for T {
impl Handler for BlockTransferHandler {
async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if message.data.len() != 1 {
return Err(anyhow::anyhow!(
......@@ -341,8 +232,10 @@ impl<T: ?Sized + BlockTransferHandler> Handler for T {
);
let client = self
.scheduler_client()
.expect("scheduler client is required");
.scheduler_client
.as_ref()
.expect("scheduler client is required")
.clone();
let handle = client.schedule_transfer(req).await?;
......
......@@ -18,12 +18,6 @@ use crate::block_manager::{
layout::LayoutType,
offload::{MAX_CONCURRENT_TRANSFERS, MAX_TRANSFER_BATCH_SIZE},
storage::{DeviceAllocator, DeviceStorage, DiskAllocator, PinnedAllocator, torch::TorchTensor},
v2::memory::DeviceStorage as DeviceStorageV2,
v2::physical::{
layout::{BlockDimension, LayoutConfig as LayoutConfigV2, builder::PhysicalLayoutBuilder},
manager::TransportManager,
transfer::{NixlAgent as NixlAgentV2, TransferCapabilities},
},
};
use derive_builder::Builder;
......@@ -117,162 +111,70 @@ async fn perform_allocation_and_build_handler(
worker_id: usize,
device_id: usize,
scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<Arc<dyn BlockTransferHandler>> {
let use_v2_transfer = std::env::var("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL")
.unwrap_or("0".to_string())
.parse::<usize>()
.map(|v| v > 0)
.unwrap_or(false);
if use_v2_transfer {
tracing::warn!("Using V2 transfer handler. This is experimental. Use at your own risk.");
let backends = if leader_meta.num_disk_blocks > 0 {
vec!["POSIX", "GDS_MT"]
} else {
vec!["POSIX"]
};
let agent = NixlAgentV2::new_with_backends(worker_id.to_string().as_str(), &backends)?;
let mut layout_config = LayoutConfigV2::builder()
.num_blocks(device_layout.config().num_blocks)
.num_layers(device_layout.config().num_layers)
.outer_dim(device_layout.config().outer_dim)
.inner_dim(device_layout.config().inner_dim)
.page_size(device_layout.config().page_size)
.alignment(device_layout.config().alignment)
.dtype_width_bytes(device_layout.config().dtype_width_bytes)
.build()?;
let v2_device_layout =
PhysicalLayoutBuilder::new(agent.clone()).with_config(layout_config.clone());
let v2_device_layout =
if let LayoutType::LayerSeparate { outer_contiguous } = device_layout.layout_type() {
v2_device_layout.layer_separate(if outer_contiguous {
BlockDimension::BlockIsSecondDim
} else {
BlockDimension::BlockIsFirstDim
})
} else {
v2_device_layout.fully_contiguous()
};
let regions = device_layout
.storage()
.iter()
.map(|s| DeviceStorageV2::from_v1(s).unwrap())
.collect::<Vec<_>>();
let v2_device_layout = v2_device_layout.with_memory_regions(regions)?.build()?;
let host_layout = if leader_meta.num_host_blocks > 0 {
layout_config.num_blocks = leader_meta.num_host_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config.clone())
.fully_contiguous()
.allocate_pinned(true)
.build()?,
)
} else {
None
};
let disk_layout = if leader_meta.num_disk_blocks > 0 {
layout_config.num_blocks = leader_meta.num_disk_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config)
.fully_contiguous()
.allocate_disk(None)
.build()?,
)
} else {
None
};
let transport_manager = TransportManager::builder()
.capabilities(TransferCapabilities::default().with_gds(true))
.worker_id(worker_id as u64)
.nixl_agent(agent)
.cuda_device_id(device_id)
.build()?;
let handler = BlockTransferHandlerV2::new(
Some(v2_device_layout),
) -> anyhow::Result<BlockTransferHandler> {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?;
let pool_config = PoolConfig {
enable_pool: true,
max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS,
max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE,
num_outer_components: device_layout.config().outer_dim,
num_layers: device_layout.config().num_layers,
};
let transfer_context = Arc::new(TransferContext::new(
Arc::new(Some(agent)),
DeviceAllocator::new(device_id)?.ctx().new_stream()?,
Handle::current(),
Some(pool_config),
));
// device
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
device_layout,
transfer_context.nixl_agent().as_ref(),
0,
worker_id,
)?);
// host
let host_blocks = if leader_meta.num_host_blocks > 0 {
let host_allocator = Arc::new(PinnedAllocator::default());
let host_layout = layout_builder
.num_blocks(leader_meta.num_host_blocks)
.build()?
.allocate_layout(worker_config.host_layout_type, host_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
host_layout,
disk_layout,
transport_manager,
scheduler_client,
)?;
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
transfer_context.nixl_agent().as_ref(),
1,
worker_id,
)?)
} else {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?;
let pool_config = PoolConfig {
enable_pool: true,
max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS,
max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE,
num_outer_components: device_layout.config().outer_dim,
num_layers: device_layout.config().num_layers,
};
let transfer_context = Arc::new(TransferContext::new(
Arc::new(Some(agent)),
DeviceAllocator::new(device_id)?.ctx().new_stream()?,
Handle::current(),
Some(pool_config),
));
// device
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
device_layout,
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
.build()?
.allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
disk_layout,
transfer_context.nixl_agent().as_ref(),
0,
2,
worker_id,
)?);
// host
let host_blocks = if leader_meta.num_host_blocks > 0 {
let host_allocator = Arc::new(PinnedAllocator::default());
let host_layout = layout_builder
.num_blocks(leader_meta.num_host_blocks)
.build()?
.allocate_layout(worker_config.host_layout_type, host_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
host_layout,
transfer_context.nixl_agent().as_ref(),
1,
worker_id,
)?)
} else {
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
.build()?
.allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
disk_layout,
transfer_context.nixl_agent().as_ref(),
2,
worker_id,
)?)
} else {
None
};
let handler = BlockTransferHandlerV1::new(
device_blocks,
host_blocks,
disk_blocks,
transfer_context,
scheduler_client,
)?;
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
}
)?)
} else {
None
};
let handler = BlockTransferHandler::new(
device_blocks,
host_blocks,
disk_blocks,
transfer_context,
scheduler_client,
)?;
Ok(handler)
}
struct WorkerMetadataHandler {
......@@ -297,8 +199,6 @@ impl Handler for WorkerMetadataHandler {
}
}
type TransferHandlerSender = Mutex<Option<oneshot::Sender<Arc<dyn BlockTransferHandler>>>>;
// Leader sends allocation config -> allocate -> publish handler -> mark ready -> ACK
struct LeaderMetadataHandler {
state: Arc<WorkerState>,
......@@ -308,8 +208,8 @@ struct LeaderMetadataHandler {
worker_id: usize,
device_id: usize,
scheduler_client: Option<TransferSchedulerClient>,
handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>,
handler_tx: Arc<TransferHandlerSender>,
handler_cell: Arc<RwLock<Option<BlockTransferHandler>>>,
handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>,
started: AtomicBool,
}
......@@ -447,7 +347,7 @@ impl Handler for GatedPing {
// Transfer dispatcher that waits until block transfer handler exists
struct BlockTransferDispatch {
cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>,
cell: Arc<RwLock<Option<BlockTransferHandler>>>,
}
#[async_trait]
......@@ -508,7 +408,7 @@ impl KvbmWorkerConfig {
pub struct KvbmWorker {
task: Option<CriticalTaskExecutionHandle>,
block_transfer_handler_rx: Option<oneshot::Receiver<Arc<dyn BlockTransferHandler>>>,
block_transfer_handler_rx: Option<oneshot::Receiver<transfer::BlockTransferHandler>>,
}
impl KvbmWorker {
......@@ -632,7 +532,7 @@ impl KvbmWorker {
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<Arc<dyn BlockTransferHandler>>,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.cancel_token.clone();
......@@ -683,13 +583,13 @@ impl KvbmWorker {
layout_type: LayoutType,
) -> anyhow::Result<(
CriticalTaskExecutionHandle,
oneshot::Receiver<Arc<dyn BlockTransferHandler>>,
oneshot::Receiver<transfer::BlockTransferHandler>,
)> {
let cancel_token = config.cancel_token.clone();
let scheduler_client = config.scheduler_client.clone();
// channel to get BlockTransferHandler back to the caller
let (handler_tx, handler_rx) = oneshot::channel::<Arc<dyn BlockTransferHandler>>();
let (handler_tx, handler_rx) = oneshot::channel::<transfer::BlockTransferHandler>();
let handler_tx_cell = Arc::new(Mutex::new(Some(handler_tx)));
// channel that the worker will use to signal layout readiness
......@@ -752,7 +652,7 @@ impl KvbmWorker {
/// This is a bit of a hack. Improve the API design around this in the future.
pub fn block_transfer_handler_rx(
&mut self,
) -> Option<tokio::sync::oneshot::Receiver<Arc<dyn BlockTransferHandler>>> {
) -> Option<tokio::sync::oneshot::Receiver<BlockTransferHandler>> {
self.block_transfer_handler_rx.take()
}
......@@ -780,7 +680,7 @@ impl KvbmWorker {
_device_layout_type: LayoutType,
config: KvbmWorkerConfig,
cancel_token: CancellationToken,
handler_tx: Arc<TransferHandlerSender>,
handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>,
layout_ready_tx: tokio::sync::Mutex<Option<oneshot::Sender<String>>>,
scheduler_client: Option<TransferSchedulerClient>,
bytes_per_block: usize,
......@@ -790,7 +690,7 @@ impl KvbmWorker {
let state = Arc::new(WorkerState::new());
// Cell to publish the transfer handler
let transfer_handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>> =
let transfer_handler_cell: Arc<RwLock<Option<BlockTransferHandler>>> =
Arc::new(RwLock::new(None));
// Build handlers map
......
......@@ -337,8 +337,8 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator {
/// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that
/// the torch tensor is not GCed until the [`DeviceStorage`] is dropped.
/// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`]
#[derive(Clone, Debug)]
pub enum DeviceStorageType {
#[derive(Debug)]
enum DeviceStorageType {
Owned, // Memory that we allocated ourselves.
Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor.
}
......@@ -350,7 +350,7 @@ pub struct DeviceStorage {
size: usize,
ctx: Arc<CudaContext>,
handles: RegistrationHandles,
storage_type: DeviceStorageType,
_storage_type: DeviceStorageType,
}
impl Local for DeviceStorage {}
......@@ -367,7 +367,7 @@ impl DeviceStorage {
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
storage_type: DeviceStorageType::Owned,
_storage_type: DeviceStorageType::Owned,
})
}
......@@ -395,7 +395,7 @@ impl DeviceStorage {
size,
ctx: ctx.clone(),
handles: RegistrationHandles::new(),
storage_type: DeviceStorageType::Torch { _tensor: tensor },
_storage_type: DeviceStorageType::Torch { _tensor: tensor },
})
}
......@@ -403,10 +403,6 @@ impl DeviceStorage {
pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx
}
pub fn device_storage_type(&self) -> &DeviceStorageType {
&self.storage_type
}
}
impl Storage for DeviceStorage {
......@@ -440,7 +436,7 @@ impl CudaContextProivder for DeviceStorage {
impl Drop for DeviceStorage {
fn drop(&mut self) {
self.handles.release();
match &self.storage_type {
match &self._storage_type {
DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
}
......
......@@ -3,13 +3,8 @@
//! CUDA device memory storage.
use crate::block_manager::DeviceStorage as V1DeviceStorage;
use crate::block_manager::Storage as V1Storage;
use crate::block_manager::storage::cuda::DeviceStorageType as V1DeviceStorageType;
use super::{MemoryRegion, Result, StorageError, StorageKind};
use cudarc::driver::CudaContext;
use nixl_sys::NixlDescriptor;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
......@@ -35,8 +30,6 @@ pub struct DeviceStorage {
ptr: u64,
device_id: u32,
len: usize,
// TODO: This is a bit ugly. We need to translate our v1 device layout to v2.
device_storage_type: V1DeviceStorageType,
}
unsafe impl Send for DeviceStorage {}
......@@ -64,7 +57,6 @@ impl DeviceStorage {
ptr,
device_id,
len,
device_storage_type: V1DeviceStorageType::Owned,
})
}
......@@ -77,51 +69,18 @@ impl DeviceStorage {
pub fn device_id(&self) -> u32 {
self.device_id
}
pub fn from_v1(v1_storage: &V1DeviceStorage) -> Result<Self> {
let device_id = v1_storage.device_id() as u32;
let ctx = cuda_context(device_id)?;
let ptr;
unsafe {
ptr = v1_storage.as_ptr() as u64;
}
let len = v1_storage.size();
if !matches!(
v1_storage.device_storage_type(),
V1DeviceStorageType::Torch { .. }
) {
return Err(StorageError::Unsupported(
"Unable to convert owned device tensors.".into(),
));
}
Ok(Self {
ctx,
ptr,
device_id,
len,
device_storage_type: v1_storage.device_storage_type().clone(),
})
}
}
impl Drop for DeviceStorage {
fn drop(&mut self) {
match self.device_storage_type {
V1DeviceStorageType::Owned => {
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
}
}
V1DeviceStorageType::Torch { .. } => {} // Do nothing.
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
};
}
}
......
......@@ -228,7 +228,6 @@ markers = [
"router: marks tests for router component",
"planner: marks tests for planner component",
"kvbm: marks tests for KV behavior and model determinism",
"kvbm_v2: marks tests using KVBM V2",
"model: model id used by a test or parameter",
"custom_build: marks tests that require custom builds or special setup (e.g., MoE models)",
"k8s: marks tests as requiring Kubernetes",
......
......@@ -183,20 +183,6 @@ class LLMServerManager:
)
self.server_stdout_file.flush()
# Try to download the model.
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
print("Attempting model download...")
try:
subprocess.run(
f"pip install hf_transfer && HF_HUB_ENABLE_HF_TRANSFER=1 hf download {model}",
check=True,
shell=True,
)
except subprocess.CalledProcessError:
print("Model download failed. Is this a locally stored model?")
# Launch
self.process = subprocess.Popen(
self.server_cmd,
......@@ -349,7 +335,7 @@ def llm_server(request, runtime_services):
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600"))
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
......@@ -390,24 +376,6 @@ class TestDeterminismAgg(BaseTestDeterminism):
tester, llm_server, runtime_services
)
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_agg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester, llm_server, runtime_services
)
@pytest.mark.parametrize(
"llm_server",
[
......
......@@ -430,7 +430,7 @@ def llm_server(request, runtime_services):
server_type=server_type,
)
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600"))
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout):
pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
......@@ -477,30 +477,6 @@ class TestDeterminismDisagg(BaseTestDeterminism):
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
@pytest.mark.parametrize(
"llm_server",
[
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "1000")),
},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_disagg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester,
llm_server,
runtime_services,
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
if __name__ == "__main__":
# Allow running as script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment