"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "c3bfbd2040032d9b5f6e52fa8e1bf7e327533e98"
Unverified Commit 5d5080ba authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Various KVBM improvements (#1134)

parent d3b0cae1
...@@ -192,11 +192,21 @@ mod tests { ...@@ -192,11 +192,21 @@ mod tests {
fn create_reference_block_manager() -> ReferenceBlockManager { fn create_reference_block_manager() -> ReferenceBlockManager {
let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst); let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
// Check if we're already in a Tokio runtime context
let async_runtime = if tokio::runtime::Handle::try_current().is_ok() {
None // If we're already in a runtime, don't create a new one
} else {
// Only create a new runtime if not already in one
Some(Arc::new(tokio::runtime::Runtime::new().unwrap()))
};
let config = KvBlockManagerConfig::builder() let config = KvBlockManagerConfig::builder()
.runtime( .runtime(
KvManagerRuntimeConfig::builder() KvManagerRuntimeConfig::builder()
.worker_id(worker_id) .worker_id(worker_id)
.enable_nixl() .enable_nixl()
.async_runtime(async_runtime)
.build() .build()
.unwrap(), .unwrap(),
) )
......
...@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync + ...@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
/// Resets the metadata to the default value /// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true /// If called, the [BlockMetadata::is_reset()] should return true
fn reset_metadata(&mut self); fn reset_metadata(&mut self);
/// The offload priority of the block. Higher priority blocks are offloaded first.
/// If the block should not be offloaded, return None.
fn offload_priority(&self) -> Option<u64>;
} }
/// Marker trait for types that are mutable blocks /// Marker trait for types that are mutable blocks
...@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata { ...@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
fn reset_metadata(&mut self) { fn reset_metadata(&mut self) {
self.priority = 0; self.priority = 0;
} }
fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64)
}
} }
/// Collection that holds shared storage and layout /// Collection that holds shared storage and layout
#[derive(Debug)] #[derive(Debug)]
......
...@@ -133,7 +133,7 @@ where ...@@ -133,7 +133,7 @@ where
pub trait WriteTo<Target> { pub trait WriteTo<Target> {
fn write_to( fn write_to(
&self, &self,
dst: &mut Target, dst: &mut Vec<Target>,
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<(), TransferError>; ) -> Result<(), TransferError>;
...@@ -143,31 +143,44 @@ pub trait WriteTo<Target> { ...@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
/// Returns a future that will complete when the transfer is complete. /// Returns a future that will complete when the transfer is complete.
fn nixl_write_to( fn nixl_write_to(
&self, &self,
dst: &mut Target, dst: &mut Vec<Target>,
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>; ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>>
where where
RB: WriteToStrategy<WB> + Local, RB: WriteToStrategy<WB> + Local,
{ {
fn write_to( fn write_to(
&self, &self,
dst: &mut WB, dst: &mut Vec<WB>,
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<(), TransferError> { ) -> Result<(), TransferError> {
match Self::write_to_strategy() { match RB::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst), TransferStrategy::Memcpy => {
for (src, dst) in self.iter().zip(dst.iter_mut()) {
memcpy::copy_block(src.as_ref(), dst)?;
}
Ok(())
}
TransferStrategy::CudaAsyncH2D TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H | TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => { | TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy()) for (src, dst) in self.iter().zip(dst.iter_mut()) {
cuda::copy_block(
src.as_ref(),
dst,
ctx.stream().as_ref(),
RB::write_to_strategy(),
)?;
}
Ok(())
} }
TransferStrategy::NixlWrite => { TransferStrategy::NixlWrite => {
std::mem::drop(nixl::write_block_to(self, dst, ctx, notify)?); std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?);
Ok(()) Ok(())
} }
_ => Err(TransferError::IncompatibleTypes(format!( _ => Err(TransferError::IncompatibleTypes(format!(
...@@ -175,17 +188,16 @@ where ...@@ -175,17 +188,16 @@ where
RB::write_to_strategy() RB::write_to_strategy()
))), ))),
} }
// dispatch_copy_to(self, dst, self.transfer_context())
} }
fn nixl_write_to( fn nixl_write_to(
&self, &self,
dst: &mut WB, dst: &mut Vec<WB>,
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> { ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() { if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
Ok(nixl::write_block_to(self, dst, ctx, notify)?) Ok(nixl::write_blocks_to(self, dst, ctx, notify)?)
} else { } else {
Err(TransferError::IncompatibleTypes(format!( Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}", "Expected NIXL transfer strategy, got: {:?}",
......
...@@ -18,16 +18,14 @@ use super::*; ...@@ -18,16 +18,14 @@ use super::*;
use anyhow::Result; use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp}; use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::future::{poll_fn, Future}; use std::future::{poll_fn, Future};
use std::ops::Range;
use std::task::Poll; use std::task::Poll;
/// Copy a block from a source to a destination using CUDA memcpy fn append_xfer_request<Source, Destination>(
pub fn write_block_to<'a, Source, Destination>( src: &Arc<Source>,
src: &'a Source, dst: &mut Destination,
dst: &'a mut Destination, src_dl: &mut XferDescList,
ctx: Arc<TransferContext>, dst_dl: &mut XferDescList,
notify: Option<String>, ) -> Result<()>
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
...@@ -36,17 +34,6 @@ where ...@@ -36,17 +34,6 @@ where
let dst_data = dst.block_data_mut(private::PrivateToken); let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() { if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
// Keep the arc to use in the returned future.
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
let src_desc = src_data.block_view()?.as_nixl_descriptor(); let src_desc = src_data.block_view()?.as_nixl_descriptor();
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut(); let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
...@@ -64,45 +51,42 @@ where ...@@ -64,45 +51,42 @@ where
)?; )?;
} }
let xfer_req = nixl_agent Ok(())
.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
.unwrap();
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
// Return a future that completes when the transfer is complete.
// TODO: How efficient is this? Can we do better?
Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
}
})))
} else { } else {
assert_eq!(src_data.num_layers(), dst_data.num_layers()); assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify) for layer_idx in 0..src_data.num_layers() {
for outer_idx in 0..src_data.num_outer_dims() {
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
debug_assert_eq!(src_view.size(), dst_view.size());
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
}
Ok(())
} }
} }
/// Copy a range of layers from a source to a destination using CUDA memcpy /// Copy a block from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>( pub fn write_blocks_to<Source, Destination>(
layer_range: Range<usize>, src: &[Arc<Source>],
src: &'a Source, dst: &mut [Destination],
dst: &'a mut Destination,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
notify: Option<String>, notify: Option<String>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
...@@ -110,8 +94,10 @@ where ...@@ -110,8 +94,10 @@ where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
{ {
let src_data = src.block_data(private::PrivateToken); if src.is_empty() || dst.is_empty() {
let dst_data = dst.block_data_mut(private::PrivateToken); return Ok(Box::new(std::future::ready(())));
}
assert_eq!(src.len(), dst.len());
let nixl_agent_arc = ctx.as_ref().nixl_agent(); let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc let nixl_agent = nixl_agent_arc
...@@ -119,44 +105,31 @@ where ...@@ -119,44 +105,31 @@ where
.as_ref() .as_ref()
.expect("NIXL agent not found"); .expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string(); let src_mem_type = src
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?; .first()
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?; .unwrap()
.block_data(private::PrivateToken)
// #[cfg(debug_assertions)] .storage_type()
// { .nixl_mem_type();
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy< let dst_mem_type = dst
// Destination::StorageType, .first()
// >>::write_to_strategy(); .unwrap()
// assert_eq!(strategy, expected_strategy); .block_data(private::PrivateToken)
// } .storage_type()
.nixl_mem_type();
for layer_idx in layer_range {
for outer_idx in 0..src_data.num_outer_dims() { let mut src_dl = XferDescList::new(src_mem_type)?;
let src_view = src_data.layer_view(layer_idx, outer_idx)?; let mut dst_dl = XferDescList::new(dst_mem_type)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
for (src, dst) in src.iter().zip(dst.iter_mut()) {
debug_assert_eq!(src_view.size(), dst_view.size()); append_xfer_request(src, dst, &mut src_dl, &mut dst_dl)?;
let src_desc = src_view.as_nixl_descriptor();
let dst_desc = dst_view.as_nixl_descriptor_mut();
unsafe {
src_dl.add_desc(
src_desc.as_ptr() as usize,
src_desc.size(),
src_desc.device_id(),
)?;
dst_dl.add_desc(
dst_desc.as_ptr() as usize,
dst_desc.size(),
dst_desc.device_id(),
)?;
}
}
} }
debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?;
let mut xfer_args = OptArgs::new()?; let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify { if let Some(notify) = notify {
...@@ -164,14 +137,6 @@ where ...@@ -164,14 +137,6 @@ where
xfer_args.set_notification_message(notify.as_bytes())?; xfer_args.set_notification_message(notify.as_bytes())?;
} }
let xfer_req = nixl_agent.create_xfer_req(
XferOp::Write,
&src_dl,
&dst_dl,
&remote_worker_id,
Some(&xfer_args),
)?;
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?; let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
Ok(Box::new(poll_fn(move |_cx| { Ok(Box::new(poll_fn(move |_cx| {
...@@ -179,6 +144,8 @@ where ...@@ -179,6 +144,8 @@ where
.as_ref() .as_ref()
.as_ref() .as_ref()
.expect("NIXL agent not found"); .expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() { if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(()) Poll::Ready(())
} else { } else {
......
...@@ -65,18 +65,20 @@ use std::collections::BTreeSet; ...@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
mod pending; mod pending;
pub mod request; pub mod request;
use pending::{CudaTransferManager, DiskTransferManager, PendingTransfer, TransferManager}; use pending::{
CudaTransferManager, DiskTransferManager, PendingTransfer, TransferBatcher, TransferManager,
};
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest}; use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
// TODO: This should be dynamic const MAX_CONCURRENT_TRANSFERS: usize = 4;
const MAX_OFFLOAD_STREAM_DEPTH: usize = 4; const MAX_TRANSFER_BATCH_SIZE: usize = 16;
/// The offload manager handles all block transfers between different cache levels. /// The offload manager handles all block transfers between different cache levels.
pub struct OffloadManager<Metadata: BlockMetadata> { pub struct OffloadManager<Metadata: BlockMetadata> {
// Handles to the device, host, and disk pools. // Handles to the device, host, and disk pools.
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>, disk: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>, host: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>, device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
/// Queue of offloading requests. /// Queue of offloading requests.
device_offload_tx: mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Metadata>>, device_offload_tx: mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Metadata>>,
...@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> { ...@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
impl<Metadata: BlockMetadata> OffloadManager<Metadata> { impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
pub fn new( pub fn new(
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>, disk: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>, host: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>, device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
nixl_agent: Arc<Option<NixlAgent>>, nixl_agent: Arc<Option<NixlAgent>>,
async_rt_handle: Handle, async_rt_handle: Handle,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
...@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone(); let device_clone = this.device.clone();
let host_clone = this.host.clone(); let host_clone = this.host.clone();
async_rt_handle.spawn(async move { async_rt_handle.spawn(async move {
OffloadManager::offload_worker( let res = OffloadManager::offload_worker(
device_clone, device_clone,
host_clone, host_clone,
device_offload_rx, device_offload_rx,
Arc::new(CudaTransferManager::new( Arc::new(TransferBatcher::new(
device_offload_transfer_ctx, CudaTransferManager::new(device_offload_transfer_ctx, MAX_CONCURRENT_TRANSFERS),
MAX_OFFLOAD_STREAM_DEPTH, MAX_TRANSFER_BATCH_SIZE,
)), )),
) )
.await .await;
.unwrap() tracing::warn!("Offload worker terminated: {:?}", res);
}); });
let transfer_ctx = Arc::new(TransferContext::new( let transfer_ctx = Arc::new(TransferContext::new(
...@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let disk_clone = this.disk.clone(); let disk_clone = this.disk.clone();
let transfer_ctx_clone = transfer_ctx.clone(); let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move { async_rt_handle.spawn(async move {
OffloadManager::offload_worker( let res = OffloadManager::offload_worker(
host_clone, host_clone,
disk_clone, disk_clone,
host_offload_rx, host_offload_rx,
Arc::new(DiskTransferManager::new( Arc::new(TransferBatcher::new(
transfer_ctx_clone, DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_OFFLOAD_STREAM_DEPTH, MAX_TRANSFER_BATCH_SIZE,
)), )),
) )
.await .await;
.unwrap() tracing::warn!("Offload worker terminated: {:?}", res);
}); });
// Host -> Device onboarding // Host -> Device onboarding
...@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone(); let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone(); let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move { async_rt_handle.spawn(async move {
OffloadManager::onboard_worker( let res = OffloadManager::onboard_worker(
host_clone, host_clone,
device_clone, device_clone,
host_onboard_rx, host_onboard_rx,
Arc::new(CudaTransferManager::new(transfer_ctx_clone, 16384)), Arc::new(TransferBatcher::new(
CudaTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
) )
.await .await;
.unwrap() tracing::warn!("Onboard worker terminated: {:?}", res);
}); });
// Disk -> Device onboarding // Disk -> Device onboarding
...@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone(); let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone(); let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move { async_rt_handle.spawn(async move {
OffloadManager::onboard_worker( let res = OffloadManager::onboard_worker(
disk_clone, disk_clone,
device_clone, device_clone,
disk_onboard_rx, disk_onboard_rx,
Arc::new(DiskTransferManager::new(transfer_ctx_clone, 16384)), Arc::new(TransferBatcher::new(
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
) )
.await .await;
.unwrap() tracing::warn!("Onboard worker terminated: {:?}", res);
}); });
Ok(this_clone) Ok(this_clone)
} }
async fn offload_worker<Source: Storage, Target: Storage>( async fn offload_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>, source_pool: Option<Arc<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>, target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>, mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>, transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> { ) -> Result<()> {
if source_pool_arc.is_none() || target_pool_arc.is_none() { if source_pool.is_none() || target_pool.is_none() {
return Ok(()); return Ok(());
} }
let source_pool = source_pool_arc.as_ref().as_ref().unwrap(); let source_pool = source_pool.as_ref().unwrap();
let target_pool = target_pool_arc.as_ref().as_ref().unwrap(); let target_pool = target_pool.as_ref().unwrap();
let mut queue = BTreeSet::new(); let mut queue = BTreeSet::new();
...@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
} }
// Allocate a block from the host pool. // Allocate a block from the host pool.
// TODO: The most likely error here is that the host pool is full. // TODO: The most likely error here is that the target pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime. // It's probably not a good idea to keep consuming queue elements in the meantime.
let target_blocks = match target_pool.allocate_blocks(1).await { let target_blocks = match target_pool.allocate_blocks(1).await {
Ok(blocks) => blocks, Ok(blocks) => blocks,
...@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
if let Some(target_block) = target_blocks.into_iter().next() { if let Some(target_block) = target_blocks.into_iter().next() {
transfer_manager transfer_manager
.begin_transfer(PendingTransfer::new( .enqueue_transfer(PendingTransfer::new(
vec![block], vec![block],
vec![target_block], vec![target_block],
None, None,
target_pool_arc.clone(), target_pool.clone(),
)) ))
.await?; .await?;
} }
...@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
} }
async fn onboard_worker<Source: Storage, Target: Storage>( async fn onboard_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>, source_pool: Option<Arc<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>, target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>, mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>, transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> { ) -> Result<()> {
if source_pool_arc.is_none() || target_pool_arc.is_none() { if source_pool.is_none() || target_pool.is_none() {
return Ok(()); return Ok(());
} }
let target_pool = target_pool_arc.as_ref().as_ref().unwrap(); let target_pool = target_pool.as_ref().unwrap();
// Loop on incoming requests // Loop on incoming requests
while let Some(request) = onboard_rx.recv().await { while let Some(request) = onboard_rx.recv().await {
...@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
.collect(); .collect();
transfer_manager transfer_manager
.begin_transfer(PendingTransfer::new( .enqueue_transfer(PendingTransfer::new(
sources, sources,
target_blocks, target_blocks,
Some(request.response_tx), Some(request.response_tx),
target_pool_arc.clone(), target_pool.clone(),
)) ))
.await?; .await?;
} }
...@@ -478,9 +486,9 @@ mod tests { ...@@ -478,9 +486,9 @@ mod tests {
const BLOCK_SIZE: usize = 4; const BLOCK_SIZE: usize = 4;
type DevicePool = Arc<Option<BlockPool<DeviceStorage, BasicMetadata>>>; type DevicePool = Option<Arc<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Arc<Option<BlockPool<PinnedStorage, BasicMetadata>>>; type HostPool = Option<Arc<BlockPool<PinnedStorage, BasicMetadata>>>;
type DiskPool = Arc<Option<BlockPool<DiskStorage, BasicMetadata>>>; type DiskPool = Option<Arc<BlockPool<DiskStorage, BasicMetadata>>>;
lazy_static::lazy_static! { lazy_static::lazy_static! {
static ref NIXL_AGENT: Arc<Option<NixlAgent>> = { static ref NIXL_AGENT: Arc<Option<NixlAgent>> = {
...@@ -521,16 +529,18 @@ mod tests { ...@@ -521,16 +529,18 @@ mod tests {
device.nixl_register(agent, None)?; device.nixl_register(agent, None)?;
let device_blocks = Blocks::<_, BasicMetadata>::new(device, 42, 0)?.into_blocks()?; let device_blocks = Blocks::<_, BasicMetadata>::new(device, 42, 0)?.into_blocks()?;
let device_pool = Arc::new(Some(BlockPool::builder().blocks(device_blocks).build()?)); let device_pool = Some(Arc::new(
BlockPool::builder().blocks(device_blocks).build()?,
));
let host_pool = if let Some(host_blocks) = host_blocks { let host_pool = if let Some(host_blocks) = host_blocks {
config.num_blocks = host_blocks; config.num_blocks = host_blocks;
let mut host = FullyContiguous::allocate(config.clone(), &PinnedAllocator::default())?; let mut host = FullyContiguous::allocate(config.clone(), &PinnedAllocator::default())?;
host.nixl_register(agent, None)?; host.nixl_register(agent, None)?;
let host_blocks = Blocks::<_, BasicMetadata>::new(host, 42, 0)?.into_blocks()?; let host_blocks = Blocks::<_, BasicMetadata>::new(host, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(host_blocks).build()?)) Some(Arc::new(BlockPool::builder().blocks(host_blocks).build()?))
} else { } else {
Arc::new(None) None
}; };
let disk_pool = if let Some(disk_blocks) = disk_blocks { let disk_pool = if let Some(disk_blocks) = disk_blocks {
...@@ -538,9 +548,9 @@ mod tests { ...@@ -538,9 +548,9 @@ mod tests {
let mut disk = FullyContiguous::allocate(config, &DiskAllocator)?; let mut disk = FullyContiguous::allocate(config, &DiskAllocator)?;
disk.nixl_register(agent, None)?; disk.nixl_register(agent, None)?;
let disk_blocks = Blocks::<_, BasicMetadata>::new(disk, 42, 0)?.into_blocks()?; let disk_blocks = Blocks::<_, BasicMetadata>::new(disk, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(disk_blocks).build()?)) Some(Arc::new(BlockPool::builder().blocks(disk_blocks).build()?))
} else { } else {
Arc::new(None) None
}; };
let async_rt_handle = Handle::current(); let async_rt_handle = Handle::current();
...@@ -558,7 +568,7 @@ mod tests { ...@@ -558,7 +568,7 @@ mod tests {
/// Create a block in the 'RESET' state. /// Create a block in the 'RESET' state.
async fn get_block<S: Storage, Metadata: BlockMetadata>( async fn get_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>, pool: &Arc<BlockPool<S, Metadata>>,
) -> Result<MutableBlock<S, Metadata>> { ) -> Result<MutableBlock<S, Metadata>> {
pool.allocate_blocks(1) pool.allocate_blocks(1)
.await? .await?
...@@ -569,7 +579,7 @@ mod tests { ...@@ -569,7 +579,7 @@ mod tests {
/// Create a block in the 'PARTIAL' state. /// Create a block in the 'PARTIAL' state.
async fn partial_block<S: Storage, Metadata: BlockMetadata>( async fn partial_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>, pool: &Arc<BlockPool<S, Metadata>>,
token: u32, token: u32,
) -> Result<MutableBlock<S, Metadata>> { ) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?; let mut block = get_block(pool).await?;
...@@ -580,7 +590,7 @@ mod tests { ...@@ -580,7 +590,7 @@ mod tests {
/// Create a block in the 'COMPLETED' state. /// Create a block in the 'COMPLETED' state.
async fn completed_block<S: Storage, Metadata: BlockMetadata>( async fn completed_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>, pool: &Arc<BlockPool<S, Metadata>>,
tokens: [u32; BLOCK_SIZE], tokens: [u32; BLOCK_SIZE],
) -> Result<MutableBlock<S, Metadata>> { ) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?; let mut block = get_block(pool).await?;
...@@ -666,7 +676,7 @@ mod tests { ...@@ -666,7 +676,7 @@ mod tests {
async fn test_offload_invalid_blocks() -> Result<()> { async fn test_offload_invalid_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
// Check blocks in the 'RESET' state. // Check blocks in the 'RESET' state.
let immutable_block = ImmutableBlock::new(Arc::new(get_block(device_pool).await?)); let immutable_block = ImmutableBlock::new(Arc::new(get_block(device_pool).await?));
...@@ -699,8 +709,8 @@ mod tests { ...@@ -699,8 +709,8 @@ mod tests {
async fn test_offload_registered_blocks() -> Result<()> { async fn test_offload_registered_blocks() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
// Create a block and register it with the offload manager // Create a block and register it with the offload manager
let block = completed_block(device_pool, [0, 1, 2, 3]).await?; let block = completed_block(device_pool, [0, 1, 2, 3]).await?;
...@@ -742,8 +752,8 @@ mod tests { ...@@ -742,8 +752,8 @@ mod tests {
async fn test_no_host_blocks_available() -> Result<()> { async fn test_no_host_blocks_available() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let host_blocks = host_pool.allocate_blocks(4).await?; let host_blocks = host_pool.allocate_blocks(4).await?;
assert_eq!(host_blocks.len(), 4); assert_eq!(host_blocks.len(), 4);
...@@ -790,8 +800,8 @@ mod tests { ...@@ -790,8 +800,8 @@ mod tests {
async fn test_onboard() -> Result<()> { async fn test_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
// Allocate and fill a block on the host. // Allocate and fill a block on the host.
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?; let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
...@@ -844,8 +854,8 @@ mod tests { ...@@ -844,8 +854,8 @@ mod tests {
async fn test_offload_onboard() -> Result<()> { async fn test_offload_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?; let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool let immutable_device_block = device_pool
...@@ -913,8 +923,8 @@ mod tests { ...@@ -913,8 +923,8 @@ mod tests {
async fn test_onboard_err_handling() -> Result<()> { async fn test_onboard_err_handling() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?; let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool let immutable_host_block = host_pool
...@@ -942,7 +952,7 @@ mod tests { ...@@ -942,7 +952,7 @@ mod tests {
async fn test_offload_onboard_no_host_blocks() -> Result<()> { async fn test_offload_onboard_no_host_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, None, None)?; let (offload_manager, device_pool, _, _) = build_pools(4, None, None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?; let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool let immutable_device_block = device_pool
...@@ -961,8 +971,8 @@ mod tests { ...@@ -961,8 +971,8 @@ mod tests {
async fn test_offload_disk() -> Result<()> { async fn test_offload_disk() -> Result<()> {
let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4))?; let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4))?;
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?; let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool let immutable_host_block = host_pool
...@@ -996,8 +1006,8 @@ mod tests { ...@@ -996,8 +1006,8 @@ mod tests {
async fn test_onboard_disk() -> Result<()> { async fn test_onboard_disk() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4))?; let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
let disk_block = completed_block(disk_pool, [0, 1, 2, 3]).await?; let disk_block = completed_block(disk_pool, [0, 1, 2, 3]).await?;
let immutable_disk_block = disk_pool let immutable_disk_block = disk_pool
...@@ -1032,9 +1042,9 @@ mod tests { ...@@ -1032,9 +1042,9 @@ mod tests {
let (offload_manager, device_pool, host_pool, disk_pool) = let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(8, Some(8), Some(8))?; build_pools(8, Some(8), Some(8))?;
let disk_pool = disk_pool.as_ref().as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let device_pool = device_pool.as_ref().as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let mut host_blocks = Vec::new(); let mut host_blocks = Vec::new();
...@@ -1076,4 +1086,39 @@ mod tests { ...@@ -1076,4 +1086,39 @@ mod tests {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_transfer_batcher() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(
2 * MAX_TRANSFER_BATCH_SIZE + 1,
None,
Some(2 * MAX_TRANSFER_BATCH_SIZE + 1),
)?;
let device_pool = device_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let mut disk_blocks = Vec::new();
for i in 0..2 * MAX_TRANSFER_BATCH_SIZE + 1 {
disk_blocks.push(completed_block(disk_pool, [i as u32; 4]).await?);
}
let immutable_disk_blocks = disk_pool.register_blocks(disk_blocks).await?;
let device_blocks = offload_manager
.onboard(immutable_disk_blocks.clone())
.await?;
assert_eq!(device_blocks.len(), 2 * MAX_TRANSFER_BATCH_SIZE + 1);
for device_block in &device_blocks {
let blocks = device_pool
.match_sequence_hashes(vec![device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], device_block)?;
}
Ok(())
}
} }
...@@ -33,11 +33,12 @@ ...@@ -33,11 +33,12 @@
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers. //! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//! //!
//! ## Workflow //! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::begin_transfer`] //! 1. A transfer request is made by calling [`TransferManager::enqueue_transfer`]
//! 2. [`TransferManager::begin_transfer`] performs the transfer, and enqueues relevant data into a bounded channel. //! 2. [`TransferManager::enqueue_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers. //! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller. //! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn; use std::thread::spawn;
...@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool; ...@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent}; use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{future::join_all, stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult; use super::BlockResult;
...@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta ...@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete. /// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
/// The target pool that will receive the registered block. /// The target pool that will receive the registered block.
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>, target_pool: Arc<BlockPool<Target, Metadata>>,
} }
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...@@ -78,30 +79,34 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -78,30 +79,34 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
sources: Vec<Arc<MutableBlock<Source, Metadata>>>, sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>, targets: Vec<MutableBlock<Target, Metadata>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>, target_pool: Arc<BlockPool<Target, Metadata>>,
) -> Self { ) -> Self {
assert_eq!(sources.len(), targets.len());
Self { Self {
sources, sources,
targets, targets,
completion_indicator, completion_indicator,
target_registration_pool, target_pool,
} }
} }
fn handle_complete(self) -> Result<()> { fn handle_complete(self) -> Result<()> {
let Self { let Self {
targets, sources,
target_registration_pool, mut targets,
target_pool,
completion_indicator, completion_indicator,
.. ..
} = self; } = self;
if let Some(target_registration_pool) = target_registration_pool.as_ref() { for (source, target) in sources.iter().zip(targets.iter_mut()) {
let blocks = target_registration_pool.register_blocks_blocking(targets)?; transfer_metadata(source, target)?;
}
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator { if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?; completion_indicator.send(Ok(blocks))?;
}
} }
Ok(()) Ok(())
...@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad ...@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
Send + Sync Send + Sync
{ {
/// Begin a transfer. Blocks if the pending queue is full. /// Begin a transfer. Blocks if the pending queue is full.
async fn begin_transfer( async fn enqueue_transfer(
&self, &self,
pending_transfer: PendingTransfer<Source, Target, Metadata>, pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()>; ) -> Result<()>;
...@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block ...@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata> CudaTransferManager<Source, Target, Metadata>
{ {
pub fn new(transfer_ctx: Arc<TransferContext>, max_depth: usize) -> Self { pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
let (tx, mut rx) = let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(
mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(max_depth); max_concurrent_transfers,
);
spawn(move || { spawn(move || {
while let Some((pending_transfer, event)) = rx.blocking_recv() { while let Some((pending_transfer, event)) = rx.blocking_recv() {
// Wait for the event. // Wait for the event.
event.synchronize()?; event.synchronize()?;
// Only finalize the transfer after the event is signaled. // Only finalize the transfer after the event is signaled.
pending_transfer.handle_complete()?; match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
} }
Ok::<(), anyhow::Error>(()) Ok::<(), anyhow::Error>(())
}); });
...@@ -183,18 +196,15 @@ where ...@@ -183,18 +196,15 @@ where
// Check that the target block is writable. // Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>, MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{ {
async fn begin_transfer( async fn enqueue_transfer(
&self, &self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
for (source, target) in pending_transfer pending_transfer.sources.write_to(
.sources &mut pending_transfer.targets,
.iter() None,
.zip(pending_transfer.targets.iter_mut()) self.transfer_ctx.clone(),
{ )?;
transfer_metadata(source, target)?;
source.write_to(target, None, self.transfer_ctx.clone())?;
}
// Use a cuda event to record the completion of the transfers. // Use a cuda event to record the completion of the transfers.
let event = self let event = self
...@@ -218,7 +228,7 @@ pub struct DiskTransferManager { ...@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
} }
impl DiskTransferManager { impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_size: usize) -> Self { pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1); let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move { tokio::spawn(async move {
...@@ -230,7 +240,7 @@ impl DiskTransferManager { ...@@ -230,7 +240,7 @@ impl DiskTransferManager {
tokio::select! { tokio::select! {
Some(future) = futures_rx.recv() => { Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity. // If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_size { while pending_transfers.len() >= max_concurrent_transfers {
pending_transfers.next().await; pending_transfers.next().await;
} }
// Once we have capacity, push the new future onto the queue. // Once we have capacity, push the new future onto the queue.
...@@ -267,26 +277,26 @@ where ...@@ -267,26 +277,26 @@ where
// Check that the target block is writable. // Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>, MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{ {
async fn begin_transfer( async fn enqueue_transfer(
&self, &self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
let futures = pending_transfer let future = pending_transfer.sources.nixl_write_to(
.sources &mut pending_transfer.targets,
.iter() None,
.zip(pending_transfer.targets.iter_mut()) self.transfer_ctx.clone(),
.map(|(source, target)| { )?;
transfer_metadata(source, target).unwrap();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to(target, None, self.transfer_ctx.clone())
.unwrap()
})
.collect::<Vec<_>>();
let completion_future = async move { let completion_future = async move {
let _ = join_all(futures).await; let _ = future.await;
pending_transfer.handle_complete().unwrap(); match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}; };
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`, // Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
...@@ -296,3 +306,112 @@ where ...@@ -296,3 +306,112 @@ where
Ok(()) Ok(())
} }
} }
/// A transfer manager that enforces a max batch size for transfers.
pub struct TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
transfer_manager: Manager,
max_transfer_batch_size: usize,
_phantom: PhantomData<(Source, Target, Metadata)>,
}
impl<Source, Target, Metadata, Manager> TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
pub fn new(transfer_manager: Manager, max_transfer_batch_size: usize) -> Self {
Self {
transfer_manager,
max_transfer_batch_size,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<Source, Target, Metadata, Manager> TransferManager<Source, Target, Metadata>
for TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
async fn enqueue_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
// If it's smaller than the max batch size, just enqueue it.
if pending_transfer.sources.len() < self.max_transfer_batch_size {
return self
.transfer_manager
.enqueue_transfer(pending_transfer)
.await;
}
// Otherwise, we need to split the transfer into multiple smaller transfers.
let PendingTransfer {
mut sources,
mut targets,
completion_indicator,
target_pool,
} = pending_transfer;
let mut indicators = Vec::new();
while !sources.is_empty() {
let sources = sources
.drain(..std::cmp::min(self.max_transfer_batch_size, sources.len()))
.collect();
let targets = targets
.drain(..std::cmp::min(self.max_transfer_batch_size, targets.len()))
.collect();
// If we have a completion indicator, we need to create a new one for each sub-transfer.
let indicator = if completion_indicator.is_some() {
let (batch_tx, batch_rx) = oneshot::channel();
indicators.push(batch_rx);
Some(batch_tx)
} else {
None
};
let request = PendingTransfer::new(sources, targets, indicator, target_pool.clone());
// Enqueue our reduced transfer. This may block if the queue is full.
self.transfer_manager.enqueue_transfer(request).await?;
}
if let Some(completion_indicator) = completion_indicator {
tokio::spawn(async move {
let mut results = Vec::new();
for indicator in indicators.into_iter() {
// Await each sub-transfer, and append the results to our final results.
let result = match indicator.await.unwrap() {
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
return;
}
};
results.extend(result);
}
// Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results)).unwrap();
});
}
Ok(())
}
}
...@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock}; ...@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError; use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage; use crate::block_manager::storage::Storage;
#[derive(PartialEq, Eq, Ord, PartialOrd)] /// Higher priority offloads are done first.
/// If two offloads have the same priority, the one that was requested first is done first.
#[derive(PartialEq, Eq)]
pub struct OffloadRequestKey { pub struct OffloadRequestKey {
pub priority: u64, pub priority: u64,
pub timestamp: u64, pub timestamp: u64,
} }
impl PartialOrd for OffloadRequestKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OffloadRequestKey {
fn cmp(&self, other: &Self) -> Ordering {
other
.priority
.cmp(&self.priority)
.then(self.timestamp.cmp(&other.timestamp))
}
}
/// Data needed to offload a block. /// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it. /// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed. /// This way, we don't prevent the block from being reused if needed.
......
...@@ -518,6 +518,10 @@ pub(crate) mod tests { ...@@ -518,6 +518,10 @@ pub(crate) mod tests {
fn reset_metadata(&mut self) { fn reset_metadata(&mut self) {
self.priority = 0; self.priority = 0;
} }
fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64)
}
} }
type TestPriorityKey = PriorityKey<TestMetadata>; type TestPriorityKey = PriorityKey<TestMetadata>;
......
...@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self.active.register(mutable)?; let immutable = self.active.register(mutable)?;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if offload { if offload {
immutable.enqueue_offload(0).await.unwrap(); if let Some(priority) = immutable.metadata().offload_priority() {
immutable.enqueue_offload(priority).await.unwrap();
}
} }
immutable_blocks.push(immutable); immutable_blocks.push(immutable);
......
...@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> { ...@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent: Arc<Option<NixlAgent>>, nixl_agent: Arc<Option<NixlAgent>>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>, nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Arc<Option<BlockPool<DiskStorage, Metadata>>>, disk_pool: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>, host_pool: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>, device_pool: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
local_block_set: NixlBlockSet, local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>, remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
...@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout { let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() { if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks."); tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
(Arc::new(None), None) (None, None)
} else { } else {
next_block_set_idx += 1; next_block_set_idx += 1;
tracing::debug!("Constructing disk pool."); tracing::debug!("Constructing disk pool.");
...@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Arc::new(Some(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} }
} else { } else {
tracing::debug!("No disk layout provided; will not allocate disk blocks."); tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(Arc::new(None), None) (None, None)
}; };
// Create the host block pool if a host layout is provided // Create the host block pool if a host layout is provided
...@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Arc::new(Some(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} else { } else {
tracing::debug!("No host layout provided; will not allocate host blocks."); tracing::debug!("No host layout provided; will not allocate host blocks.");
(Arc::new(None), None) (None, None)
}; };
// Create the device block pool if a device layout is provided // Create the device block pool if a device layout is provided
...@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Arc::new(Some(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} else { } else {
tracing::debug!("No device layout provided; will not allocate device blocks."); tracing::debug!("No device layout provided; will not allocate device blocks.");
(Arc::new(None), None) (None, None)
}; };
// Finalize the local block set by adding NIXL metadata // Finalize the local block set by adding NIXL metadata
...@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
} }
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> { pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().as_ref() self.disk_pool.as_ref().map(|pool| pool.as_ref())
} }
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> { pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().as_ref() self.host_pool.as_ref().map(|pool| pool.as_ref())
} }
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> { pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref().as_ref() self.device_pool.as_ref().map(|pool| pool.as_ref())
} }
pub fn worker_id(&self) -> WorkerID { pub fn worker_id(&self) -> WorkerID {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
use super::*; use super::*;
use core::ffi::c_char;
use nix::fcntl::{fallocate, FallocateFlags}; use nix::fcntl::{fallocate, FallocateFlags};
use std::ffi::CString; use std::ffi::CString;
use std::fs::File; use std::fs::File;
...@@ -41,7 +42,7 @@ impl DiskStorage { ...@@ -41,7 +42,7 @@ impl DiskStorage {
let raw_fd = unsafe { let raw_fd = unsafe {
nix::libc::mkostemp( nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut i8, template_bytes.as_mut_ptr() as *mut c_char,
// For maximum performance, GPU DirectStorage requires O_DIRECT. // For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache. // This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned. // It also introduces the restriction that all accesses must be page-aligned.
...@@ -80,6 +81,7 @@ impl DiskStorage { ...@@ -80,6 +81,7 @@ impl DiskStorage {
impl Drop for DiskStorage { impl Drop for DiskStorage {
// TODO: How robust is this actually? // TODO: How robust is this actually?
fn drop(&mut self) { fn drop(&mut self) {
self.handles.release();
std::fs::remove_file(self.file_name.clone()).unwrap(); std::fs::remove_file(self.file_name.clone()).unwrap();
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment