"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "0c0336e6ea3b7d8cd87b382e33a025007e90d686"
Unverified Commit 6d9aac77 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: kvbm offload fixes and tests (#1191)

parent e5845b53
...@@ -38,12 +38,12 @@ use super::{ ...@@ -38,12 +38,12 @@ use super::{
WorkerID, WorkerID,
}; };
use derive_getters::Getters;
use std::{ use std::{
fmt::Debug, fmt::Debug,
ops::{Deref, DerefMut}, ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
}; };
use thiserror::Error; use thiserror::Error;
mod private { mod private {
...@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> { ...@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
self.manager = Some(manager); self.manager = Some(manager);
} }
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> { pub(crate) fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
self.manager.as_ref() self.manager.as_ref()
} }
...@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider { ...@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider {
fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>; fn block_data_mut(&mut self, _: private::PrivateToken) -> &mut BlockData<Self::StorageType>;
} }
#[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd)] #[derive(Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Getters)]
pub struct BasicMetadata { pub struct BasicMetadata {
#[getter(copy)]
priority: u32, priority: u32,
#[getter(copy)]
returned_tick: u64, returned_tick: u64,
#[getter(copy)]
acquired_tick: u64, acquired_tick: u64,
} }
impl BasicMetadata {
pub fn update_priority(&self, priority: u32) -> Self {
BasicMetadata {
priority,
returned_tick: self.returned_tick,
acquired_tick: self.acquired_tick,
}
}
}
impl BlockMetadata for BasicMetadata { impl BlockMetadata for BasicMetadata {
fn on_acquired(&mut self, tick: u64) { fn on_acquired(&mut self, tick: u64) {
self.acquired_tick = tick; self.acquired_tick = tick;
...@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { ...@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
Self { block } Self { block }
} }
pub fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
// Access the underlying Block's manager field directly through deref
self.manager.as_ref()
}
pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> { pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> {
&self.block &self.block
} }
...@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> ...@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> { pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
// TODO: Is it ok to silently fail if the block is not managed?
if let Some(manager) = self.manager() { if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?; manager.enqueue_offload_block(self, priority).await?;
} else {
tracing::warn!("Block is not managed. Unable to enqueue offload.");
} }
Ok(()) Ok(())
} }
......
...@@ -28,6 +28,7 @@ use crate::block_manager::storage::{ ...@@ -28,6 +28,7 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use nixl_sys::XferOp::{Read, Write};
use std::future::Future; use std::future::Future;
use std::ops::Range; use std::ops::Range;
...@@ -77,6 +78,21 @@ pub enum TransferError { ...@@ -77,6 +78,21 @@ pub enum TransferError {
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NixlTransfer {
Read,
Write,
}
impl NixlTransfer {
pub fn as_xfer_op(&self) -> nixl_sys::XferOp {
match self {
NixlTransfer::Read => Read,
NixlTransfer::Write => Write,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferStrategy { pub enum TransferStrategy {
Memcpy, Memcpy,
...@@ -85,8 +101,7 @@ pub enum TransferStrategy { ...@@ -85,8 +101,7 @@ pub enum TransferStrategy {
CudaAsyncD2D, CudaAsyncD2D,
CudaBlockingH2D, CudaBlockingH2D,
CudaBlockingD2H, CudaBlockingD2H,
NixlWrite, // aka PUT Nixl(NixlTransfer),
NixlRead, // aka GET
Invalid, Invalid,
} }
...@@ -126,7 +141,7 @@ where ...@@ -126,7 +141,7 @@ where
{ {
#[inline(always)] #[inline(always)]
fn read_from_strategy() -> TransferStrategy { fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -179,8 +194,14 @@ where ...@@ -179,8 +194,14 @@ where
} }
Ok(()) Ok(())
} }
TransferStrategy::NixlWrite => { TransferStrategy::Nixl(transfer_type) => {
std::mem::drop(nixl::write_blocks_to(self, dst, ctx, notify)?); std::mem::drop(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?);
Ok(()) Ok(())
} }
_ => Err(TransferError::IncompatibleTypes(format!( _ => Err(TransferError::IncompatibleTypes(format!(
...@@ -196,8 +217,14 @@ where ...@@ -196,8 +217,14 @@ where
notify: Option<String>, notify: Option<String>,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> { ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() { if let TransferStrategy::Nixl(transfer_type) = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(self, dst, ctx, notify)?) Ok(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?)
} else { } else {
Err(TransferError::IncompatibleTypes(format!( Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}", "Expected NIXL transfer strategy, got: {:?}",
...@@ -626,7 +653,7 @@ mod tests { ...@@ -626,7 +653,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Pinned to ... // Pinned to ...
...@@ -644,7 +671,7 @@ mod tests { ...@@ -644,7 +671,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Device to ... // Device to ...
...@@ -662,7 +689,7 @@ mod tests { ...@@ -662,7 +689,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp}; use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList};
use std::future::{poll_fn, Future}; use std::future::{poll_fn, Future};
use std::task::Poll; use std::task::Poll;
...@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>( ...@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>(
dst: &mut [Destination], dst: &mut [Destination],
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
notify: Option<String>, notify: Option<String>,
transfer_type: NixlTransfer,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
Source: BlockDataProvider, Source: BlockDataProvider,
...@@ -127,8 +128,13 @@ where ...@@ -127,8 +128,13 @@ where
debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?); debug_assert!(!src_dl.has_overlaps()? && !dst_dl.has_overlaps()?);
let xfer_req = let xfer_req = nixl_agent.create_xfer_req(
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)?; transfer_type.as_xfer_op(),
&src_dl,
&dst_dl,
&nixl_agent.name(),
None,
)?;
let mut xfer_args = OptArgs::new()?; let mut xfer_args = OptArgs::new()?;
......
...@@ -21,35 +21,35 @@ use super::*; ...@@ -21,35 +21,35 @@ use super::*;
impl WriteToStrategy<DiskStorage> for DiskStorage { impl WriteToStrategy<DiskStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
impl WriteToStrategy<SystemStorage> for DiskStorage { impl WriteToStrategy<SystemStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<PinnedStorage> for DiskStorage { impl WriteToStrategy<PinnedStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<DeviceStorage> for DiskStorage { impl WriteToStrategy<DeviceStorage> for DiskStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
impl WriteToStrategy<DiskStorage> for SystemStorage { impl WriteToStrategy<DiskStorage> for SystemStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage { ...@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
impl WriteToStrategy<DiskStorage> for PinnedStorage { impl WriteToStrategy<DiskStorage> for PinnedStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage { ...@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
impl WriteToStrategy<DiskStorage> for DeviceStorage { impl WriteToStrategy<DiskStorage> for DeviceStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage { ...@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage {
impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S { impl<S: Storage + Local> WriteToStrategy<NixlStorage> for S {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
} }
} }
...@@ -170,7 +170,7 @@ where ...@@ -170,7 +170,7 @@ where
impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S { impl<S: Storage + Local> ReadFromStrategy<NixlStorage> for S {
#[inline(always)] #[inline(always)]
fn read_from_strategy() -> TransferStrategy { fn read_from_strategy() -> TransferStrategy {
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
} }
} }
...@@ -198,7 +198,7 @@ mod tests { ...@@ -198,7 +198,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <SystemStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Pinned to ... // Pinned to ...
...@@ -216,7 +216,7 @@ mod tests { ...@@ -216,7 +216,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <PinnedStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Device to ... // Device to ...
...@@ -234,7 +234,7 @@ mod tests { ...@@ -234,7 +234,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
<DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(), <DeviceStorage as WriteToStrategy<NixlStorage>>::write_to_strategy(),
TransferStrategy::NixlWrite TransferStrategy::Nixl(NixlTransfer::Write)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
...@@ -276,7 +276,7 @@ mod tests { ...@@ -276,7 +276,7 @@ mod tests {
assert_eq!( assert_eq!(
<SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <SystemStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Pinned to ... // Pinned to ...
...@@ -297,7 +297,7 @@ mod tests { ...@@ -297,7 +297,7 @@ mod tests {
assert_eq!( assert_eq!(
<PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <PinnedStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Device to ... // Device to ...
...@@ -318,7 +318,7 @@ mod tests { ...@@ -318,7 +318,7 @@ mod tests {
assert_eq!( assert_eq!(
<DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(), <DeviceStorage as ReadFromStrategy<NixlStorage>>::read_from_strategy(),
TransferStrategy::NixlRead TransferStrategy::Nixl(NixlTransfer::Read)
); );
// Nixl to ... should fail to compile // Nixl to ... should fail to compile
......
This diff is collapsed.
...@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, ...@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source,
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_offload_request_key_ordering() {
let key1 = OffloadRequestKey {
priority: 1,
timestamp: 1,
};
let key2 = OffloadRequestKey {
priority: 2,
timestamp: 2,
};
assert!(key2 < key1);
let key3 = OffloadRequestKey {
priority: 2,
timestamp: 3,
};
assert!(key2 < key3);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment