"docs/vscode:/vscode.git/clone" did not exist on "ba51aea65e0d2a7afca3f25caba01500fa84d655"
Unverified Commit 827b8c3e authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KVBM V2 transfer (#4068)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent 38242c8d
...@@ -118,11 +118,11 @@ impl TorchTensor for VllmTensor { ...@@ -118,11 +118,11 @@ impl TorchTensor for VllmTensor {
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub struct BlockTransferHandler { pub struct BlockTransferHandler {
_impl: Arc<RustBlockTransferHandler>, _impl: Arc<dyn RustBlockTransferHandler>,
} }
impl BlockTransferHandler { impl BlockTransferHandler {
pub fn get_handler(&self) -> Arc<RustBlockTransferHandler> { pub fn get_handler(&self) -> Arc<dyn RustBlockTransferHandler> {
self._impl.clone() self._impl.clone()
} }
} }
......
...@@ -9,7 +9,7 @@ mod leader; ...@@ -9,7 +9,7 @@ mod leader;
mod worker; mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig}; pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::BlockTransferHandler; pub use transfer::{BlockTransferHandler, BlockTransferHandlerV1, BlockTransferHandlerV2};
pub use utils::{ pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType, BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
}; };
......
...@@ -11,9 +11,10 @@ use zmq::*; ...@@ -11,9 +11,10 @@ use zmq::*;
use BlockTransferPool::*; use BlockTransferPool::*;
use crate::block_manager::{ use crate::block_manager::{
BasicMetadata, Storage, Storage,
block::{ block::{
Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock, BasicMetadata, Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock,
WritableBlock,
data::local::LocalBlockData, data::local::LocalBlockData,
locality, locality,
transfer::{TransferContext, WriteTo, WriteToStrategy}, transfer::{TransferContext, WriteTo, WriteToStrategy},
...@@ -21,6 +22,10 @@ use crate::block_manager::{ ...@@ -21,6 +22,10 @@ use crate::block_manager::{
connector::scheduler::{SchedulingDecision, TransferSchedulerClient}, connector::scheduler::{SchedulingDecision, TransferSchedulerClient},
offload::MAX_TRANSFER_BATCH_SIZE, offload::MAX_TRANSFER_BATCH_SIZE,
storage::{DeviceStorage, DiskStorage, Local, PinnedStorage}, storage::{DeviceStorage, DiskStorage, Local, PinnedStorage},
v2::physical::{
layout::PhysicalLayout, manager::TransportManager, transfer::LayoutHandle,
transfer::options::TransferOptions,
},
}; };
use anyhow::Result; use anyhow::Result;
...@@ -44,9 +49,9 @@ impl ConnectorTransferBatcher { ...@@ -44,9 +49,9 @@ impl ConnectorTransferBatcher {
} }
} }
pub async fn execute_batched_transfer( pub async fn execute_batched_transfer<T: BlockTransferDirectHandler>(
&self, &self,
handler: &BlockTransferHandler, handler: &T,
request: BlockTransferRequest, request: BlockTransferRequest,
) -> Result<()> { ) -> Result<()> {
let blocks = request.blocks(); let blocks = request.blocks();
...@@ -83,9 +88,21 @@ impl ConnectorTransferBatcher { ...@@ -83,9 +88,21 @@ impl ConnectorTransferBatcher {
} }
} }
#[async_trait]
pub trait BlockTransferHandler: Send + Sync {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()>;
fn scheduler_client(&self) -> Option<TransferSchedulerClient>;
}
#[async_trait]
pub trait BlockTransferDirectHandler {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()>;
}
/// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s. /// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s.
#[derive(Clone)] #[derive(Clone)]
pub struct BlockTransferHandler { pub struct BlockTransferHandlerV1 {
device: Option<LocalBlockDataList<DeviceStorage>>, device: Option<LocalBlockDataList<DeviceStorage>>,
host: Option<LocalBlockDataList<PinnedStorage>>, host: Option<LocalBlockDataList<PinnedStorage>>,
disk: Option<LocalBlockDataList<DiskStorage>>, disk: Option<LocalBlockDataList<DiskStorage>>,
...@@ -95,7 +112,46 @@ pub struct BlockTransferHandler { ...@@ -95,7 +112,46 @@ pub struct BlockTransferHandler {
// add worker-connector scheduler client here // add worker-connector scheduler client here
} }
impl BlockTransferHandler { #[async_trait]
impl BlockTransferHandler for BlockTransferHandlerV1 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await
}
fn scheduler_client(&self) -> Option<TransferSchedulerClient> {
self.scheduler_client.clone()
}
}
#[async_trait]
impl BlockTransferDirectHandler for BlockTransferHandlerV1 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}");
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await,
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await,
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await,
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await,
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await,
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
Ok(())
}
}
impl BlockTransferHandlerV1 {
pub fn new( pub fn new(
device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>, device_blocks: Option<Vec<LocalBlock<DeviceStorage, BasicMetadata>>>,
host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>, host_blocks: Option<Vec<LocalBlock<PinnedStorage, BasicMetadata>>>,
...@@ -178,41 +234,94 @@ impl BlockTransferHandler { ...@@ -178,41 +234,94 @@ impl BlockTransferHandler {
} }
} }
} }
}
#[derive(Clone)]
pub struct BlockTransferHandlerV2 {
device_handle: Option<LayoutHandle>,
host_handle: Option<LayoutHandle>,
disk_handle: Option<LayoutHandle>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
}
/// Execute transfer with batching to prevent resource exhaustion impl BlockTransferHandlerV2 {
pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> { pub fn new(
device_layout: Option<PhysicalLayout>,
host_layout: Option<PhysicalLayout>,
disk_layout: Option<PhysicalLayout>,
transport_manager: TransportManager,
scheduler_client: Option<TransferSchedulerClient>,
) -> Result<Self> {
Ok(Self {
device_handle: device_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
host_handle: host_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
disk_handle: disk_layout
.map(|layout| transport_manager.register_layout(layout).unwrap()),
transport_manager,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
})
}
}
#[async_trait]
impl BlockTransferHandler for BlockTransferHandlerV2 {
async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> {
self.batcher.execute_batched_transfer(self, request).await self.batcher.execute_batched_transfer(self, request).await
} }
/// Execute transfer directly without batching (used by the batcher) fn scheduler_client(&self) -> Option<TransferSchedulerClient> {
pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> { self.scheduler_client.clone()
tracing::debug!( }
"Performing transfer of {} blocks from {:?} to {:?}", }
request.blocks().len(),
request.from_pool(),
request.to_pool()
);
tracing::debug!("request: {request:#?}"); #[async_trait]
impl BlockTransferDirectHandler for BlockTransferHandlerV2 {
async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
let (src, dst) = match (request.from_pool(), request.to_pool()) {
(Device, Host) => (self.device_handle.as_ref(), self.host_handle.as_ref()),
(Device, Disk) => (self.device_handle.as_ref(), self.disk_handle.as_ref()),
(Host, Device) => (self.host_handle.as_ref(), self.device_handle.as_ref()),
(Host, Disk) => (self.host_handle.as_ref(), self.disk_handle.as_ref()),
(Disk, Device) => (self.disk_handle.as_ref(), self.device_handle.as_ref()),
_ => return Err(anyhow::anyhow!("Invalid transfer type.")),
};
let notify = match (request.from_pool(), request.to_pool()) { if let (Some(src), Some(dst)) = (src, dst) {
(Device, Host) => self.begin_transfer(&self.device, &self.host, request).await, let src_block_ids = request
(Device, Disk) => self.begin_transfer(&self.device, &self.disk, request).await, .blocks()
(Host, Device) => self.begin_transfer(&self.host, &self.device, request).await, .iter()
(Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await, .map(|(from, _)| *from)
(Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await, .collect::<Vec<_>>();
_ => { let dst_block_ids = request
return Err(anyhow::anyhow!("Invalid transfer type.")); .blocks()
} .iter()
}?; .map(|(_, to)| *to)
.collect::<Vec<_>>();
self.transport_manager
.execute_transfer(
*src,
&src_block_ids,
*dst,
&dst_block_ids,
TransferOptions::default(),
)?
.await?;
} else {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
notify.await?;
Ok(()) Ok(())
} }
} }
#[async_trait] #[async_trait]
impl Handler for BlockTransferHandler { impl<T: ?Sized + BlockTransferHandler> Handler for T {
async fn handle(&self, mut message: MessageHandle) -> Result<()> { async fn handle(&self, mut message: MessageHandle) -> Result<()> {
if message.data.len() != 1 { if message.data.len() != 1 {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
...@@ -232,10 +341,8 @@ impl Handler for BlockTransferHandler { ...@@ -232,10 +341,8 @@ impl Handler for BlockTransferHandler {
); );
let client = self let client = self
.scheduler_client .scheduler_client()
.as_ref() .expect("scheduler client is required");
.expect("scheduler client is required")
.clone();
let handle = client.schedule_transfer(req).await?; let handle = client.schedule_transfer(req).await?;
......
...@@ -18,6 +18,12 @@ use crate::block_manager::{ ...@@ -18,6 +18,12 @@ use crate::block_manager::{
layout::LayoutType, layout::LayoutType,
offload::{MAX_CONCURRENT_TRANSFERS, MAX_TRANSFER_BATCH_SIZE}, offload::{MAX_CONCURRENT_TRANSFERS, MAX_TRANSFER_BATCH_SIZE},
storage::{DeviceAllocator, DeviceStorage, DiskAllocator, PinnedAllocator, torch::TorchTensor}, storage::{DeviceAllocator, DeviceStorage, DiskAllocator, PinnedAllocator, torch::TorchTensor},
v2::memory::DeviceStorage as DeviceStorageV2,
v2::physical::{
layout::{BlockDimension, LayoutConfig as LayoutConfigV2, builder::PhysicalLayoutBuilder},
manager::TransportManager,
transfer::{NixlAgent as NixlAgentV2, TransferCapabilities},
},
}; };
use derive_builder::Builder; use derive_builder::Builder;
...@@ -111,70 +117,162 @@ async fn perform_allocation_and_build_handler( ...@@ -111,70 +117,162 @@ async fn perform_allocation_and_build_handler(
worker_id: usize, worker_id: usize,
device_id: usize, device_id: usize,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<BlockTransferHandler> { ) -> anyhow::Result<Arc<dyn BlockTransferHandler>> {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?; let use_v2_transfer = std::env::var("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL")
let pool_config = PoolConfig { .unwrap_or("0".to_string())
enable_pool: true, .parse::<usize>()
max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS, .map(|v| v > 0)
max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE, .unwrap_or(false);
num_outer_components: device_layout.config().outer_dim,
num_layers: device_layout.config().num_layers, if use_v2_transfer {
}; tracing::warn!("Using V2 transfer handler. This is experimental. Use at your own risk.");
let transfer_context = Arc::new(TransferContext::new( let backends = if leader_meta.num_disk_blocks > 0 {
Arc::new(Some(agent)), vec!["POSIX", "GDS_MT"]
DeviceAllocator::new(device_id)?.ctx().new_stream()?, } else {
Handle::current(), vec!["POSIX"]
Some(pool_config), };
));
let agent = NixlAgentV2::new_with_backends(worker_id.to_string().as_str(), &backends)?;
// device
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>( let mut layout_config = LayoutConfigV2::builder()
device_layout, .num_blocks(device_layout.config().num_blocks)
transfer_context.nixl_agent().as_ref(), .num_layers(device_layout.config().num_layers)
0, .outer_dim(device_layout.config().outer_dim)
worker_id, .inner_dim(device_layout.config().inner_dim)
)?); .page_size(device_layout.config().page_size)
// host .alignment(device_layout.config().alignment)
let host_blocks = if leader_meta.num_host_blocks > 0 { .dtype_width_bytes(device_layout.config().dtype_width_bytes)
let host_allocator = Arc::new(PinnedAllocator::default()); .build()?;
let host_layout = layout_builder
.num_blocks(leader_meta.num_host_blocks) let v2_device_layout =
.build()? PhysicalLayoutBuilder::new(agent.clone()).with_config(layout_config.clone());
.allocate_layout(worker_config.host_layout_type, host_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>( let v2_device_layout =
if let LayoutType::LayerSeparate { outer_contiguous } = device_layout.layout_type() {
v2_device_layout.layer_separate(if outer_contiguous {
BlockDimension::BlockIsSecondDim
} else {
BlockDimension::BlockIsFirstDim
})
} else {
v2_device_layout.fully_contiguous()
};
let regions = device_layout
.storage()
.iter()
.map(|s| DeviceStorageV2::from_v1(s).unwrap())
.collect::<Vec<_>>();
let v2_device_layout = v2_device_layout.with_memory_regions(regions)?.build()?;
let host_layout = if leader_meta.num_host_blocks > 0 {
layout_config.num_blocks = leader_meta.num_host_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config.clone())
.fully_contiguous()
.allocate_pinned(true)
.build()?,
)
} else {
None
};
let disk_layout = if leader_meta.num_disk_blocks > 0 {
layout_config.num_blocks = leader_meta.num_disk_blocks;
Some(
PhysicalLayoutBuilder::new(agent.clone())
.with_config(layout_config)
.fully_contiguous()
.allocate_disk(None)
.build()?,
)
} else {
None
};
let transport_manager = TransportManager::builder()
.capabilities(TransferCapabilities::default().with_gds(true))
.worker_id(worker_id as u64)
.nixl_agent(agent)
.cuda_device_id(device_id)
.build()?;
let handler = BlockTransferHandlerV2::new(
Some(v2_device_layout),
host_layout, host_layout,
transfer_context.nixl_agent().as_ref(),
1,
worker_id,
)?)
} else {
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
.build()?
.allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
disk_layout, disk_layout,
transport_manager,
scheduler_client,
)?;
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
} else {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?;
let pool_config = PoolConfig {
enable_pool: true,
max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS,
max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE,
num_outer_components: device_layout.config().outer_dim,
num_layers: device_layout.config().num_layers,
};
let transfer_context = Arc::new(TransferContext::new(
Arc::new(Some(agent)),
DeviceAllocator::new(device_id)?.ctx().new_stream()?,
Handle::current(),
Some(pool_config),
));
// device
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
device_layout,
transfer_context.nixl_agent().as_ref(), transfer_context.nixl_agent().as_ref(),
2, 0,
worker_id, worker_id,
)?) )?);
} else { // host
None let host_blocks = if leader_meta.num_host_blocks > 0 {
}; let host_allocator = Arc::new(PinnedAllocator::default());
let host_layout = layout_builder
let handler = BlockTransferHandler::new( .num_blocks(leader_meta.num_host_blocks)
device_blocks, .build()?
host_blocks, .allocate_layout(worker_config.host_layout_type, host_allocator)?;
disk_blocks, Some(KvbmWorker::make_layout::<_, BasicMetadata>(
transfer_context, host_layout,
scheduler_client, transfer_context.nixl_agent().as_ref(),
)?; 1,
Ok(handler) worker_id,
)?)
} else {
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
.build()?
.allocate_layout(worker_config.disk_layout_type, disk_allocator)?;
Some(KvbmWorker::make_layout::<_, BasicMetadata>(
disk_layout,
transfer_context.nixl_agent().as_ref(),
2,
worker_id,
)?)
} else {
None
};
let handler = BlockTransferHandlerV1::new(
device_blocks,
host_blocks,
disk_blocks,
transfer_context,
scheduler_client,
)?;
Ok(Arc::new(handler) as Arc<dyn BlockTransferHandler>)
}
} }
struct WorkerMetadataHandler { struct WorkerMetadataHandler {
...@@ -199,6 +297,8 @@ impl Handler for WorkerMetadataHandler { ...@@ -199,6 +297,8 @@ impl Handler for WorkerMetadataHandler {
} }
} }
type TransferHandlerSender = Mutex<Option<oneshot::Sender<Arc<dyn BlockTransferHandler>>>>;
// Leader sends allocation config -> allocate -> publish handler -> mark ready -> ACK // Leader sends allocation config -> allocate -> publish handler -> mark ready -> ACK
struct LeaderMetadataHandler { struct LeaderMetadataHandler {
state: Arc<WorkerState>, state: Arc<WorkerState>,
...@@ -208,8 +308,8 @@ struct LeaderMetadataHandler { ...@@ -208,8 +308,8 @@ struct LeaderMetadataHandler {
worker_id: usize, worker_id: usize,
device_id: usize, device_id: usize,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
handler_cell: Arc<RwLock<Option<BlockTransferHandler>>>, handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>,
handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>, handler_tx: Arc<TransferHandlerSender>,
started: AtomicBool, started: AtomicBool,
} }
...@@ -344,7 +444,7 @@ impl Handler for GatedPing { ...@@ -344,7 +444,7 @@ impl Handler for GatedPing {
// Transfer dispatcher that waits until block transfer handler exists // Transfer dispatcher that waits until block transfer handler exists
struct BlockTransferDispatch { struct BlockTransferDispatch {
cell: Arc<RwLock<Option<BlockTransferHandler>>>, cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>>,
} }
#[async_trait] #[async_trait]
...@@ -405,7 +505,7 @@ impl KvbmWorkerConfig { ...@@ -405,7 +505,7 @@ impl KvbmWorkerConfig {
pub struct KvbmWorker { pub struct KvbmWorker {
task: Option<CriticalTaskExecutionHandle>, task: Option<CriticalTaskExecutionHandle>,
block_transfer_handler_rx: Option<oneshot::Receiver<transfer::BlockTransferHandler>>, block_transfer_handler_rx: Option<oneshot::Receiver<Arc<dyn BlockTransferHandler>>>,
} }
impl KvbmWorker { impl KvbmWorker {
...@@ -529,7 +629,7 @@ impl KvbmWorker { ...@@ -529,7 +629,7 @@ impl KvbmWorker {
layout_type: LayoutType, layout_type: LayoutType,
) -> anyhow::Result<( ) -> anyhow::Result<(
CriticalTaskExecutionHandle, CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>, oneshot::Receiver<Arc<dyn BlockTransferHandler>>,
)> { )> {
let cancel_token = config.cancel_token.clone(); let cancel_token = config.cancel_token.clone();
...@@ -580,13 +680,13 @@ impl KvbmWorker { ...@@ -580,13 +680,13 @@ impl KvbmWorker {
layout_type: LayoutType, layout_type: LayoutType,
) -> anyhow::Result<( ) -> anyhow::Result<(
CriticalTaskExecutionHandle, CriticalTaskExecutionHandle,
oneshot::Receiver<transfer::BlockTransferHandler>, oneshot::Receiver<Arc<dyn BlockTransferHandler>>,
)> { )> {
let cancel_token = config.cancel_token.clone(); let cancel_token = config.cancel_token.clone();
let scheduler_client = config.scheduler_client.clone(); let scheduler_client = config.scheduler_client.clone();
// channel to get BlockTransferHandler back to the caller // channel to get BlockTransferHandler back to the caller
let (handler_tx, handler_rx) = oneshot::channel::<transfer::BlockTransferHandler>(); let (handler_tx, handler_rx) = oneshot::channel::<Arc<dyn BlockTransferHandler>>();
let handler_tx_cell = Arc::new(Mutex::new(Some(handler_tx))); let handler_tx_cell = Arc::new(Mutex::new(Some(handler_tx)));
// channel that the worker will use to signal layout readiness // channel that the worker will use to signal layout readiness
...@@ -649,7 +749,7 @@ impl KvbmWorker { ...@@ -649,7 +749,7 @@ impl KvbmWorker {
/// This is a bit of a hack. Improve the API design around this in the future. /// This is a bit of a hack. Improve the API design around this in the future.
pub fn block_transfer_handler_rx( pub fn block_transfer_handler_rx(
&mut self, &mut self,
) -> Option<tokio::sync::oneshot::Receiver<BlockTransferHandler>> { ) -> Option<tokio::sync::oneshot::Receiver<Arc<dyn BlockTransferHandler>>> {
self.block_transfer_handler_rx.take() self.block_transfer_handler_rx.take()
} }
...@@ -677,7 +777,7 @@ impl KvbmWorker { ...@@ -677,7 +777,7 @@ impl KvbmWorker {
_device_layout_type: LayoutType, _device_layout_type: LayoutType,
config: KvbmWorkerConfig, config: KvbmWorkerConfig,
cancel_token: CancellationToken, cancel_token: CancellationToken,
handler_tx: Arc<Mutex<Option<oneshot::Sender<BlockTransferHandler>>>>, handler_tx: Arc<TransferHandlerSender>,
layout_ready_tx: tokio::sync::Mutex<Option<oneshot::Sender<String>>>, layout_ready_tx: tokio::sync::Mutex<Option<oneshot::Sender<String>>>,
scheduler_client: Option<TransferSchedulerClient>, scheduler_client: Option<TransferSchedulerClient>,
bytes_per_block: usize, bytes_per_block: usize,
...@@ -687,7 +787,7 @@ impl KvbmWorker { ...@@ -687,7 +787,7 @@ impl KvbmWorker {
let state = Arc::new(WorkerState::new()); let state = Arc::new(WorkerState::new());
// Cell to publish the transfer handler // Cell to publish the transfer handler
let transfer_handler_cell: Arc<RwLock<Option<BlockTransferHandler>>> = let transfer_handler_cell: Arc<RwLock<Option<Arc<dyn BlockTransferHandler>>>> =
Arc::new(RwLock::new(None)); Arc::new(RwLock::new(None));
// Build handlers map // Build handlers map
......
...@@ -315,8 +315,8 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator { ...@@ -315,8 +315,8 @@ impl StorageAllocator<PinnedStorage> for PinnedAllocator {
/// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that /// When building a [`DeviceStorage`] from a torch tensor, we need to ensure that
/// the torch tensor is not GCed until the [`DeviceStorage`] is dropped. /// the torch tensor is not GCed until the [`DeviceStorage`] is dropped.
/// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`] /// Because of this, we need to store a reference to the torch tensor in the [`DeviceStorage`]
#[derive(Debug)] #[derive(Clone, Debug)]
enum DeviceStorageType { pub enum DeviceStorageType {
Owned, // Memory that we allocated ourselves. Owned, // Memory that we allocated ourselves.
Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor. Torch { _tensor: Arc<dyn TorchTensor> }, // Memory that came from a torch tensor.
} }
...@@ -328,7 +328,7 @@ pub struct DeviceStorage { ...@@ -328,7 +328,7 @@ pub struct DeviceStorage {
size: usize, size: usize,
ctx: Arc<CudaContext>, ctx: Arc<CudaContext>,
handles: RegistrationHandles, handles: RegistrationHandles,
_storage_type: DeviceStorageType, storage_type: DeviceStorageType,
} }
impl Local for DeviceStorage {} impl Local for DeviceStorage {}
...@@ -345,7 +345,7 @@ impl DeviceStorage { ...@@ -345,7 +345,7 @@ impl DeviceStorage {
size, size,
ctx: ctx.clone(), ctx: ctx.clone(),
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Owned, storage_type: DeviceStorageType::Owned,
}) })
} }
...@@ -373,7 +373,7 @@ impl DeviceStorage { ...@@ -373,7 +373,7 @@ impl DeviceStorage {
size, size,
ctx: ctx.clone(), ctx: ctx.clone(),
handles: RegistrationHandles::new(), handles: RegistrationHandles::new(),
_storage_type: DeviceStorageType::Torch { _tensor: tensor }, storage_type: DeviceStorageType::Torch { _tensor: tensor },
}) })
} }
...@@ -381,6 +381,10 @@ impl DeviceStorage { ...@@ -381,6 +381,10 @@ impl DeviceStorage {
pub fn context(&self) -> &Arc<CudaContext> { pub fn context(&self) -> &Arc<CudaContext> {
&self.ctx &self.ctx
} }
pub fn device_storage_type(&self) -> &DeviceStorageType {
&self.storage_type
}
} }
impl Storage for DeviceStorage { impl Storage for DeviceStorage {
...@@ -414,7 +418,7 @@ impl CudaContextProivder for DeviceStorage { ...@@ -414,7 +418,7 @@ impl CudaContextProivder for DeviceStorage {
impl Drop for DeviceStorage { impl Drop for DeviceStorage {
fn drop(&mut self) { fn drop(&mut self) {
self.handles.release(); self.handles.release();
match &self._storage_type { match &self.storage_type {
DeviceStorageType::Owned => { DeviceStorageType::Owned => {
unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap() unsafe { cudarc::driver::result::free_sync(self.ptr as _) }.unwrap()
} }
......
...@@ -3,8 +3,13 @@ ...@@ -3,8 +3,13 @@
//! CUDA device memory storage. //! CUDA device memory storage.
use crate::block_manager::DeviceStorage as V1DeviceStorage;
use crate::block_manager::Storage as V1Storage;
use crate::block_manager::storage::cuda::DeviceStorageType as V1DeviceStorageType;
use super::{MemoryRegion, Result, StorageError, StorageKind}; use super::{MemoryRegion, Result, StorageError, StorageKind};
use cudarc::driver::CudaContext; use cudarc::driver::CudaContext;
use nixl_sys::NixlDescriptor;
use std::any::Any; use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
...@@ -30,6 +35,8 @@ pub struct DeviceStorage { ...@@ -30,6 +35,8 @@ pub struct DeviceStorage {
ptr: u64, ptr: u64,
device_id: u32, device_id: u32,
len: usize, len: usize,
// TODO: This is a bit ugly. We need to translate our v1 device layout to v2.
device_storage_type: V1DeviceStorageType,
} }
unsafe impl Send for DeviceStorage {} unsafe impl Send for DeviceStorage {}
...@@ -57,6 +64,7 @@ impl DeviceStorage { ...@@ -57,6 +64,7 @@ impl DeviceStorage {
ptr, ptr,
device_id, device_id,
len, len,
device_storage_type: V1DeviceStorageType::Owned,
}) })
} }
...@@ -69,18 +77,51 @@ impl DeviceStorage { ...@@ -69,18 +77,51 @@ impl DeviceStorage {
pub fn device_id(&self) -> u32 { pub fn device_id(&self) -> u32 {
self.device_id self.device_id
} }
pub fn from_v1(v1_storage: &V1DeviceStorage) -> Result<Self> {
let device_id = v1_storage.device_id() as u32;
let ctx = cuda_context(device_id)?;
let ptr;
unsafe {
ptr = v1_storage.as_ptr() as u64;
}
let len = v1_storage.size();
if !matches!(
v1_storage.device_storage_type(),
V1DeviceStorageType::Torch { .. }
) {
return Err(StorageError::Unsupported(
"Unable to convert owned device tensors.".into(),
));
}
Ok(Self {
ctx,
ptr,
device_id,
len,
device_storage_type: v1_storage.device_storage_type().clone(),
})
}
} }
impl Drop for DeviceStorage { impl Drop for DeviceStorage {
fn drop(&mut self) { fn drop(&mut self) {
if let Err(e) = self.ctx.bind_to_thread() { match self.device_storage_type {
tracing::debug!("failed to bind CUDA context for free: {e}"); V1DeviceStorageType::Owned => {
} if let Err(e) = self.ctx.bind_to_thread() {
unsafe { tracing::debug!("failed to bind CUDA context for free: {e}");
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) { }
tracing::debug!("failed to free device memory: {e}"); unsafe {
if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
tracing::debug!("failed to free device memory: {e}");
}
}
} }
}; V1DeviceStorageType::Torch { .. } => {} // Do nothing.
}
} }
} }
......
...@@ -199,6 +199,7 @@ markers = [ ...@@ -199,6 +199,7 @@ markers = [
"slow: marks tests as known to be slow", "slow: marks tests as known to be slow",
"h100: marks tests to run on H100", "h100: marks tests to run on H100",
"kvbm: marks tests for KV behavior and model determinism", "kvbm: marks tests for KV behavior and model determinism",
"kvbm_v2: marks tests using KVBM V2",
"model: model id used by a test or parameter", "model: model id used by a test or parameter",
"custom_build: marks tests that require custom builds or special setup (e.g., MoE models)", "custom_build: marks tests that require custom builds or special setup (e.g., MoE models)",
"k8s: marks tests as requiring Kubernetes", "k8s: marks tests as requiring Kubernetes",
......
...@@ -182,6 +182,20 @@ class LLMServerManager: ...@@ -182,6 +182,20 @@ class LLMServerManager:
) )
self.server_stdout_file.flush() self.server_stdout_file.flush()
# Try to download the model.
model = os.environ.get(
"KVBM_MODEL_ID", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
print("Attempting model download...")
try:
subprocess.run(
f"pip install hf_transfer && HF_HUB_ENABLE_HF_TRANSFER=1 hf download {model}",
check=True,
shell=True,
)
except subprocess.CalledProcessError:
print("Model download failed. Is this a locally stored model?")
# Launch # Launch
self.process = subprocess.Popen( self.process = subprocess.Popen(
self.server_cmd, self.server_cmd,
...@@ -334,7 +348,7 @@ def llm_server(request, runtime_services): ...@@ -334,7 +348,7 @@ def llm_server(request, runtime_services):
server_type=server_type, server_type=server_type,
) )
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300")) start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600"))
if not server_manager.start_server(timeout=start_timeout): if not server_manager.start_server(timeout=start_timeout):
pytest.fail( pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
...@@ -374,6 +388,24 @@ class TestDeterminismAgg(BaseTestDeterminism): ...@@ -374,6 +388,24 @@ class TestDeterminismAgg(BaseTestDeterminism):
tester, llm_server, runtime_services tester, llm_server, runtime_services
) )
@pytest.mark.parametrize(
"llm_server",
[
{"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000"))},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_agg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester, llm_server, runtime_services
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"llm_server", "llm_server",
[ [
......
...@@ -429,7 +429,7 @@ def llm_server(request, runtime_services): ...@@ -429,7 +429,7 @@ def llm_server(request, runtime_services):
server_type=server_type, server_type=server_type,
) )
start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "300")) start_timeout = int(os.environ.get("KVBM_SERVER_START_TIMEOUT", "600"))
if not server_manager.start_server(timeout=start_timeout): if not server_manager.start_server(timeout=start_timeout):
pytest.fail( pytest.fail(
f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})" f"Failed to start {server_type} server (cpu_blocks={cpu_blocks}, gpu_blocks={gpu_blocks}, port={server_manager.port})"
...@@ -475,6 +475,30 @@ class TestDeterminismDisagg(BaseTestDeterminism): ...@@ -475,6 +475,30 @@ class TestDeterminismDisagg(BaseTestDeterminism):
success_rate_threshold=SUCCESS_RATE_THRESHOLD, success_rate_threshold=SUCCESS_RATE_THRESHOLD,
) )
@pytest.mark.parametrize(
"llm_server",
[
{
"cpu_blocks": int(os.environ.get("KVBM_CPU_BLOCKS", "10000")),
"gpu_blocks": int(os.environ.get("KVBM_GPU_BLOCKS", "1000")),
},
],
indirect=True,
)
@pytest.mark.kvbm_v2
def test_determinism_disagg_with_cache_reset_v2(
self, tester, llm_server, runtime_services, monkeypatch
):
"""Test determinism across cache reset: run test with warmup, reset cache, run again without warmup."""
monkeypatch.setenv("DYN_KVBM_USE_V2_TRANSFER_EXPERIMENTAL", "1")
# Call the base class implementation
super().base_test_determinism_with_cache_reset(
tester,
llm_server,
runtime_services,
success_rate_threshold=SUCCESS_RATE_THRESHOLD,
)
if __name__ == "__main__": if __name__ == "__main__":
# Allow running as script # Allow running as script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment