Unverified Commit 6d9aac77 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: kvbm offload fixes and tests (#1191)

parent e5845b53
......@@ -38,12 +38,12 @@ use super::{
WorkerID,
};
use derive_getters::Getters;
use std::{
fmt::Debug,
ops::{Deref, DerefMut},
sync::Arc,
};
use thiserror::Error;
mod private {
......@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
self.manager = Some(manager);
}
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
self.manager.as_ref()
}
......@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>;
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)]
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)]
pub struct BasicMetadata {
#[getter(copy)]
priority: u32,
#[getter(copy)]
returned_tick: u64,
#[getter(copy)]
acquired_tick: u64,
}
impl BasicMetadata {
pub fn update_priority(&self, priority: u32) -> Self {
BasicMetadata {
priority,
returned_tick: self.returned_tick,
acquired_tick: self.acquired_tick,
}
}
}
impl BlockMetadata for BasicMetadata {
fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick;
......@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
Self { block }
}
pub fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
// Access the underlying Block's manager field directly through deref
self.manager.as_ref()
}
pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> {
&self.block
}
......@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
// TODO: Is it ok to silently fail if the block is not managed?
if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?;
} else {
tracing::warn!("Block is not managed. Unable to enqueue offload.");
}
Ok(())
}
......
......@@ -28,6 +28,7 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream;
use nixl_sys::XferOp::{Read, Write};
use std::future::Future;
use std::ops::Range;
......@@ -77,6 +78,21 @@ pub enum TransferError {
Other(#[from] anyhow::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NixlTransfer {
Read,
Write,
}
impl NixlTransfer {
pub fn as_xfer_op(&self) -> nixl_sys::XferOp {
match self {
NixlTransfer::Read => Read,
NixlTransfer::Write => Write,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy {
Memcpy,
......@@ -85,8 +101,7 @@ pub enum TransferStrategy {
CudaAsyncD2D,
CudaBlockingH2D,
CudaBlockingD2H,
NixlWrite, // aka PUT
NixlRead, // aka GET
Nixl(NixlTransfer),
Invalid,
}
......@@ -126,7 +141,7 @@ where
{
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
......@@ -179,8 +194,14 @@ where
}
Ok(())
}
TransferStrategy::NixlWrite => {
std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?);
TransferStrategy::Nixl(transfer_type) => {
std::mem::drop(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?);
Ok(())
}
_ => Err(TransferError::IncompatibleTypes(format!(
......@@ -196,8 +217,14 @@ where
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(self, dst, ctx, notify)?)
if let TransferStrategy::Nixl(transfer_type) = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
......@@ -626,7 +653,7 @@ mod tests {
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Pinned to ...
......@@ -644,7 +671,7 @@ mod tests {
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Device to ...
......@@ -662,7 +689,7 @@ mod tests {
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Nixl to ... should fail to compile
......
......@@ -16,7 +16,7 @@
use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList};
use std::future::{poll_fn, Future};
use std::task::Poll;
......@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>(
dst: &mut [Destination],
ctx: Arc<TransferContext>,
notify: Option<String>,
transfer_type: NixlTransfer,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
Source: BlockDataProvider,
......@@ -127,8 +128,13 @@ where
debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?;
let xfer_req = nixl_agent.create_xfer_req(
transfer_type.as_xfer_op(),
&src_dl,
&dst_dl,
&nixl_agent.name(),
None,
)?;
let mut xfer_args = OptArgs::new()?;
......
......@@ -21,35 +21,35 @@ use super::*;
impl WriteToStrategy<DiskStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
}
}
impl WriteToStrategy<SystemStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
impl WriteToStrategy<PinnedStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
impl WriteToStrategy<DeviceStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
impl WriteToStrategy<DiskStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
}
}
......@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
impl WriteToStrategy<DiskStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
}
}
......@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
impl WriteToStrategy<DiskStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
......@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage {
impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
}
}
......@@ -170,7 +170,7 @@ where
impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S {
#[inline(always)]
fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead
TransferStrategy::Nixl(NixlTransfer::Read)
}
}
......@@ -198,7 +198,7 @@ mod tests {
assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Pinned to ...
......@@ -216,7 +216,7 @@ mod tests {
);
assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Device to ...
......@@ -234,7 +234,7 @@ mod tests {
);
assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite
TransferStrategy::Nixl(NixlTransfer::Write)
);
// Nixl to ... should fail to compile
......@@ -276,7 +276,7 @@ mod tests {
assert_eq!(
<SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
TransferStrategy::Nixl(NixlTransfer::Read)
);
// Pinned to ...
......@@ -297,7 +297,7 @@ mod tests {
assert_eq!(
<PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
TransferStrategy::Nixl(NixlTransfer::Read)
);
// Device to ...
......@@ -318,7 +318,7 @@ mod tests {
assert_eq!(
<DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead
TransferStrategy::Nixl(NixlTransfer::Read)
);
// Nixl to ... should fail to compile
......
This diff is collapsed.
......@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_offload_request_key_ordering() {
let key1 = OffloadRequestKey {
priority: 1,
timestamp: 1,
};
let key2 = OffloadRequestKey {
priority: 2,
timestamp: 2,
};
assert!(key2 < key1);
let key3 = OffloadRequestKey {
priority: 2,
timestamp: 3,
};
assert!(key2 < key3);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment