"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "7a3847936672935e720d5a63be8dd61eaacde7e8"
Unverified Commit 9ca2923d authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

chore: Revert KVBM v2 transfer integration (#5406)

parent 91a8d07f
...@@ -118,11 +118,11 @@ impl TorchTensor for VllmTensor { ...@@ -118,11 +118,11 @@ impl TorchTensor for VllmTensor {
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub struct BlockTransferHandler { pub struct BlockTransferHandler {
_impl: Arc<dyn RustBlockTransferHandler>, _impl: Arc<RustBlockTransferHandler>,
} }
impl BlockTransferHandler { impl BlockTransferHandler {
pub fn get_handler(&self) -> Arc<dyn RustBlockTransferHandler> { pub fn get_handler(&self) -> Arc<RustBlockTransferHandler> {
self._impl.clone() self._impl.clone()
} }
} }
......
...@@ -9,7 +9,7 @@ mod leader; ...@@ -9,7 +9,7 @@ mod leader;
mod worker; mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig}; pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::{BlockTransferHandler, BlockTransferHandlerV1, BlockTransferHandlerV2}; pub use transfer::BlockTransferHandler;
pub use utils::{ pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType, BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
}; };
......
...@@ -11,10 +11,9 @@ use zmq::*; ...@@ -11,10 +11,9 @@ use zmq::*;
use BlockTransferPool::*; use BlockTransferPool::*;
use crate::block_manager::{ use crate::block_manager::{
Storage, BasicMetadata, Storage,
block::{ block::{
BasicMetadata, Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock,
WritableBlock,
data::local::LocalBlockData, data::local::LocalBlockData,
locality, locality,
transfer::{TransferContext, WriteTo, WriteToStrategy}, transfer::{TransferContext, WriteTo, WriteToStrategy},
...@@ -22,10 +21,6 @@ use crate::block_manager::{ ...@@ -22,10 +21,6 @@ use crate::block_manager::{
connector::scheduler::{SchedulingDecision, TransferSchedulerClient}, connector::scheduler::{SchedulingDecision, TransferSchedulerClient},
offload::MAX_TRANSFER_BATCH_SIZE, offload::MAX_TRANSFER_BATCH_SIZE,
storage::{DeviceStorage, DiskStorage, Local, PinnedStorage}, storage::{DeviceStorage, DiskStorage, Local, PinnedStorage},
v2::physical::{
layout::PhysicalLayout, manager::TransportManager, transfer::LayoutHandle,
transfer::options::TransferOptions,
},
}; };
use anyhow::Result; use anyhow::Result;
...@@ -49,9 +44,9 @@ impl ConnectorTransferBatcher { ...@@ -49,9 +44,9 @@ impl ConnectorTransferBatcher {
} }
} }
pub async fn execute_batched_transfer<T: BlockTransferDirectHandler>( pub async fn execute_batched_transfer(
&self, &self,
handler: &T, handler: &BlockTransferHandler,
request: BlockTransferRequest, request: BlockTransferRequest,
) -> Result<()> { ) -> Result<()> {
let blocks = request.blocks(); let blocks = request.blocks();
...@@ -88,21 +83,9 @@ impl ConnectorTransferBatcher { ...@@ -88,21 +83,9 @@ impl ConnectorTransferBatcher {
} }
} }
#[async_trait]
pub trait BlockTransferHandler: Send + Sync {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()>;
fn scheduler_client(&self) -> Option<TransferSchedulerClient>;
}
#[async_trait]
pub trait BlockTransferDirectHandler {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()>;
}
/// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s. /// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s.
#[derive(Clone)] #[derive(Clone)]
pub struct BlockTransferHandlerV1 { pub struct BlockTransferHandler {
device: Option<LocalBlockDataList<DeviceStorage>>, device: Option<LocalBlockDataList<DeviceStorage>>,
host: Option<LocalBlockDataList<PinnedStorage>>, host: Option<LocalBlockDataList<PinnedStorage>>,
disk: Option<LocalBlockDataList<DiskStorage>>, disk: Option<LocalBlockDataList<DiskStorage>>,
...@@ -112,46 +95,7 @@ pub struct BlockTransferHandlerV1 { ...@@ -112,46 +95,7 @@ pub struct BlockTransferHandlerV1 {
// add worker-connector scheduler client here // add worker-connector scheduler client here
} }
#[async_trait] impl BlockTransferHandler {
impl BlockTransferHandler for BlockTransferHandlerV1 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
fn scheduler_client(&self) -> Option<TransferSchedulerClient> {
self.scheduler_client.clone()
}
}
#[async_trait]
impl BlockTransferDirectHandler for BlockTransferHandlerV1 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}");
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
}
impl BlockTransferHandlerV1 {
pub fn new( pub fn new(
device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>, device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>,
host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>, host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>,
...@@ -234,94 +178,41 @@ impl BlockTransferHandlerV1 { ...@@ -234,94 +178,41 @@ impl BlockTransferHandlerV1 {
} }
} }
} }
}
#[derive(Clone)]
pub struct BlockTransferHandlerV2 {
device_handle: Option<LayoutHandle>,
host_handle: Option<LayoutHandle>,
disk_handle: Option<LayoutHandle>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
}
impl BlockTransferHandlerV2 { /// Execute transfer with batching to prevent resource exhaustion
pub fn new( pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
device_layout: Option<PhysicalLayout>,
host_layout: Option<PhysicalLayout>,
disk_layout: Option<PhysicalLayout>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
) -> Result<Self> {
Ok(Self {
device_handle: device_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
host_handle: host_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
disk_handle: disk_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
transport_manager,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
})
}
}
#[async_trait]
impl BlockTransferHandler for BlockTransferHandlerV2 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await self.batcher.execute_batched_transfer(self, request).await
} }
fn scheduler_client(&self) -> Option<TransferSchedulerClient> { /// Execute transfer directly without batching (used by the batcher)
self.scheduler_client.clone() pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
} tracing::debug!(
} "Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
#[async_trait] tracing::debug!("request: {request:#?}");
impl BlockTransferDirectHandler for BlockTransferHandlerV2 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
let (src, dst) = match (request.from_pool(), request.to_pool()) {
(Device, Host) => (self.device_handle.as_ref(), self.host_handle.as_ref()),
(Device, Disk) => (self.device_handle.as_ref(), self.disk_handle.as_ref()),
(Host, Device) => (self.host_handle.as_ref(), self.device_handle.as_ref()),
(Host, Disk) => (self.host_handle.as_ref(), self.disk_handle.as_ref()),
(Disk, Device) => (self.disk_handle.as_ref(), self.device_handle.as_ref()),
_ => return Err(anyhow::anyhow!("Invalid transfer type.")),
};
if let (Some(src), Some(dst)) = (src, dst) { let notify = match (request.from_pool(), request.to_pool()) {
let src_block_ids = request (Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
.blocks() (Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
.iter() (Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
.map(|(from, _)| *from) (Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
.collect::<Vec<_>>(); (Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
let dst_block_ids = request _ => {
.blocks() return Err(anyhow::anyhow!("Invalid transfer type."));
.iter() }
.map(|(_, to)| *to) }?;
.collect::<Vec<_>>();
self.transport_manager
.execute_transfer(
*src,
&src_block_ids,
*dst,
&dst_block_ids,
TransferOptions::default(),
)?
.await?;
} else {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
notify.await?;
Ok(()) Ok(())
} }
} }
#[async_trait] #[async_trait]
impl<T: ?Sized + BlockTransferHandler> Handler for T { impl Handler for BlockTransferHandler {
async fn handle(&self, mut message: MessageHandle) -> Result<()> { async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if message.data.len() != 1 { if message.data.len() != 1 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
...@@ -341,8 +232,10 @@ impl<T: ?Sized + BlockTransferHandler> Handler for T { ...@@ -341,8 +232,10 @@ impl<T: ?Sized + BlockTransferHandler> Handler for T {
); );
let client = self let client = self
.scheduler_client() .scheduler_client
.expect("scheduler client is required"); .as_ref()
.expect("scheduler client is required")
.clone();
let handle = client.schedule_transfer(req).await?; let handle = client.schedule_transfer(req).await?;
......
...@@ -18,12 +18,6 @@ use crate::block_manager::{ ...@@ -18,12 +18,6 @@ use crate::block_manager::{
layout::LayoutType, layout::LayoutType,
offload::{MAX_CONCURRENT_TRANSFERS, MAX_TRANSFER_BATCH_SIZE}, offload::{MAX_CONCURRENT_TRANSFERS, MAX_TRANSFER_BATCH_SIZE},
storage::{DeviceAllocator, DeviceStorage, DiskAllocator, PinnedAllocator, torch::TorchTensor}, storage::{DeviceAllocator, DeviceStorage, DiskAllocator, PinnedAllocator, torch::TorchTensor},
v2::memory::DeviceStorage as DeviceStorageV2,
v2::physical::{
layout::{BlockDimension, LayoutConfig as LayoutConfigV2, builder::PhysicalLayoutBuilder},
manager::TransportManager,
transfer::{NixlAgent as NixlAgentV2, TransferCapabilities},
},
}; };
use derive_builder::Builder; use derive_builder::Builder;
...@@ -117,162 +111,70 @@ async fn perform_allocation_and_build_handler( ...@@ -117,162 +111,70 @@ async fn perform_allocation_and_build_handler(
worker_id: usize, worker_id: usize,
device_id: usize, device_id: usize,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<Arc<dyn BlockTransferHandler>> { ) -> anyhow::Result<BlockTransferHandler> {
let use_v2_transfer = std::env::var("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL") let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?;
.unwrap_or("0".to_string()) let pool_config = PoolConfig {
.parse::<usize>() enable_pool: true,
.map(|v| v > 0) max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS,
.unwrap_or(false); max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE,
num_outer_components: device_layout.config().outer_dim,
if use_v2_transfer { num_layers: device_layout.config().num_layers,
tracing::warn!("Using V2 transfer handler. This is experimental. Use at your own risk."); };
let backends = if leader_meta.num_disk_blocks > 0 { let transfer_context = Arc::new(TransferContext::new(
vec!["POSIX", "GDS_MT"] Arc::new(Some(agent)),
} else { DeviceAllocator::new(device_id)?.ctx().new_stream()?,
vec!["POSIX"] Handle::current(),
}; Some(pool_config),
));
let agent = NixlAgentV2::new_with_backends(worker_id.to_string().as_str(), &backends)?;
// device
let mut layout_config = LayoutConfigV2::builder() let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
.num_blocks(device_layout.config().num_blocks) device_layout,
.num_layers(device_layout.config().num_layers) transfer_context.nixl_agent().as_ref(),
.outer_dim(device_layout.config().outer_dim) 0,
.inner_dim(device_layout.config().inner_dim) worker_id,
.page_size(device_layout.config().page_size) )?);
.alignment(device_layout.config().alignment) // host
.dtype_width_bytes(device_layout.config().dtype_width_bytes) let host_blocks = if leader_meta.num_host_blocks > 0 {
.build()?; let host_allocator = Arc::new(PinnedAllocator::default());
let host_layout = layout_builder
let v2_device_layout = .num_blocks(leader_meta.num_host_blocks)
PhysicalLayoutBuilder::new(agent.clone()).with_config(layout_config.clone()); .build()?
.allocate_layout(worker_config.host_layout_type, host_allocator)?;
let v2_device_layout = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
if let LayoutType::LayerSeparate { outer_contiguous } = device_layout.layout_type() {
v2_device_layout.layer_separate(if outer_contiguous {
BlockDimension::BlockIsSecondDim
} else {
BlockDimension::BlockIsFirstDim
})
} else {
v2_device_layout.fully_contiguous()
};
let regions = device_layout
.storage()
.iter()
.map(|s| DeviceStorageV2::from_v1(s).unwrap())
.collect::<Vec<_>>();
let v2_device_layout = v2_device_layout.with_memory_regions(regions)?.build()?;
let host_layout = if leader_meta.num_host_blocks > 0 {
layout_config.num_blocks = leader_meta.num_host_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config.clone())
.fully_contiguous()
.allocate_pinned(true)
.build()?,
)
} else {
None
};
let disk_layout = if leader_meta.num_disk_blocks > 0 {
layout_config.num_blocks = leader_meta.num_disk_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config)
.fully_contiguous()
.allocate_disk(None)
.build()?,
)
} else {
None
};
let transport_manager = TransportManager::builder()
.capabilities(TransferCapabilities::default().with_gds(true))
.worker_id(worker_id as u64)
.nixl_agent(agent)
.cuda_device_id(device_id)
.build()?;
let handler = BlockTransferHandlerV2::new(
Some(v2_device_layout),
host_layout, host_layout,
disk_layout, transfer_context.nixl_agent().as_ref(),
transport_manager, 1,
scheduler_client, worker_id,
)?; )?)
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
} else { } else {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?; None
let pool_config = PoolConfig { };
enable_pool: true, // disk
max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS, let disk_blocks = if leader_meta.num_disk_blocks > 0 {
max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, let disk_allocator = Arc::new(DiskAllocator);
num_outer_components: device_layout.config().outer_dim, let disk_layout = layout_builder
num_layers: device_layout.config().num_layers, .num_blocks(leader_meta.num_disk_blocks)
}; .build()?
let transfer_context = Arc::new(TransferContext::new( .allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Arc::new(Some(agent)), Some(KvbmWorker::make_layout::<_, BasicMetadata>(
DeviceAllocator::new(device_id)?.ctx().new_stream()?, disk_layout,
Handle::current(),
Some(pool_config),
));
// device
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
device_layout,
transfer_context.nixl_agent().as_ref(), transfer_context.nixl_agent().as_ref(),
0, 2,
worker_id, worker_id,
)?); )?)
// host } else {
let host_blocks = if leader_meta.num_host_blocks > 0 { None
let host_allocator = Arc::new(PinnedAllocator::default()); };
let host_layout = layout_builder
.num_blocks(leader_meta.num_host_blocks) let handler = BlockTransferHandler::new(
.build()? device_blocks,
.allocate_layout(worker_config.host_layout_type, host_allocator)?; host_blocks,
Some(KvbmWorker::make_layout::<_, BasicMetadata>( disk_blocks,
host_layout, transfer_context,
transfer_context.nixl_agent().as_ref(), scheduler_client,
1, )?;
worker_id, Ok(handler)
)?)
} else {
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
.build()?
.allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
disk_layout,
transfer_context.nixl_agent().as_ref(),
2,
worker_id,
)?)
} else {
None
};
let handler = BlockTransferHandlerV1::new(
device_blocks,
host_blocks,
disk_blocks,
transfer_context,
scheduler_client,
)?;
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
}
} }
struct WorkerMetadataHandler { struct WorkerMetadataHandler {
...@@ -297,8 +199,6 @@ impl Handler for WorkerMetadataHandler { ...@@ -297,8 +199,6 @@ impl Handler for WorkerMetadataHandler {
} }
} }
type TransferHandlerSender = Mutex<Option<oneshot::Sender<Arc<dyn BlockTransferHandler>>>>;
// Leader sends allocation config -> allocate -> publish handler -> mark ready -> ACK // Leader sends allocation config -> allocate -> publish handler -> mark ready -> ACK
struct LeaderMetadataHandler { struct LeaderMetadataHandler {
state: Arc<WorkerState>, state: Arc<WorkerState>,
...@@ -308,8 +208,8 @@ struct LeaderMetadataHandler { ...@@ -308,8 +208,8 @@ struct LeaderMetadataHandler {
worker_id: usize, worker_id: usize,
device_id: usize, device_id: usize,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>, handler_cell: Arc<RwLock<Option<BlockTransferHandler>>>,
handler_tx: Arc<TransferHandlerSender>, handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>,
started: AtomicBool, started: AtomicBool,
} }
...@@ -447,7 +347,7 @@ impl Handler for GatedPing { ...@@ -447,7 +347,7 @@ impl Handler for GatedPing {
// Transfer dispatcher that waits until block transfer handler exists // Transfer dispatcher that waits until block transfer handler exists
struct BlockTransferDispatch { struct BlockTransferDispatch {
cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>, cell: Arc<RwLock<Option<BlockTransferHandler>>>,
} }
#[async_trait] #[async_trait]
...@@ -508,7 +408,7 @@ impl KvbmWorkerConfig { ...@@ -508,7 +408,7 @@ impl KvbmWorkerConfig {
pub struct KvbmWorker { pub struct KvbmWorker {
task: Option<CriticalTaskExecutionHandle>, task: Option<CriticalTaskExecutionHandle>,
block_transfer_handler_rx: Option<oneshot::Receiver<Arc<dyn BlockTransferHandler>>>, block_transfer_handler_rx: Option<oneshot::Receiver<transfer::BlockTransferHandler>>,
} }
impl KvbmWorker { impl KvbmWorker {
...@@ -632,7 +532,7 @@ impl KvbmWorker { ...@@ -632,7 +532,7 @@ impl KvbmWorker {
layout_type: LayoutType, layout_type: LayoutType,
) -> anyhow::Result<( ) -> anyhow::Result<(
CriticalTaskExecutionHandle, CriticalTaskExecutionHandle,
oneshot::Receiver<Arc<dyn BlockTransferHandler>>, oneshot::Receiver<transfer::BlockTransferHandler>,
)> { )> {
let cancel_token = config.cancel_token.clone(); let cancel_token = config.cancel_token.clone();
...@@ -683,13 +583,13 @@ impl KvbmWorker { ...@@ -683,13 +583,13 @@ impl KvbmWorker {
layout_type: LayoutType, layout_type: LayoutType,
) -> anyhow::Result<( ) -> anyhow::Result<(
CriticalTaskExecutionHandle, CriticalTaskExecutionHandle,
oneshot::Receiver<Arc<dyn BlockTransferHandler>>, oneshot::Receiver<transfer::BlockTransferHandler>,
)> { )> {
let cancel_token = config.cancel_token.clone(); let cancel_token = config.cancel_token.clone();
let scheduler_client = config.scheduler_client.clone(); let scheduler_client = config.scheduler_client.clone();
// channel to get BlockTransferHandler back to the caller // channel to get BlockTransferHandler back to the caller
let (handler_tx, handler_rx) = oneshot::channel::<Arc<dyn BlockTransferHandler>>(); let (handler_tx, handler_rx) = oneshot::channel::<transfer::BlockTransferHandler>();
let handler_tx_cell = Arc::new(Mutex::new(Some(handler_tx))); let handler_tx_cell = Arc::new(Mutex::new(Some(handler_tx)));
// channel that the worker will use to signal layout readiness // channel that the worker will use to signal layout readiness
...@@ -752,7 +652,7 @@ impl KvbmWorker { ...@@ -752,7 +652,7 @@ impl KvbmWorker {
/// This is a bit of a hack. Improve the API design around this in the future. /// This is a bit of a hack. Improve the API design around this in the future.
pub fn block_transfer_handler_rx( pub fn block_transfer_handler_rx(
&mut self, &mut self,
) -> Option<tokio::sync::oneshot::Receiver<Arc<dyn BlockTransferHandler>>> { ) -> Option<tokio::sync::oneshot::Receiver<BlockTransferHandler>> {
self.block_transfer_handler_rx.take() self.block_transfer_handler_rx.take()
} }
...@@ -780,7 +680,7 @@ impl KvbmWorker { ...@@ -780,7 +680,7 @@ impl KvbmWorker {
_device_layout_type: LayoutType, _device_layout_type: LayoutType,
config: KvbmWorkerConfig, config: KvbmWorkerConfig,
cancel_token: CancellationToken, cancel_token: CancellationToken,
handler_tx: Arc<TransferHandlerSender>, handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>,
layout_ready_tx: tokio::sync::Mutex<Option<oneshot::Sender<String>>>, layout_ready_tx: tokio::sync::Mutex<Option<oneshot::Sender<String>>>,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
bytes_per_block: usize, bytes_per_block: usize,
...@@ -790,7 +690,7 @@ impl KvbmWorker { ...@@ -790,7 +690,7 @@ impl KvbmWorker {
let state = Arc::new(WorkerState::new()); let state = Arc::new(WorkerState::new());
// Cell to publish the transfer handler // Cell to publish the transfer handler
let transfer_handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>> = let transfer_handler_cell: Arc<RwLock<Option<BlockTransferHandler>>> =
Arc::new(RwLock::new(None)); Arc::new(RwLock::new(None));
// Build handlers map // Build handlers map
......
...@@ -337,8 +337,8 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator { ...@@ -337,8 +337,8 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator {
/// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that /// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that
/// the torch tensor is not GCed until the [`DeviceStorage`] is dropped. /// the torch tensor is not GCed until the [`DeviceStorage`] is dropped.
/// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`] /// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`]
#[derive(Clone, Debug)] #[derive(Debug)]
pub enum DeviceStorageType { enum DeviceStorageType {
Owned, // Memory that we allocated ourselves. Owned, // Memory that we allocated ourselves.
Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor. Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor.
} }
...@@ -350,7 +350,7 @@ pub struct DeviceStorage { ...@@ -350,7 +350,7 @@ pub struct DeviceStorage {
size: usize, size: usize,
ctx: Arc<CudaContext>, ctx: Arc<CudaContext>,
handles: RegistrationHandles, handles: RegistrationHandles,
storage_type: DeviceStorageType, _storage_type: DeviceStorageType,
} }
impl Local for DeviceStorage {} impl Local for DeviceStorage {}
...@@ -367,7 +367,7 @@ impl DeviceStorage { ...@@ -367,7 +367,7 @@ impl DeviceStorage {
size, size,
ctx: ctx.clone(), ctx: ctx.clone(),
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
storage_type: DeviceStorageType::Owned, _storage_type: DeviceStorageType::Owned,
}) })
} }
...@@ -395,7 +395,7 @@ impl DeviceStorage { ...@@ -395,7 +395,7 @@ impl DeviceStorage {
size, size,
ctx: ctx.clone(), ctx: ctx.clone(),
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
storage_type: DeviceStorageType::Torch { _tensor: tensor }, _storage_type: DeviceStorageType::Torch { _tensor: tensor },
}) })
} }
...@@ -403,10 +403,6 @@ impl DeviceStorage { ...@@ -403,10 +403,6 @@ impl DeviceStorage {
pub fn context(&self) -> &Arc<CudaContext> { pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx &self.ctx
} }
pub fn device_storage_type(&self) -> &DeviceStorageType {
&self.storage_type
}
} }
impl Storage for DeviceStorage { impl Storage for DeviceStorage {
...@@ -440,7 +436,7 @@ impl CudaContextProivder for DeviceStorage { ...@@ -440,7 +436,7 @@ impl CudaContextProivder for DeviceStorage {
impl Drop for DeviceStorage { impl Drop for DeviceStorage {
fn drop(&mut self) { fn drop(&mut self) {
self.handles.release(); self.handles.release();
match &self.storage_type { match &self._storage_type {
DeviceStorageType::Owned => { DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap() unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
} }
......
...@@ -3,13 +3,8 @@ ...@@ -3,13 +3,8 @@
//! CUDA device memory storage. //! CUDA device memory storage.
use crate::block_manager::DeviceStorage as V1DeviceStorage;
use crate::block_manager::Storage as V1Storage;
use crate::block_manager::storage::cuda::DeviceStorageType as V1DeviceStorageType;
use super::{MemoryRegion, Result, StorageError, StorageKind}; use super::{MemoryRegion, Result, StorageError, StorageKind};
use cudarc::driver::CudaContext; use cudarc::driver::CudaContext;
use nixl_sys::NixlDescriptor;
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
...@@ -35,8 +30,6 @@ pub struct DeviceStorage { ...@@ -35,8 +30,6 @@ pub struct DeviceStorage {
ptr: u64, ptr: u64,
device_id: u32, device_id: u32,
len: usize, len: usize,
// TODO: This is a bit ugly. We need to translate our v1 device layout to v2.
device_storage_type: V1DeviceStorageType,
} }
unsafe impl Send for DeviceStorage {} unsafe impl Send for DeviceStorage {}
...@@ -64,7 +57,6 @@ impl DeviceStorage { ...@@ -64,7 +57,6 @@ impl DeviceStorage {
ptr, ptr,
device_id, device_id,
len, len,
device_storage_type: V1DeviceStorageType::Owned,
}) })
} }
...@@ -77,51 +69,18 @@ impl DeviceStorage { ...@@ -77,51 +69,18 @@ impl DeviceStorage {
pub fn device_id(&self) -> u32 { pub fn device_id(&self) -> u32 {
self.device_id self.device_id
} }
pub fn from_v1(v1_storage: &V1DeviceStorage) -> Result<Self> {
let device_id = v1_storage.device_id() as u32;
let ctx = cuda_context(device_id)?;
let ptr;
unsafe {
ptr = v1_storage.as_ptr() as u64;
}
let len = v1_storage.size();
if !matches!(
v1_storage.device_storage_type(),
V1DeviceStorageType::Torch { .. }
) {
return Err(StorageError::Unsupported(
"Unable to convert owned device tensors.".into(),
));
}
Ok(Self {
ctx,
ptr,
device_id,
len,
device_storage_type: v1_storage.device_storage_type().clone(),
})
}
} }
impl Drop for DeviceStorage { impl Drop for DeviceStorage {
fn drop(&mut self) { fn drop(&mut self) {
match self.device_storage_type { if let Err(e) = self.ctx.bind_to_thread() {
V1DeviceStorageType::Owned => { tracing::debug!("failed to bind CUDA context for free: {e}");
if let Err(e) = self.ctx.bind_to_thread() {
tracing::debug!("failed to bind CUDA context for free: {e}");
}
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
}
}
V1DeviceStorageType::Torch { .. } => {} // Do nothing.
} }
unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
};
} }
} }
......
...@@ -228,7 +228,6 @@ markers = [ ...@@ -228,7 +228,6 @@ markers = [
"router: marks tests for router component", "router: marks tests for router component",
"planner: marks tests for planner component", "planner: marks tests for planner component",
"kvbm: marks tests for KV behavior and model determinism", "kvbm: marks tests for KV behavior and model determinism",
"kvbm_v2: marks tests using KVBM V2",
"model: model id used by a test or parameter", "model: model id used by a test or parameter",
"custom_build: marks tests that require custom builds or special setup (e.g., MoE models)", "custom_build: marks tests that require custom builds or special setup (e.g., MoE models)",
"k8s: marks tests as requiring Kubernetes", "k8s: marks tests as requiring Kubernetes",
......
...@@ -183,20 +183,6 @@ class LLMServerManager: ...@@ -183,20 +183,6 @@ class LLMServerManager:
) )
self.server_stdout_file.flush() self.server_stdout_file.flush()
# Try to download the model.
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
print("Attempting model download...")
try:
subprocess.run(
f"pip install hf_transfer && HF_HUB_ENABLE_HF_TRANSFER=1 hf download {model}",
check=True,
shell=True,
)
except subprocess.CalledProcessError:
print("Model download failed. Is this a locally stored model?")
# Launch # Launch
self.process = subprocess.Popen( self.process = subprocess.Popen(
self.server_cmd, self.server_cmd,
...@@ -349,7 +335,7 @@ def llm_server(request, runtime_services): ...@@ -349,7 +335,7 @@ def llm_server(request, runtime_services):
server_type=server_type, server_type=server_type,
) )
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600")) start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout): if not server_manager.start_server(timeout=start_timeout):
pytest.fail( pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
...@@ -390,24 +376,6 @@ class TestDeterminismAgg(BaseTestDeterminism): ...@@ -390,24 +376,6 @@ class TestDeterminismAgg(BaseTestDeterminism):
tester, llm_server, runtime_services tester, llm_server, runtime_services
) )
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_agg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester, llm_server, runtime_services
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"llm_server", "llm_server",
[ [
......
...@@ -430,7 +430,7 @@ def llm_server(request, runtime_services): ...@@ -430,7 +430,7 @@ def llm_server(request, runtime_services):
server_type=server_type, server_type=server_type,
) )
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600")) start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300"))
if not server_manager.start_server(timeout=start_timeout): if not server_manager.start_server(timeout=start_timeout):
pytest.fail( pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
...@@ -477,30 +477,6 @@ class TestDeterminismDisagg(BaseTestDeterminism): ...@@ -477,30 +477,6 @@ class TestDeterminismDisagg(BaseTestDeterminism):
success_rate_threshold=SUCCESS_RATE_THRESHOLD, success_rate_threshold=SUCCESS_RATE_THRESHOLD,
) )
@pytest.mark.parametrize(
"llm_server",
[
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "1000")),
},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_disagg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester,
llm_server,
runtime_services,
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
if __name__ == "__main__": if __name__ == "__main__":
# Allow running as script # Allow running as script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment