Unverified Commit 6d9aac77 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: kvbm offload fixes and tests (#1191)

parent e5845b53
...@@ -38,12 +38,12 @@ use super::{ ...@@ -38,12 +38,12 @@ use super::{
WorkerID, WorkerID,
}; };
use derive_getters::Getters;
use std::{ use std::{
fmt::Debug, fmt::Debug,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
}; };
use thiserror::Error; use thiserror::Error;
mod private { mod private {
...@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> { ...@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
self.manager = Some(manager); self.manager = Some(manager);
} }
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> { pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
self.manager.as_ref() self.manager.as_ref()
} }
...@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider { ...@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>; fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>;
} }
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)] #[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)]
pub struct BasicMetadata { pub struct BasicMetadata {
#[getter(copy)]
priority: u32, priority: u32,
#[getter(copy)]
returned_tick: u64, returned_tick: u64,
#[getter(copy)]
acquired_tick: u64, acquired_tick: u64,
} }
impl BasicMetadata {
pub fn update_priority(&self, priority: u32) -> Self {
BasicMetadata {
priority,
returned_tick: self.returned_tick,
acquired_tick: self.acquired_tick,
}
}
}
impl BlockMetadata for BasicMetadata { impl BlockMetadata for BasicMetadata {
fn on_acquired(&mut self, tick: u64) { fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick; self.acquired_tick = tick;
...@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { ...@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
Self { block } Self { block }
} }
pub fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
// Access the underlying Block's manager field directly through deref
self.manager.as_ref()
}
pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> { pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> {
&self.block &self.block
} }
...@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> ...@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> { pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
// TODO: Is it ok to silently fail if the block is not managed?
if let Some(manager) = self.manager() { if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?; manager.enqueue_offload_block(self, priority).await?;
} else {
tracing::warn!("Block is not managed. Unable to enqueue offload.");
} }
Ok(()) Ok(())
} }
......
...@@ -28,6 +28,7 @@ use crate::block_manager::storage::{ ...@@ -28,6 +28,7 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use nixl_sys::XferOp::{Read, Write};
use std::future::Future; use std::future::Future;
use std::ops::Range; use std::ops::Range;
...@@ -77,6 +78,21 @@ pub enum TransferError { ...@@ -77,6 +78,21 @@ pub enum TransferError {
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NixlTransfer {
Read,
Write,
}
impl NixlTransfer {
pub fn as_xfer_op(&self) -> nixl_sys::XferOp {
match self {
NixlTransfer::Read => Read,
NixlTransfer::Write => Write,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy { pub enum TransferStrategy {
Memcpy, Memcpy,
...@@ -85,8 +101,7 @@ pub enum TransferStrategy { ...@@ -85,8 +101,7 @@ pub enum TransferStrategy {
CudaAsyncD2D, CudaAsyncD2D,
CudaBlockingH2D, CudaBlockingH2D,
CudaBlockingD2H, CudaBlockingD2H,
NixlWrite, // aka PUT Nixl(NixlTransfer),
NixlRead, // aka GET
Invalid, Invalid,
} }
...@@ -126,7 +141,7 @@ where ...@@ -126,7 +141,7 @@ where
{ {
#[inline(always)] #[inline(always)]
fn read_from_strategy() -> TransferStrategy { fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -179,8 +194,14 @@ where ...@@ -179,8 +194,14 @@ where
} }
Ok(()) Ok(())
} }
TransferStrategy::NixlWrite => { TransferStrategy::Nixl(transfer_type) => {
std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?); std::mem::drop(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?);
Ok(()) Ok(())
} }
_ => Err(TransferError::IncompatibleTypes(format!( _ => Err(TransferError::IncompatibleTypes(format!(
...@@ -196,8 +217,14 @@ where ...@@ -196,8 +217,14 @@ where
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> { ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() { if let TransferStrategy::Nixl(transfer_type) = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(self, dst, ctx, notify)?) Ok(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?)
} else { } else {
Err(TransferError::IncompatibleTypes(format!( Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}", "Expected NIXL transfer strategy, got: {:?}",
...@@ -626,7 +653,7 @@ mod tests { ...@@ -626,7 +653,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Pinned to ... // Pinned to ...
...@@ -644,7 +671,7 @@ mod tests { ...@@ -644,7 +671,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Device to ... // Device to ...
...@@ -662,7 +689,7 @@ mod tests { ...@@ -662,7 +689,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp}; use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList};
use std::future::{poll_fn, Future}; use std::future::{poll_fn, Future};
use std::task::Poll; use std::task::Poll;
...@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>( ...@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>(
dst: &mut [Destination], dst: &mut [Destination],
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
notify: Option<String>, notify: Option<String>,
transfer_type: NixlTransfer,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
Source: BlockDataProvider, Source: BlockDataProvider,
...@@ -127,8 +128,13 @@ where ...@@ -127,8 +128,13 @@ where
debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?); debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);
let xfer_req = let xfer_req = nixl_agent.create_xfer_req(
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?; transfer_type.as_xfer_op(),
&src_dl,
&dst_dl,
&nixl_agent.name(),
None,
)?;
let mut xfer_args = OptArgs::new()?; let mut xfer_args = OptArgs::new()?;
......
...@@ -21,35 +21,35 @@ use super::*; ...@@ -21,35 +21,35 @@ use super::*;
impl WriteToStrategy<DiskStorage> for DiskStorage { impl WriteToStrategy<DiskStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
impl WriteToStrategy<SystemStorage> for DiskStorage { impl WriteToStrategy<SystemStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<PinnedStorage> for DiskStorage { impl WriteToStrategy<PinnedStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<DeviceStorage> for DiskStorage { impl WriteToStrategy<DeviceStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<DiskStorage> for SystemStorage { impl WriteToStrategy<DiskStorage> for SystemStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage { ...@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
impl WriteToStrategy<DiskStorage> for PinnedStorage { impl WriteToStrategy<DiskStorage> for PinnedStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage { ...@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
impl WriteToStrategy<DiskStorage> for DeviceStorage { impl WriteToStrategy<DiskStorage> for DeviceStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage { ...@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage {
impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S { impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -170,7 +170,7 @@ where ...@@ -170,7 +170,7 @@ where
impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S { impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S {
#[inline(always)] #[inline(always)]
fn read_from_strategy() -> TransferStrategy { fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -198,7 +198,7 @@ mod tests { ...@@ -198,7 +198,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Pinned to ... // Pinned to ...
...@@ -216,7 +216,7 @@ mod tests { ...@@ -216,7 +216,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Device to ... // Device to ...
...@@ -234,7 +234,7 @@ mod tests { ...@@ -234,7 +234,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
...@@ -276,7 +276,7 @@ mod tests { ...@@ -276,7 +276,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Pinned to ... // Pinned to ...
...@@ -297,7 +297,7 @@ mod tests { ...@@ -297,7 +297,7 @@ mod tests {
assert_eq!( assert_eq!(
<PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Device to ... // Device to ...
...@@ -318,7 +318,7 @@ mod tests { ...@@ -318,7 +318,7 @@ mod tests {
assert_eq!( assert_eq!(
<DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
......
...@@ -259,12 +259,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -259,12 +259,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
} }
} }
// Allocate a block from the host pool.
// TODO: The most likely error here is that the target pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let target_blocks = match target_pool.allocate_blocks(1).await { let target_blocks = match target_pool.allocate_blocks(1).await {
Ok(blocks) => blocks, Ok(blocks) => blocks,
Err(_) => { Err(_) => {
tracing::warn!("Target pool full. Skipping offload. This should only ever happen with very small pool sizes.");
continue; continue;
} }
}; };
...@@ -451,6 +449,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -451,6 +449,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
self.disk_onboard_tx self.disk_onboard_tx
.send(OnboardRequest::new(disk_blocks, tx)) .send(OnboardRequest::new(disk_blocks, tx))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?; .map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
} else {
return Err(BlockPoolError::BlockError(BlockError::Other(
anyhow::anyhow!("Block type not supported for onboarding."),
)));
} }
match rx.await { match rx.await {
...@@ -466,12 +468,15 @@ mod tests { ...@@ -466,12 +468,15 @@ mod tests {
use crate::block_manager::block::test_utils::get_private_token; use crate::block_manager::block::test_utils::get_private_token;
use crate::block_manager::{ use crate::block_manager::{
block::{BasicMetadata, BlockDataExt, BlockDataProvider, BlockExt, Blocks, MutableBlock}, block::{
nixl::BlockHandleInfo, BasicMetadata, BlockDataExt, BlockDataProvider, BlockExt,
Blocks, MutableBlock,
},
layout::{nixl::NixlLayout, FullyContiguous}, layout::{nixl::NixlLayout, FullyContiguous},
pool::BlockPool, pool::BlockPool,
storage::{ storage::{
cuda::CudaAccessible, DeviceAllocator, DeviceStorage, DiskAllocator, DiskStorage, DeviceAllocator, DeviceStorage, DiskAllocator, DiskStorage, PinnedAllocator,
PinnedAllocator, PinnedStorage, StorageType, PinnedStorage, StorageType,
}, },
DType, LayoutConfig, DType, LayoutConfig,
}; };
...@@ -480,11 +485,12 @@ mod tests { ...@@ -480,11 +485,12 @@ mod tests {
use aligned_vec::avec; use aligned_vec::avec;
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind, cudaMemset}; use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind, cudaMemset};
use std::fs::File; use std::fs::File;
use std::io::{Read, Seek, SeekFrom}; use std::io::{Read, Seek, SeekFrom, Write};
use std::mem::ManuallyDrop; use std::mem::ManuallyDrop;
use std::os::unix::io::FromRawFd; use std::os::unix::io::FromRawFd;
const BLOCK_SIZE: usize = 4; const BLOCK_SIZE: usize = 4;
const NUM_LAYERS: usize = 8;
type DevicePool = Option<Arc<BlockPool<DeviceStorage, BasicMetadata>>>; type DevicePool = Option<Arc<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Option<Arc<BlockPool<PinnedStorage, BasicMetadata>>>; type HostPool = Option<Arc<BlockPool<PinnedStorage, BasicMetadata>>>;
...@@ -505,6 +511,7 @@ mod tests { ...@@ -505,6 +511,7 @@ mod tests {
device_blocks: usize, device_blocks: usize,
host_blocks: Option<usize>, host_blocks: Option<usize>,
disk_blocks: Option<usize>, disk_blocks: Option<usize>,
inner_dim: Option<usize>,
) -> Result<( ) -> Result<(
Arc<OffloadManager<BasicMetadata>>, Arc<OffloadManager<BasicMetadata>>,
DevicePool, DevicePool,
...@@ -513,10 +520,10 @@ mod tests { ...@@ -513,10 +520,10 @@ mod tests {
)> { )> {
let mut config = LayoutConfig { let mut config = LayoutConfig {
num_blocks: device_blocks, num_blocks: device_blocks,
num_layers: 8, num_layers: NUM_LAYERS,
outer_dim: 1, outer_dim: 1,
page_size: BLOCK_SIZE, page_size: BLOCK_SIZE,
inner_dim: 1024, inner_dim: inner_dim.unwrap_or(1024),
alignment: 1, alignment: 1,
dtype: DType::FP16, dtype: DType::FP16,
}; };
...@@ -602,21 +609,39 @@ mod tests { ...@@ -602,21 +609,39 @@ mod tests {
Ok(block) Ok(block)
} }
fn populate_cuda_block<S: Storage + CudaAccessible + NixlDescriptor>( fn populate_block<S: Storage + NixlDescriptor>(
block: &impl BlockDataProvider<StorageType = S>, block: &impl BlockDataProvider<StorageType = S>,
value: i32, value: u8,
) -> Result<()> { ) -> Result<()> {
let block_data = block.block_data(get_private_token()).block_view()?; let block_data = block.block_data(get_private_token());
let block_size = block_data.size(); let block_view = block_data.block_view()?;
let block_size = block_view.size();
unsafe {
cudaMemset( match block_data.storage_type() {
block_data.as_ptr() as *mut std::ffi::c_void, StorageType::Device(_) | StorageType::Pinned => unsafe {
value, cudaMemset(
block_size, block_view.as_ptr() as *mut std::ffi::c_void,
) value as i32,
.result()?; block_size,
)
.result()?;
},
StorageType::Disk => {
let nixl_desc = block_view.as_nixl_descriptor();
let mut file: ManuallyDrop<File>;
let data = avec![[4096] | value; block_size];
unsafe {
file = ManuallyDrop::new(File::from_raw_fd(nixl_desc.device_id() as i32));
file.seek(SeekFrom::Start(nixl_desc.as_ptr() as u64))?;
}
file.write_all(&data)?;
file.sync_all()?;
file.flush()?;
}
_ => panic!(),
} }
Ok(()) Ok(())
} }
...@@ -654,27 +679,31 @@ mod tests { ...@@ -654,27 +679,31 @@ mod tests {
file.read_exact(&mut aligned)?; file.read_exact(&mut aligned)?;
contents = aligned.to_vec(); contents = aligned.to_vec();
} }
_ => { _ => anyhow::bail!("Unsupported storage type."),
panic!();
}
} }
Ok(contents.to_vec()) Ok(contents.to_vec())
} }
/// Compare the contents of a device block and a host block. fn check_block_contents(
fn compare_block_contents(
block1: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>, block1: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
block2: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>, block2: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
value: u8,
) -> Result<()> { ) -> Result<()> {
assert_eq!(get_block_contents(block1)?, get_block_contents(block2)?); let contents1 = get_block_contents(block1)?;
let contents2 = get_block_contents(block2)?;
for (c1_value, c2_value) in contents1.iter().zip(contents2.iter()) {
if *c1_value != *c2_value || *c1_value != value {
panic!("{} != {} != {}", c1_value, c2_value, value);
}
}
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_offload_invalid_blocks() -> Result<()> { async fn test_offload_invalid_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
...@@ -707,7 +736,7 @@ mod tests { ...@@ -707,7 +736,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_offload_registered_blocks() -> Result<()> { async fn test_offload_registered_blocks() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -722,7 +751,7 @@ mod tests { ...@@ -722,7 +751,7 @@ mod tests {
.next() .next()
.ok_or(anyhow::anyhow!("Failed to register block"))?; .ok_or(anyhow::anyhow!("Failed to register block"))?;
populate_cuda_block(&immutable_device_block, 42)?; populate_block(&immutable_device_block, 42)?;
// Offloads should only go to G2 (for now) // Offloads should only go to G2 (for now)
offload_manager.offload(&immutable_device_block, 0).await?; offload_manager.offload(&immutable_device_block, 0).await?;
...@@ -743,14 +772,14 @@ mod tests { ...@@ -743,14 +772,14 @@ mod tests {
immutable_device_block.sequence_hash()? immutable_device_block.sequence_hash()?
); );
compare_block_contents(&immutable_device_block, &host_blocks[0])?; check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_no_host_blocks_available() -> Result<()> { async fn test_no_host_blocks_available() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -798,7 +827,7 @@ mod tests { ...@@ -798,7 +827,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_onboard() -> Result<()> { async fn test_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -812,7 +841,7 @@ mod tests { ...@@ -812,7 +841,7 @@ mod tests {
.next() .next()
.unwrap(); .unwrap();
populate_cuda_block(&immutable_host_block, 42)?; populate_block(&immutable_host_block, 42)?;
// Onboard the block. // Onboard the block.
let onboarded_blocks = offload_manager let onboarded_blocks = offload_manager
...@@ -831,7 +860,7 @@ mod tests { ...@@ -831,7 +860,7 @@ mod tests {
BlockState::Registered(_) BlockState::Registered(_)
)); ));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block)?; check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
// Wait for the new value to show up in the device pool. // Wait for the new value to show up in the device pool.
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
...@@ -845,14 +874,14 @@ mod tests { ...@@ -845,14 +874,14 @@ mod tests {
); );
// Check that this is the same block. // Check that this is the same block.
compare_block_contents(&device_blocks[0], &immutable_host_block)?; check_block_contents(&immutable_host_block, &device_blocks[0], 42)?;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_offload_onboard() -> Result<()> { async fn test_offload_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -865,7 +894,7 @@ mod tests { ...@@ -865,7 +894,7 @@ mod tests {
.next() .next()
.unwrap(); .unwrap();
populate_cuda_block(&immutable_device_block, 42)?; populate_block(&immutable_device_block, 42)?;
// Offload the block to the host. // Offload the block to the host.
offload_manager.offload(&immutable_device_block, 0).await?; offload_manager.offload(&immutable_device_block, 0).await?;
...@@ -880,7 +909,7 @@ mod tests { ...@@ -880,7 +909,7 @@ mod tests {
.next() .next()
.unwrap(); .unwrap();
compare_block_contents(&immutable_device_block, &immutable_host_block)?; check_block_contents(&immutable_device_block, &immutable_host_block, 42)?;
// Remove the device block from the pool by dropping it and allocating more blocks. // Remove the device block from the pool by dropping it and allocating more blocks.
drop(immutable_device_block); drop(immutable_device_block);
...@@ -914,14 +943,14 @@ mod tests { ...@@ -914,14 +943,14 @@ mod tests {
BlockState::Registered(_) BlockState::Registered(_)
)); ));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block)?; check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_onboard_err_handling() -> Result<()> { async fn test_onboard_err_handling() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?; let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -950,7 +979,7 @@ mod tests { ...@@ -950,7 +979,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_offload_onboard_no_host_blocks() -> Result<()> { async fn test_offload_onboard_no_host_blocks() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(4, None, None)?; let (offload_manager, device_pool, _, _) = build_pools(4, None, None, None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
...@@ -969,7 +998,7 @@ mod tests { ...@@ -969,7 +998,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_offload_disk() -> Result<()> { async fn test_offload_disk() -> Result<()> {
let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4))?; let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4), None)?;
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
...@@ -982,7 +1011,7 @@ mod tests { ...@@ -982,7 +1011,7 @@ mod tests {
.next() .next()
.unwrap(); .unwrap();
populate_cuda_block(&immutable_host_block, 42)?; populate_block(&immutable_host_block, 42)?;
offload_manager.offload(&immutable_host_block, 0).await?; offload_manager.offload(&immutable_host_block, 0).await?;
...@@ -997,14 +1026,14 @@ mod tests { ...@@ -997,14 +1026,14 @@ mod tests {
immutable_host_block.sequence_hash()? immutable_host_block.sequence_hash()?
); );
compare_block_contents(&disk_blocks[0], &immutable_host_block)?; check_block_contents(&immutable_host_block, &disk_blocks[0], 42)?;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn test_onboard_disk() -> Result<()> { async fn test_onboard_disk() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4))?; let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4), None)?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
...@@ -1017,10 +1046,14 @@ mod tests { ...@@ -1017,10 +1046,14 @@ mod tests {
.next() .next()
.unwrap(); .unwrap();
populate_block(&immutable_disk_block, 42)?;
let device_block = offload_manager let device_block = offload_manager
.onboard(vec![immutable_disk_block.clone()]) .onboard(vec![immutable_disk_block.clone()])
.await?; .await?;
check_block_contents(&immutable_disk_block, &device_block[0], 42)?;
assert_eq!(device_block.len(), 1); assert_eq!(device_block.len(), 1);
assert_eq!( assert_eq!(
device_block[0].sequence_hash()?, device_block[0].sequence_hash()?,
...@@ -1040,7 +1073,7 @@ mod tests { ...@@ -1040,7 +1073,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_bulk_transfer_disk() -> Result<()> { async fn test_bulk_transfer_disk() -> Result<()> {
let (offload_manager, device_pool, host_pool, disk_pool) = let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(8, Some(8), Some(8))?; build_pools(8, Some(8), Some(8), None)?;
let disk_pool = disk_pool.as_ref().unwrap(); let disk_pool = disk_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap(); let host_pool = host_pool.as_ref().unwrap();
...@@ -1050,7 +1083,7 @@ mod tests { ...@@ -1050,7 +1083,7 @@ mod tests {
for i in 0..8 { for i in 0..8 {
let block = completed_block(host_pool, [i; 4]).await?; let block = completed_block(host_pool, [i; 4]).await?;
populate_cuda_block(&block, i as i32)?; populate_block(&block, i as u8)?;
host_blocks.push(block); host_blocks.push(block);
} }
...@@ -1064,24 +1097,24 @@ mod tests { ...@@ -1064,24 +1097,24 @@ mod tests {
let mut disk_blocks = Vec::new(); let mut disk_blocks = Vec::new();
for host_block in &immutable_host_blocks { for (i, host_block) in immutable_host_blocks.iter().enumerate() {
let blocks = disk_pool let blocks = disk_pool
.match_sequence_hashes(vec![host_block.sequence_hash()?].as_slice()) .match_sequence_hashes(vec![host_block.sequence_hash()?].as_slice())
.await?; .await?;
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], host_block)?; check_block_contents(host_block, &blocks[0], i as u8)?;
disk_blocks.push(blocks[0].clone()); disk_blocks.push(blocks[0].clone());
} }
let device_blocks = offload_manager.onboard(disk_blocks.clone()).await?; let device_blocks = offload_manager.onboard(disk_blocks.clone()).await?;
assert_eq!(device_blocks.len(), disk_blocks.len()); assert_eq!(device_blocks.len(), disk_blocks.len());
for disk_block in &disk_blocks { for (i, disk_block) in disk_blocks.iter().enumerate() {
let blocks = device_pool let blocks = device_pool
.match_sequence_hashes(vec![disk_block.sequence_hash()?].as_slice()) .match_sequence_hashes(vec![disk_block.sequence_hash()?].as_slice())
.await?; .await?;
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], disk_block)?; check_block_contents(disk_block, &blocks[0], i as u8)?;
} }
Ok(()) Ok(())
...@@ -1093,6 +1126,7 @@ mod tests { ...@@ -1093,6 +1126,7 @@ mod tests {
2 * MAX_TRANSFER_BATCH_SIZE + 1, 2 * MAX_TRANSFER_BATCH_SIZE + 1,
None, None,
Some(2 * MAX_TRANSFER_BATCH_SIZE + 1), Some(2 * MAX_TRANSFER_BATCH_SIZE + 1),
None,
)?; )?;
let device_pool = device_pool.as_ref().unwrap(); let device_pool = device_pool.as_ref().unwrap();
...@@ -1101,7 +1135,9 @@ mod tests { ...@@ -1101,7 +1135,9 @@ mod tests {
let mut disk_blocks = Vec::new(); let mut disk_blocks = Vec::new();
for i in 0..2 * MAX_TRANSFER_BATCH_SIZE + 1 { for i in 0..2 * MAX_TRANSFER_BATCH_SIZE + 1 {
disk_blocks.push(completed_block(disk_pool, [i as u32; 4]).await?); let disk_block = completed_block(disk_pool, [i as u32; 4]).await?;
populate_block(&disk_block, i as u8)?;
disk_blocks.push(disk_block);
} }
let immutable_disk_blocks = disk_pool.register_blocks(disk_blocks).await?; let immutable_disk_blocks = disk_pool.register_blocks(disk_blocks).await?;
...@@ -1111,14 +1147,169 @@ mod tests { ...@@ -1111,14 +1147,169 @@ mod tests {
.await?; .await?;
assert_eq!(device_blocks.len(), 2 * MAX_TRANSFER_BATCH_SIZE + 1); assert_eq!(device_blocks.len(), 2 * MAX_TRANSFER_BATCH_SIZE + 1);
for device_block in &device_blocks { for (i, device_block) in device_blocks.iter().enumerate() {
let blocks = device_pool let blocks = device_pool
.match_sequence_hashes(vec![device_block.sequence_hash()?].as_slice()) .match_sequence_hashes(vec![device_block.sequence_hash()?].as_slice())
.await?; .await?;
check_block_contents(device_block, &blocks[0], i as u8)?;
assert_eq!(blocks.len(), 1); assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], device_block)?;
} }
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_onboard_unsupported_block_type() -> Result<()> {
let (offload_manager, device_pool, _, _) = build_pools(1, None, None, None)?;
let device_pool = device_pool.as_ref().unwrap();
let block = completed_block(device_pool, [0; 4]).await?;
let registered_block = device_pool
.register_blocks(vec![block])
.await?
.into_iter()
.next()
.unwrap();
let onboarded_blocks = offload_manager.onboard(vec![registered_block]).await;
assert!(matches!(
onboarded_blocks,
Err(BlockPoolError::BlockError(BlockError::Other(_)))
));
Ok(())
}
#[tokio::test]
async fn test_offload_transfer_metadata() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let mut device_block = completed_block(device_pool, [0; 4]).await?;
populate_block(&device_block, 42)?;
let new_metadata = device_block.metadata().update_priority(1);
device_block.update_metadata(new_metadata);
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
offload_manager.offload(&immutable_device_block, 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;
assert_eq!(host_blocks[0].metadata().priority(), 1);
Ok(())
}
#[tokio::test]
async fn test_onboard_duplicate() -> Result<()> {
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0; 4]).await?;
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
populate_block(&immutable_device_block, 42)?;
offload_manager.offload(&immutable_device_block, 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
let onboarded_blocks = offload_manager
.onboard(vec![host_blocks[0].clone()])
.await?;
assert_eq!(onboarded_blocks.len(), 1);
check_block_contents(&host_blocks[0], &onboarded_blocks[0], 42)?;
// This should be the same block that we put on the device.
// The block that was copied should be discarded by the block pool.
assert_eq!(
onboarded_blocks[0].block_idx(),
immutable_device_block.block_idx()
);
Ok(())
}
#[tokio::test]
async fn test_transfer_big_blocks() -> Result<()> {
// Try a block size of 32 MB.
let inner_dim = 2_usize.pow(20) * 32 / NUM_LAYERS / BLOCK_SIZE;
let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(2, Some(2), Some(2), Some(inner_dim))?;
let device_pool = device_pool.as_ref().unwrap();
let host_pool = host_pool.as_ref().unwrap();
let disk_pool = disk_pool.as_ref().unwrap();
let device_block = completed_block(device_pool, [0; 4]).await?;
populate_block(&device_block, 42)?;
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
// Offload to host.
offload_manager.offload(&immutable_device_block, 0).await?;
// Wait for the offload to be processed.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;
// Offload to disk
offload_manager.offload(&host_blocks[0], 0).await?;
// Wait for the offload to be processed.
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let disk_blocks = disk_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(disk_blocks.len(), 1);
check_block_contents(&host_blocks[0], &disk_blocks[0], 42)?;
// Onboard to device.
let device_blocks = offload_manager.onboard(disk_blocks.clone()).await?;
assert_eq!(device_blocks.len(), 1);
check_block_contents(&disk_blocks[0], &device_blocks[0], 42)?;
Ok(())
}
} }
...@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, ...@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source,
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_offload_request_key_ordering() {
let key1 = OffloadRequestKey {
priority: 1,
timestamp: 1,
};
let key2 = OffloadRequestKey {
priority: 2,
timestamp: 2,
};
assert!(key2 < key1);
let key3 = OffloadRequestKey {
priority: 2,
timestamp: 3,
};
assert!(key2 < key3);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment