Unverified Commit 5d5080ba authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Various KVBM improvements (#1134)

parent d3b0cae1
......@@ -192,11 +192,21 @@ mod tests {
fn create_reference_block_manager() -> ReferenceBlockManager {
let worker_id = WORKER_ID.fetch_add(1, Ordering::SeqCst);
// Check if we're already in a Tokio runtime context
let async_runtime = if tokio::runtime::Handle::try_current().is_ok() {
None // If we're already in a runtime, don't create a new one
} else {
// Only create a new runtime if not already in one
Some(Arc::new(tokio::runtime::Runtime::new().unwrap()))
};
let config = KvBlockManagerConfig::builder()
.runtime(
KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.enable_nixl()
.async_runtime(async_runtime)
.build()
.unwrap(),
)
......
......@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
/// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true
fn reset_metadata(&mut self);
/// The offload priority of the block. Higher priority blocks are offloaded first.
/// If the block should not be offloaded, return None.
fn offload_priority(&self) -> Option<u64>;
}
/// Marker trait for types that are mutable blocks
......@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
fn reset_metadata(&mut self) {
self.priority = 0;
}
fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64)
}
}
/// Collection that holds shared storage and layout
#[derive(Debug)]
......
......@@ -133,7 +133,7 @@ where
pub trait WriteTo<Target> {
fn write_to(
&self,
dst: &mut Target,
dst: &mut Vec<Target>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError>;
......@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
/// Returns a future that will complete when the transfer is complete.
fn nixl_write_to(
&self,
dst: &mut Target,
dst: &mut Vec<Target>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>>
where
RB: WriteToStrategy<WB> + Local,
{
fn write_to(
&self,
dst: &mut WB,
dst: &mut Vec<WB>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError> {
match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
match RB::write_to_strategy() {
TransferStrategy::Memcpy => {
for (src, dst) in self.iter().zip(dst.iter_mut()) {
memcpy::copy_block(src.as_ref(), dst)?;
}
Ok(())
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
for (src, dst) in self.iter().zip(dst.iter_mut()) {
cuda::copy_block(
src.as_ref(),
dst,
ctx.stream().as_ref(),
RB::write_to_strategy(),
)?;
}
Ok(())
}
TransferStrategy::NixlWrite => {
std::mem::drop(nixl::write_block_to(self, dst, ctx, notify)?);
std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?);
Ok(())
}
_ => Err(TransferError::IncompatibleTypes(format!(
......@@ -175,17 +188,16 @@ where
RB::write_to_strategy()
))),
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
fn nixl_write_to(
&self,
dst: &mut WB,
dst: &mut Vec<WB>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
Ok(nixl::write_block_to(self, dst, ctx, notify)?)
Ok(nixl::write_blocks_to(self, dst, ctx, notify)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
......
......@@ -18,16 +18,14 @@ use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::future::{poll_fn, Future};
use std::ops::Range;
use std::task::Poll;
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
src: &'a Source,
dst: &'a mut Destination,
ctx: Arc<TransferContext>,
notify: Option<String>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
fn append_xfer_request<Source, Destination>(
src: &Arc<Source>,
dst: &mut Destination,
src_dl: &mut XferDescList,
dst_dl: &mut XferDescList,
) -> Result<()>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
......@@ -36,17 +34,6 @@ where
let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
// Keep the arc to use in the returned future.
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
let src_desc = src_data.block_view()?.as_nixl_descriptor();
let dst_desc = dst_data.block_view_mut()?.as_nixl_descriptor_mut();
......@@ -64,74 +51,10 @@ where
)?;
}
let xfer_req = nixl_agent
.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
.unwrap();
let mut xfer_args = OptArgs::new()?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
// Return a future that completes when the transfer is complete.
// TODO: How efficient is this? Can we do better?
Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
}
})))
Ok(())
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)
}
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>,
src: &'a Source,
dst: &'a mut Destination,
ctx: Arc<TransferContext>,
notify: Option<String>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for layer_idx in layer_range {
for layer_idx in 0..src_data.num_layers() {
for outer_idx in 0..src_data.num_outer_dims() {
let src_view = src_data.layer_view(layer_idx, outer_idx)?;
let mut dst_view = dst_data.layer_view_mut(layer_idx, outer_idx)?;
......@@ -156,6 +79,56 @@ where
}
}
}
Ok(())
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_blocks_to<Source, Destination>(
src: &[Arc<Source>],
dst: &mut [Destination],
ctx: Arc<TransferContext>,
notify: Option<String>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
{
if src.is_empty() || dst.is_empty() {
return Ok(Box::new(std::future::ready(())));
}
assert_eq!(src.len(), dst.len());
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let src_mem_type = src
.first()
.unwrap()
.block_data(private::PrivateToken)
.storage_type()
.nixl_mem_type();
let dst_mem_type = dst
.first()
.unwrap()
.block_data(private::PrivateToken)
.storage_type()
.nixl_mem_type();
let mut src_dl = XferDescList::new(src_mem_type)?;
let mut dst_dl = XferDescList::new(dst_mem_type)?;
for (src, dst) in src.iter().zip(dst.iter_mut()) {
append_xfer_request(src, dst, &mut src_dl, &mut dst_dl)?;
}
debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?;
let mut xfer_args = OptArgs::new()?;
......@@ -164,14 +137,6 @@ where
xfer_args.set_notification_message(notify.as_bytes())?;
}
let xfer_req = nixl_agent.create_xfer_req(
XferOp::Write,
&src_dl,
&dst_dl,
&remote_worker_id,
Some(&xfer_args),
)?;
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
Ok(Box::new(poll_fn(move |_cx| {
......@@ -179,6 +144,8 @@ where
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
......
......@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
mod pending;
pub mod request;
use pending::{CudaTransferManager, DiskTransferManager, PendingTransfer, TransferManager};
use pending::{
CudaTransferManager, DiskTransferManager, PendingTransfer, TransferBatcher, TransferManager,
};
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
// TODO: This should be dynamic
const MAX_OFFLOAD_STREAM_DEPTH: usize = 4;
const MAX_CONCURRENT_TRANSFERS: usize = 4;
const MAX_TRANSFER_BATCH_SIZE: usize = 16;
/// The offload manager handles all block transfers between different cache levels.
pub struct OffloadManager<Metadata: BlockMetadata> {
// Handles to the device, host, and disk pools.
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
disk: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
/// Queue of offloading requests.
device_offload_tx: mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Metadata>>,
......@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
pub fn new(
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
disk: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
nixl_agent: Arc<Option<NixlAgent>>,
async_rt_handle: Handle,
) -> Result<Arc<Self>> {
......@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone();
let host_clone = this.host.clone();
async_rt_handle.spawn(async move {
OffloadManager::offload_worker(
let res = OffloadManager::offload_worker(
device_clone,
host_clone,
device_offload_rx,
Arc::new(CudaTransferManager::new(
device_offload_transfer_ctx,
MAX_OFFLOAD_STREAM_DEPTH,
Arc::new(TransferBatcher::new(
CudaTransferManager::new(device_offload_transfer_ctx, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await
.unwrap()
.await;
tracing::warn!("Offload worker terminated: {:?}", res);
});
let transfer_ctx = Arc::new(TransferContext::new(
......@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let disk_clone = this.disk.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::offload_worker(
let res = OffloadManager::offload_worker(
host_clone,
disk_clone,
host_offload_rx,
Arc::new(DiskTransferManager::new(
transfer_ctx_clone,
MAX_OFFLOAD_STREAM_DEPTH,
Arc::new(TransferBatcher::new(
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await
.unwrap()
.await;
tracing::warn!("Offload worker terminated: {:?}", res);
});
// Host -> Device onboarding
......@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::onboard_worker(
let res = OffloadManager::onboard_worker(
host_clone,
device_clone,
host_onboard_rx,
Arc::new(CudaTransferManager::new(transfer_ctx_clone, 16384)),
Arc::new(TransferBatcher::new(
CudaTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await
.unwrap()
.await;
tracing::warn!("Onboard worker terminated: {:?}", res);
});
// Disk -> Device onboarding
......@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::onboard_worker(
let res = OffloadManager::onboard_worker(
disk_clone,
device_clone,
disk_onboard_rx,
Arc::new(DiskTransferManager::new(transfer_ctx_clone, 16384)),
Arc::new(TransferBatcher::new(
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await
.unwrap()
.await;
tracing::warn!("Onboard worker terminated: {:?}", res);
});
Ok(this_clone)
}
async fn offload_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>,
source_pool: Option<Arc<BlockPool<Source, Metadata>>>,
target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> {
if source_pool_arc.is_none() || target_pool_arc.is_none() {
if source_pool.is_none() || target_pool.is_none() {
return Ok(());
}
let source_pool = source_pool_arc.as_ref().as_ref().unwrap();
let target_pool = target_pool_arc.as_ref().as_ref().unwrap();
let source_pool = source_pool.as_ref().unwrap();
let target_pool = target_pool.as_ref().unwrap();
let mut queue = BTreeSet::new();
......@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
// Allocate a block from the host pool.
// TODO: The most likely error here is that the host pool is full.
// TODO: The most likely error here is that the target pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let target_blocks = match target_pool.allocate_blocks(1).await {
Ok(blocks) => blocks,
......@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
if let Some(target_block) = target_blocks.into_iter().next() {
transfer_manager
.begin_transfer(PendingTransfer::new(
.enqueue_transfer(PendingTransfer::new(
vec![block],
vec![target_block],
None,
target_pool_arc.clone(),
target_pool.clone(),
))
.await?;
}
......@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
async fn onboard_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>,
source_pool: Option<Arc<BlockPool<Source, Metadata>>>,
target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> {
if source_pool_arc.is_none() || target_pool_arc.is_none() {
if source_pool.is_none() || target_pool.is_none() {
return Ok(());
}
let target_pool = target_pool_arc.as_ref().as_ref().unwrap();
let target_pool = target_pool.as_ref().unwrap();
// Loop on incoming requests
while let Some(request) = onboard_rx.recv().await {
......@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
.collect();
transfer_manager
.begin_transfer(PendingTransfer::new(
.enqueue_transfer(PendingTransfer::new(
sources,
target_blocks,
Some(request.response_tx),
target_pool_arc.clone(),
target_pool.clone(),
))
.await?;
}
......@@ -478,9 +486,9 @@ mod tests {
const BLOCK_SIZE: usize = 4;
type DevicePool = Arc<Option<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Arc<Option<BlockPool<PinnedStorage, BasicMetadata>>>;
type DiskPool = Arc<Option<BlockPool<DiskStorage, BasicMetadata>>>;
type DevicePool = Option<Arc<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Option<Arc<BlockPool<PinnedStorage, BasicMetadata>>>;
type DiskPool = Option<Arc<BlockPool<DiskStorage, BasicMetadata>>>;
lazy_static::lazy_static! {
static ref NIXL_AGENT: Arc<Option<NixlAgent>> = {
......@@ -521,16 +529,18 @@ mod tests {
device.nixl_register(agent, None)?;
let device_blocks = Blocks::<_, BasicMetadata>::new(device, 42, 0)?.into_blocks()?;
let device_pool = Arc::new(Some(BlockPool::builder().blocks(device_blocks).build()?));
let device_pool = Some(Arc::new(
BlockPool::builder().blocks(device_blocks).build()?,
));
let host_pool = if let Some(host_blocks) = host_blocks {
config.num_blocks = host_blocks;
let mut host = FullyContiguous::allocate(config.clone(), &PinnedAllocator::default())?;
host.nixl_register(agent, None)?;
let host_blocks = Blocks::<_, BasicMetadata>::new(host, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(host_blocks).build()?))
Some(Arc::new(BlockPool::builder().blocks(host_blocks).build()?))
} else {
Arc::new(None)
None
};
let disk_pool = if let Some(disk_blocks) = disk_blocks {
......@@ -538,9 +548,9 @@ mod tests {
let mut disk = FullyContiguous::allocate(config, &DiskAllocator)?;
disk.nixl_register(agent, None)?;
let disk_blocks = Blocks::<_, BasicMetadata>::new(disk, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(disk_blocks).build()?))
Some(Arc::new(BlockPool::builder().blocks(disk_blocks).build()?))
} else {
Arc::new(None)
None
};
let async_rt_handle = Handle::current();
......@@ -558,7 +568,7 @@ mod tests {
/// Create a block in the 'RESET' state.
async fn get_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
pool: &Arc<BlockPool<S, Metadata>>,
) -> Result<MutableBlock<S, Metadata>> {
pool.allocate_blocks(1)
.await?
......@@ -569,7 +579,7 @@ mod tests {
/// Create a block in the 'PARTIAL' state.
async fn partial_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
pool: &Arc<BlockPool<S, Metadata>>,
token: u32,
) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?;
......@@ -580,7 +590,7 @@ mod tests {
/// Create a block in the 'COMPLETED' state.
async fn completed_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
pool: &Arc<BlockPool<S, Metadata>>,
tokens: [u32; BLOCK_SIZE],
) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?;
......@@ -666,7 +676,7 @@ mod tests {
async fn test_offload_invalid_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
// Check blocks in the 'RESET' state.
let immutable_block = ImmutableBlock::new(Arc::new(get_block(device_pool).await?));
......@@ -699,8 +709,8 @@ mod tests {
async fn test_offload_registered_blocks() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
// Create a block and register it with the offload manager
let block = completed_block(device_pool, [0, 1, 2, 3]).await?;
......@@ -742,8 +752,8 @@ mod tests {
async fn test_no_host_blocks_available() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let host_blocks = host_pool.allocate_blocks(4).await?;
assert_eq!(host_blocks.len(), 4);
......@@ -790,8 +800,8 @@ mod tests {
async fn test_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
// Allocate and fill a block on the host.
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
......@@ -844,8 +854,8 @@ mod tests {
async fn test_offload_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
......@@ -913,8 +923,8 @@ mod tests {
async fn test_onboard_err_handling() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool
......@@ -942,7 +952,7 @@ mod tests {
async fn test_offload_onboard_no_host_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, None, None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
......@@ -961,8 +971,8 @@ mod tests {
async fn test_offload_disk() -> Result<()> {
let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4))?;
let host_pool = host_pool.as_ref().as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool
......@@ -996,8 +1006,8 @@ mod tests {
async fn test_onboard_disk() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let disk_block = completed_block(disk_pool, [0, 1, 2, 3]).await?;
let immutable_disk_block = disk_pool
......@@ -1032,9 +1042,9 @@ mod tests {
let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(8, Some(8), Some(8))?;
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let device_pool = device_pool.as_ref().unwrap();
let mut host_blocks = Vec::new();
......@@ -1076,4 +1086,39 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_transfer_batcher() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(
2 * MAX_TRANSFER_BATCH_SIZE + 1,
None,
Some(2 * MAX_TRANSFER_BATCH_SIZE + 1),
)?;
let device_pool = device_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let mut disk_blocks = Vec::new();
for i in 0..2 * MAX_TRANSFER_BATCH_SIZE + 1 {
disk_blocks.push(completed_block(disk_pool, [i as u32; 4]).await?);
}
let immutable_disk_blocks = disk_pool.register_blocks(disk_blocks).await?;
let device_blocks = offload_manager
.onboard(immutable_disk_blocks.clone())
.await?;
assert_eq!(device_blocks.len(), 2 * MAX_TRANSFER_BATCH_SIZE + 1);
for device_block in &device_blocks {
let blocks = device_pool
.match_sequence_hashes(vec![device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], device_block)?;
}
Ok(())
}
}
......@@ -33,11 +33,12 @@
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::begin_transfer`]
//! 2. [`TransferManager::begin_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 1. A transfer request is made by calling [`TransferManager::enqueue_transfer`]
//! 2. [`TransferManager::enqueue_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::thread::spawn;
......@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
use anyhow::Result;
use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{future::join_all, stream::FuturesUnordered, StreamExt};
use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult;
......@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
/// The target pool that will receive the registered block.
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
target_pool: Arc<BlockPool<Target, Metadata>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
......@@ -78,31 +79,35 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
target_pool: Arc<BlockPool<Target, Metadata>>,
) -> Self {
assert_eq!(sources.len(), targets.len());
Self {
sources,
targets,
completion_indicator,
target_registration_pool,
target_pool,
}
}
fn handle_complete(self) -> Result<()> {
let Self {
targets,
target_registration_pool,
sources,
mut targets,
target_pool,
completion_indicator,
..
} = self;
if let Some(target_registration_pool) = target_registration_pool.as_ref() {
let blocks = target_registration_pool.register_blocks_blocking(targets)?;
for (source, target) in sources.iter().zip(targets.iter_mut()) {
transfer_metadata(source, target)?;
}
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
Ok(())
}
......@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
Send + Sync
{
/// Begin a transfer. Blocks if the pending queue is full.
async fn begin_transfer(
async fn enqueue_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()>;
......@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata>
{
pub fn new(transfer_ctx: Arc<TransferContext>, max_depth: usize) -> Self {
let (tx, mut rx) =
mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(max_depth);
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(
max_concurrent_transfers,
);
spawn(move || {
while let Some((pending_transfer, event)) = rx.blocking_recv() {
// Wait for the event.
event.synchronize()?;
// Only finalize the transfer after the event is signaled.
pending_transfer.handle_complete()?;
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}
Ok::<(), anyhow::Error>(())
});
......@@ -183,18 +196,15 @@ where
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
async fn enqueue_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
for (source, target) in pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
{
transfer_metadata(source, target)?;
source.write_to(target, None, self.transfer_ctx.clone())?;
}
pending_transfer.sources.write_to(
&mut pending_transfer.targets,
None,
self.transfer_ctx.clone(),
)?;
// Use a cuda event to record the completion of the transfers.
let event = self
......@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
}
impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_size: usize) -> Self {
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move {
......@@ -230,7 +240,7 @@ impl DiskTransferManager {
tokio::select! {
Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_size {
while pending_transfers.len() >= max_concurrent_transfers {
pending_transfers.next().await;
}
// Once we have capacity, push the new future onto the queue.
......@@ -267,26 +277,26 @@ where
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
async fn enqueue_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
let futures = pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
.map(|(source, target)| {
transfer_metadata(source, target).unwrap();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to(target, None, self.transfer_ctx.clone())
.unwrap()
})
.collect::<Vec<_>>();
let future = pending_transfer.sources.nixl_write_to(
&mut pending_transfer.targets,
None,
self.transfer_ctx.clone(),
)?;
let completion_future = async move {
let _ = join_all(futures).await;
pending_transfer.handle_complete().unwrap();
let _ = future.await;
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
......@@ -296,3 +306,112 @@ where
Ok(())
}
}
/// A transfer manager that enforces a max batch size for transfers.
pub struct TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
transfer_manager: Manager,
max_transfer_batch_size: usize,
_phantom: PhantomData<(Source, Target, Metadata)>,
}
impl<Source, Target, Metadata, Manager> TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
pub fn new(transfer_manager: Manager, max_transfer_batch_size: usize) -> Self {
Self {
transfer_manager,
max_transfer_batch_size,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<Source, Target, Metadata, Manager> TransferManager<Source, Target, Metadata>
for TransferBatcher<Source, Target, Metadata, Manager>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
async fn enqueue_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
// If it's smaller than the max batch size, just enqueue it.
if pending_transfer.sources.len() < self.max_transfer_batch_size {
return self
.transfer_manager
.enqueue_transfer(pending_transfer)
.await;
}
// Otherwise, we need to split the transfer into multiple smaller transfers.
let PendingTransfer {
mut sources,
mut targets,
completion_indicator,
target_pool,
} = pending_transfer;
let mut indicators = Vec::new();
while !sources.is_empty() {
let sources = sources
.drain(..std::cmp::min(self.max_transfer_batch_size, sources.len()))
.collect();
let targets = targets
.drain(..std::cmp::min(self.max_transfer_batch_size, targets.len()))
.collect();
// If we have a completion indicator, we need to create a new one for each sub-transfer.
let indicator = if completion_indicator.is_some() {
let (batch_tx, batch_rx) = oneshot::channel();
indicators.push(batch_rx);
Some(batch_tx)
} else {
None
};
let request = PendingTransfer::new(sources, targets, indicator, target_pool.clone());
// Enqueue our reduced transfer. This may block if the queue is full.
self.transfer_manager.enqueue_transfer(request).await?;
}
if let Some(completion_indicator) = completion_indicator {
tokio::spawn(async move {
let mut results = Vec::new();
for indicator in indicators.into_iter() {
// Await each sub-transfer, and append the results to our final results.
let result = match indicator.await.unwrap() {
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
return;
}
};
results.extend(result);
}
// Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results)).unwrap();
});
}
Ok(())
}
}
......@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
#[derive(PartialEq, Eq, Ord, PartialOrd)]
/// Higher priority offloads are done first.
/// If two offloads have the same priority, the one that was requested first is done first.
#[derive(PartialEq, Eq)]
pub struct OffloadRequestKey {
pub priority: u64,
pub timestamp: u64,
}
impl PartialOrd for OffloadRequestKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OffloadRequestKey {
fn cmp(&self, other: &Self) -> Ordering {
other
.priority
.cmp(&self.priority)
.then(self.timestamp.cmp(&other.timestamp))
}
}
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
......
......@@ -518,6 +518,10 @@ pub(crate) mod tests {
fn reset_metadata(&mut self) {
self.priority = 0;
}
fn offload_priority(&self) -> Option<u64> {
Some(self.priority as u64)
}
}
type TestPriorityKey = PriorityKey<TestMetadata>;
......
......@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self.active.register(mutable)?;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if offload {
immutable.enqueue_offload(0).await.unwrap();
if let Some(priority) = immutable.metadata().offload_priority() {
immutable.enqueue_offload(priority).await.unwrap();
}
}
immutable_blocks.push(immutable);
......
......@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent: Arc<Option<NixlAgent>>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
disk_pool: Option<Arc<BlockPool<DiskStorage, Metadata>>>,
host_pool: Option<Arc<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
......@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
(Arc::new(None), None)
(None, None)
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
......@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(),
worker_id,
)?;
(Arc::new(Some(pool)), Some(blocks))
(Some(Arc::new(pool)), Some(blocks))
}
} else {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(Arc::new(None), None)
(None, None)
};
// Create the host block pool if a host layout is provided
......@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(),
worker_id,
)?;
(Arc::new(Some(pool)), Some(blocks))
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No host layout provided; will not allocate host blocks.");
(Arc::new(None), None)
(None, None)
};
// Create the device block pool if a device layout is provided
......@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(),
worker_id,
)?;
(Arc::new(Some(pool)), Some(blocks))
(Some(Arc::new(pool)), Some(blocks))
} else {
tracing::debug!("No device layout provided; will not allocate device blocks.");
(Arc::new(None), None)
(None, None)
};
// Finalize the local block set by adding NIXL metadata
......@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
}
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().as_ref()
self.disk_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().as_ref()
self.host_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref().as_ref()
self.device_pool.as_ref().map(|pool| pool.as_ref())
}
pub fn worker_id(&self) -> WorkerID {
......
......@@ -15,6 +15,7 @@
use super::*;
use core::ffi::c_char;
use nix::fcntl::{fallocate, FallocateFlags};
use std::ffi::CString;
use std::fs::File;
......@@ -41,7 +42,7 @@ impl DiskStorage {
let raw_fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut i8,
template_bytes.as_mut_ptr() as *mut c_char,
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned.
......@@ -80,6 +81,7 @@ impl DiskStorage {
impl Drop for DiskStorage {
// TODO: How robust is this actually?
fn drop(&mut self) {
self.handles.release();
std::fs::remove_file(self.file_name.clone()).unwrap();
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment