"vscode:/vscode.git/clone" did not exist on "a0be38cb2b1050e715149244c53e8fc54bdbd9cd"
Unverified Commit 74221fd7 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Add support for SSD offloading in block manager (#1115)

parent 024422b9
......@@ -42,6 +42,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f"
[[package]]
name = "aligned-vec"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b"
dependencies = [
"equator",
]
[[package]]
name = "android-tzdata"
version = "0.1.1"
......@@ -522,7 +531,7 @@ dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
......@@ -1596,6 +1605,7 @@ name = "dynamo-llm"
version = "0.2.1"
dependencies = [
"akin",
"aligned-vec",
"anyhow",
"async-nats",
"async-openai",
......@@ -1622,10 +1632,12 @@ dependencies = [
"hf-hub",
"insta",
"itertools 0.14.0",
"lazy_static",
"memmap2",
"minijinja",
"minijinja-contrib",
"ndarray",
"nix 0.26.4",
"nixl-sys",
"oneshot",
"prometheus",
......@@ -1711,7 +1723,7 @@ dependencies = [
"local-ip-address",
"log",
"nid",
"nix",
"nix 0.29.0",
"nuid",
"once_cell",
"prometheus",
......@@ -1889,6 +1901,26 @@ dependencies = [
"log",
]
[[package]]
name = "equator"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc"
dependencies = [
"equator-macro",
]
[[package]]
name = "equator-macro"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "equivalent"
version = "1.0.2"
......@@ -3401,7 +3433,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if 1.0.0",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
......@@ -3661,6 +3693,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.1"
......@@ -4067,6 +4108,19 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.0",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]]
name = "nix"
version = "0.29.0"
......@@ -4900,7 +4954,7 @@ dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"memoffset 0.9.1",
"once_cell",
"portable-atomic",
"pyo3-build-config",
......
......@@ -39,6 +39,7 @@ MOUNT_WORKSPACE=
ENVIRONMENT_VARIABLES=
REMAINING_ARGS=
INTERACTIVE=
USE_NIXL_GDS=
get_options() {
while :; do
......@@ -142,6 +143,9 @@ get_options() {
--mount-workspace)
MOUNT_WORKSPACE=TRUE
;;
--use-nixl-gds)
USE_NIXL_GDS=TRUE
;;
--dry-run)
RUN_PREFIX="echo"
echo ""
......@@ -251,6 +255,12 @@ get_options() {
RM_STRING=" --rm "
fi
if [ -n "$USE_NIXL_GDS" ]; then
VOLUME_MOUNTS+=" -v /run/udev:/run/udev:ro "
NIXL_GDS_CAPS="--cap-add=IPC_LOCK"
else
NIXL_GDS_CAPS=""
fi
REMAINING_ARGS=("$@")
}
......@@ -264,6 +274,7 @@ show_help() {
echo " [--dry-run print docker commands without running]"
echo " [--hf-cache directory to volume mount as the hf cache, default is NONE unless mounting workspace]"
echo " [--gpus gpus to enable, default is 'all', 'none' disables gpu support]"
echo " [--use-nixl-gds add volume mounts and capabilities needed for NVIDIA GPUDirect Storage]"
echo " [-v add volume mount]"
echo " [-e add environment variable]"
echo " [--mount-workspace set up for local development]"
......@@ -301,6 +312,7 @@ ${RUN_PREFIX} docker run \
${VOLUME_MOUNTS} \
-w /workspace \
--cap-add CAP_SYS_PTRACE \
${NIXL_GDS_CAPS} \
--ipc host \
${PRIVILEGED_STRING} \
${NAME_STRING} \
......
......@@ -1119,6 +1119,7 @@ dependencies = [
"minijinja",
"minijinja-contrib",
"ndarray",
"nix 0.26.4",
"nixl-sys",
"oneshot",
"prometheus",
......@@ -1190,7 +1191,7 @@ dependencies = [
"local-ip-address",
"log",
"nid",
"nix",
"nix 0.29.0",
"nuid",
"once_cell",
"prometheus",
......@@ -2564,6 +2565,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.1"
......@@ -2756,6 +2766,19 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.0",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]]
name = "nix"
version = "0.29.0"
......@@ -3361,7 +3384,7 @@ dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"memoffset 0.9.1",
"once_cell",
"portable-atomic",
"pyo3-build-config",
......
......@@ -30,7 +30,7 @@ default = []
testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
sentencepiece = ["dep:sentencepiece"]
[dependencies]
......@@ -80,6 +80,7 @@ rayon = "1"
nixl-sys = { version = "0.2.1-rc.3", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
ndarray = { version = "0.16", optional = true }
nix = { version = "0.26", optional = true }
# protocols
unicode-segmentation = "1.12"
......@@ -124,3 +125,5 @@ insta = { version = "1.41", features = [
"redactions",
"filters",
] }
aligned-vec = "0.6.4"
lazy_static = "1.4"
......@@ -36,13 +36,15 @@ pub use block::{
RemoteBlock,
},
transfer::{BlockTransferEngineV1, TransferRequestPut},
BasicMetadata, BlockMetadata, Blocks,
BasicMetadata, BlockMetadata, Blocks, ImmutableBlock,
};
pub use config::*;
pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType};
use offload::request::BlockResult;
pub use pool::BlockPool;
pub use storage::{
nixl::NixlRegisterableStorage, DeviceStorage, PinnedStorage, Storage, StorageAllocator,
nixl::NixlRegisterableStorage, DeviceStorage, DiskStorage, PinnedStorage, Storage,
StorageAllocator,
};
pub use tokio_util::sync::CancellationToken;
......@@ -143,6 +145,11 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
self.state.get_remote_blocks_mutable(bds)
}
/// Get a reference to the disk block pool
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.state.disk()
}
/// Get a reference to the host block pool
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.state.host()
......@@ -157,6 +164,13 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
pub fn worker_id(&self) -> WorkerID {
self.state.worker_id()
}
pub async fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
self.state.onboard_blocks(blocks).await
}
}
impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
......@@ -169,6 +183,8 @@ impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
mod tests {
use super::*;
use crate::block_manager::block::BlockExt;
use crate::tokens::Tokens;
use std::sync::atomic::{AtomicU64, Ordering};
// Atomic Counter for Worker ID
......@@ -180,6 +196,7 @@ mod tests {
.runtime(
KvManagerRuntimeConfig::builder()
.worker_id(worker_id)
.enable_nixl()
.build()
.unwrap(),
)
......@@ -191,6 +208,13 @@ mod tests {
.build()
.unwrap(),
)
.disk_layout(
KvManagerLayoutConfig::builder()
.num_blocks(16)
.allocator(storage::DiskAllocator)
.build()
.unwrap(),
)
.host_layout(
KvManagerLayoutConfig::builder()
.num_blocks(16)
......@@ -296,4 +320,44 @@ mod tests {
// // Execute the transfer request
// transfer_request.execute().unwrap();
}
#[tokio::test]
async fn test_offload() -> Result<()> {
dynamo_runtime::logging::init();
let block_manager = create_reference_block_manager();
let device = block_manager.device().unwrap();
let tokens = Tokens::from(vec![1, 2, 3, 4]);
let token_sequence = tokens.into_sequence(4, Some(0));
let token_block = token_sequence.blocks().first().unwrap();
let mut device_block = device.allocate_blocks(1).await?.into_iter().next().unwrap();
device_block.apply_token_block(token_block.clone())?;
let immutable_device_blocks = device.register_blocks(vec![device_block]).await.unwrap();
assert_eq!(immutable_device_blocks.len(), 1);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// It should now be on host and disk.
let host_blocks = block_manager
.host()
.unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice())
.await
.unwrap();
assert_eq!(host_blocks.len(), 1);
let disk_blocks = block_manager
.disk()
.unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice())
.await
.unwrap();
assert_eq!(disk_blocks.len(), 1);
Ok(())
}
}
......@@ -23,11 +23,12 @@ use super::*;
use crate::block_manager::storage::{
nixl::{NixlRegisterableStorage, NixlStorage},
DeviceStorage, PinnedStorage, SystemStorage,
DeviceStorage, DiskStorage, PinnedStorage, SystemStorage,
};
use cudarc::driver::CudaStream;
use std::future::Future;
use std::ops::Range;
pub use crate::block_manager::state::TransferContext;
......@@ -134,8 +135,18 @@ pub trait WriteTo<Target> {
&self,
dst: &mut Target,
notify: Option<String>,
ctx: &TransferContext,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError>;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn nixl_write_to(
&self,
dst: &mut Target,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
......@@ -146,7 +157,7 @@ where
&self,
dst: &mut WB,
notify: Option<String>,
ctx: &TransferContext,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError> {
match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
......@@ -155,7 +166,10 @@ where
| TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
}
TransferStrategy::NixlWrite => Ok(nixl::write_block_to(self, dst, ctx, notify)?),
TransferStrategy::NixlWrite => {
std::mem::drop(nixl::write_block_to(self, dst, ctx, notify)?);
Ok(())
}
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
RB::write_to_strategy()
......@@ -163,6 +177,22 @@ where
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
fn nixl_write_to(
&self,
dst: &mut WB,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
Ok(nixl::write_block_to(self, dst, ctx, notify)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
RB::write_to_strategy()
)))?
}
}
}
#[derive(Default)]
......
......@@ -17,15 +17,17 @@ use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::future::{poll_fn, Future};
use std::ops::Range;
use std::task::Poll;
/// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>(
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
ctx: Arc<TransferContext>,
notify: Option<String>,
) -> Result<()>
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
......@@ -34,8 +36,13 @@ where
let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
// Keep the arc to use in the returned future.
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
......@@ -57,8 +64,9 @@ where
)?;
}
let xfer_req =
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?;
let xfer_req = nixl_agent
.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
.unwrap();
let mut xfer_args = OptArgs::new()?;
......@@ -67,18 +75,27 @@ where
xfer_args.set_notification_message(notify.as_bytes())?;
}
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
// Return a future that completes when the transfer is complete.
// TODO: How efficient is this? Can we do better?
Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
}
});
})))
} else {
assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?;
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)
}
Ok(())
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
......@@ -86,9 +103,9 @@ pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>,
src: &'a Source,
dst: &'a mut Destination,
ctx: &TransferContext,
ctx: Arc<TransferContext>,
notify: Option<String>,
) -> Result<()>
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
Source: BlockDataProvider,
Destination: BlockDataProviderMut,
......@@ -96,9 +113,13 @@ where
let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
......@@ -149,13 +170,17 @@ where
Some(&xfer_args),
)?;
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| {
while status {
status = nixl_agent.get_xfer_status(&xfer_req).unwrap();
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
}
});
Ok(())
})))
}
......@@ -18,6 +18,41 @@
use super::*;
impl WriteToStrategy<DiskStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<PinnedStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<DeviceStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<DiskStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
......@@ -39,6 +74,13 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
}
}
impl WriteToStrategy<DiskStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
......@@ -60,6 +102,13 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
}
}
impl WriteToStrategy<DiskStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
......
......@@ -160,7 +160,13 @@ mod nixl {
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
self._block_data
.layout
.storage()
.into_iter()
.next()
.unwrap()
.device_id()
}
}
......@@ -184,7 +190,13 @@ mod nixl {
}
fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id()
self._block_data
.layout
.storage()
.into_iter()
.next()
.unwrap()
.device_id()
}
}
......
......@@ -37,6 +37,9 @@ pub struct KvManagerRuntimeConfig {
#[builder(default = "NixlOptions::Enabled")]
pub nixl: NixlOptions,
#[builder(default)]
pub async_runtime: Option<Arc<tokio::runtime::Runtime>>,
}
impl KvManagerRuntimeConfig {
......@@ -163,6 +166,10 @@ pub struct KvBlockManagerConfig {
/// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default, setter(strip_option))]
pub host_layout: Option<KvManagerLayoutConfig<PinnedStorage>>,
// Specific configuration for the disk layout
#[builder(default, setter(strip_option))]
pub disk_layout: Option<KvManagerLayoutConfig<DiskStorage>>,
}
impl KvBlockManagerConfig {
......
......@@ -13,200 +13,290 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, Notify};
use super::block::{
transfer::WriteTo, BlockError, BlockExt, BlockMetadata, BlockState, ImmutableBlock,
MutableBlock,
};
//! # Offload Manager
//! The offload manager is responsible for handling all block transfers between different cache levels.
//!
//! ## Offloading
//! Offloading is the process of moving blocks to a cache level further away from the device.
//! When blocks are registered (via [`BlockPool::register_blocks`]), they are automatically sent to the offload manager.
//! Due to limited bandwidth, the offload manager must prioritize which offloads to perform.
//! This is indicated by the `priority` parameter to [`OffloadManager::offload`].
//! When a offload request is received, the offload manager will enqueue it into a priority queue.
//! This priority queue is keyed by the `priority` parameter, where blocks with lower priority values are processed first.
//! Within the same priority, blocks that were sent to the offload manager earlier are processed first.
//!
//! ## Onboarding
//! Onboarding is the process of moving blocks to a cache level closer to the device.
//! All onboardings are manually triggered through the [`OffloadManager::onboard`] method.
//!
//! ## Transfer Managers
//! The offload manager uses two transfer managers to handle the offloading and onboarding of blocks.
//!
//! The [`CudaTransferManager`] is responsible for transfers between the device and host.
//! The [`DiskTransferManager`] is responsible for transfers from host to disk and disk to device.
//!
//! ## Worker Threads
//! The offload manager uses two kinds of worker threads to handle the offloading and onboarding of blocks.
//!
//! The [`OffloadManager::offload_worker`] is responsible for offloading blocks.
//! The [`OffloadManager::onboard_worker`] is responsible for onboarding blocks.
//!
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
use super::block::{BlockError, BlockMetadata, BlockState, ImmutableBlock};
use super::pool::BlockPoolError;
use super::state::TransferContext;
use super::storage::{Cuda, Storage};
use super::{BlockPool, DeviceStorage, PinnedStorage};
use super::{BlockPool, DeviceStorage, DiskStorage, PinnedStorage};
use nixl_sys::Agent as NixlAgent;
use std::sync::Arc;
use tokio::runtime::Handle;
use tokio::sync::{
mpsc::{self, error::TryRecvError},
Mutex,
};
use anyhow::Result;
use cudarc::driver::sys::CUevent_flags;
use std::any::Any;
use std::collections::BTreeSet;
mod pending;
mod request;
pub mod request;
use pending::{PendingTransfer, TransferManager};
use request::{OffloadRequest, OffloadRequestKey, OnboardRequest};
use pending::{CudaTransferManager, DiskTransferManager, PendingTransfer, TransferManager};
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
// TODO: This should be dynamic
const MAX_OFFLOAD_STREAM_DEPTH: usize = 4;
/// The offload manager handles all block transfers between different cache levels.
pub struct OffloadManager<Metadata: BlockMetadata> {
// Handles to the device and host pools.
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
// Handles to the device, host, and disk pools.
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
/// Priority queue of pending offloads
dtoh_offload_queue: Arc<Mutex<BTreeSet<OffloadRequest<DeviceStorage, Metadata>>>>,
/// Used to notify the offload worker that an item has been added to the priority queue
dtoh_offload_notify: Arc<Notify>,
/// An incrementing counter for offloaded blocks. Within the same priority, blocks with lower tick values are processed first.
tick: Arc<Mutex<u64>>,
/// Queue of offloading requests.
device_offload_tx: mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Metadata>>,
host_offload_tx: mpsc::UnboundedSender<OffloadRequest<PinnedStorage, Metadata>>,
/// Queue of pending onboarding requests.
htod_onboard_tx: mpsc::UnboundedSender<OnboardRequest<PinnedStorage, DeviceStorage, Metadata>>,
host_onboard_tx: mpsc::UnboundedSender<OnboardRequest<PinnedStorage, DeviceStorage, Metadata>>,
disk_onboard_tx: mpsc::UnboundedSender<OnboardRequest<DiskStorage, DeviceStorage, Metadata>>,
/// An incrementing counter for offloaded blocks. Within the same priority, blocks with lower tick values are processed first.
tick: Arc<Mutex<u64>>,
}
impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
pub fn new(
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
disk: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
nixl_agent: Arc<Option<NixlAgent>>,
async_rt_handle: Handle,
) -> Result<Arc<Self>> {
let dtoh_offload_queue = Arc::new(Mutex::new(BTreeSet::new()));
let dtoh_offload_notify = Arc::new(Notify::new());
let (htod_onboard_tx, htod_onboard_rx) = mpsc::unbounded_channel();
let (device_offload_tx, device_offload_rx) = mpsc::unbounded_channel();
let (host_offload_tx, host_offload_rx) = mpsc::unbounded_channel();
let (host_onboard_tx, host_onboard_rx) = mpsc::unbounded_channel();
let (disk_onboard_tx, disk_onboard_rx) = mpsc::unbounded_channel();
let this = Arc::new(Self {
device,
disk,
host,
dtoh_offload_queue,
dtoh_offload_notify,
device,
device_offload_tx,
host_offload_tx,
host_onboard_tx,
disk_onboard_tx,
tick: Arc::new(Mutex::new(0)),
htod_onboard_tx,
});
let this_clone = this.clone();
// The offload and onboard workers must run in separate streams.
// Otherwise, we'd only be doing either an offload or onboard at a time, cutting our effective transfer bandwidth in half.
tokio::spawn(async move { this_clone.offload_worker().await });
let this_clone = this.clone();
tokio::spawn(async move { this_clone.onboard_worker(htod_onboard_rx).await });
let cuda_ctx = Cuda::device_or_create(0)?;
Ok(this)
}
// We want cuda offloads to happen in parallel with host onboards, so we need to use a different stream.
let device_offload_transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(),
cuda_ctx.new_stream()?,
));
async fn update_target_metadata<Source: Storage, Target: Storage>(
source: &Arc<MutableBlock<Source, Metadata>>,
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
target.update_metadata(source.metadata().clone());
// Copy tokens
target.apply_token_block(reg_handle.token_block().clone())?;
} else {
Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)))?;
}
// Device -> Host offload
let device_clone = this.device.clone();
let host_clone = this.host.clone();
async_rt_handle.spawn(async move {
OffloadManager::offload_worker(
device_clone,
host_clone,
device_offload_rx,
Arc::new(CudaTransferManager::new(
device_offload_transfer_ctx,
MAX_OFFLOAD_STREAM_DEPTH,
)),
)
.await
.unwrap()
});
Ok(())
}
let transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(),
cuda_ctx.new_stream()?,
));
async fn offload_worker(&self) -> Result<()> {
// Since cuda memcpys in streams are async, this gets a bit tricky.
// We can't just consume the queue normally, otherwise the stream would become very backlogged.
// From the point when the a transfer is put into the stream until the transfer corresponding to the block is complete, we need to hold a strong reference to the block.
// If we don't do this, the block may be evicted and overwritten before the transfer is complete.
// To do this, we use a queue to track blocks currently being offloaded. Once the offload is complete (as indicated by a CudaEvent), the reference to the block is dropped.
// Host -> Disk offload
let host_clone = this.host.clone();
let disk_clone = this.disk.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::offload_worker(
host_clone,
disk_clone,
host_offload_rx,
Arc::new(DiskTransferManager::new(
transfer_ctx_clone,
MAX_OFFLOAD_STREAM_DEPTH,
)),
)
.await
.unwrap()
});
if self.device.is_none() || self.host.is_none() {
return Ok(());
}
// Host -> Device onboarding
let host_clone = this.host.clone();
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::onboard_worker(
host_clone,
device_clone,
host_onboard_rx,
Arc::new(CudaTransferManager::new(transfer_ctx_clone, 16384)),
)
.await
.unwrap()
});
let cuda_ctx = Cuda::device_or_create(0)?;
// Disk -> Device onboarding
let disk_clone = this.disk.clone();
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
OffloadManager::onboard_worker(
disk_clone,
device_clone,
disk_onboard_rx,
Arc::new(DiskTransferManager::new(transfer_ctx_clone, 16384)),
)
.await
.unwrap()
});
let transfer_ctx = TransferContext::new(None, cuda_ctx.new_stream()?);
Ok(this_clone)
}
let device = self.device.as_ref().as_ref().unwrap();
let host = self.host.as_ref().as_ref().unwrap();
async fn offload_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>,
mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> {
if source_pool_arc.is_none() || target_pool_arc.is_none() {
return Ok(());
}
let source_pool = source_pool_arc.as_ref().as_ref().unwrap();
let target_pool = target_pool_arc.as_ref().as_ref().unwrap();
// We don't want to hold too many strong references to blocks in the device pool, since it would limit our effective KV Cache capacity.
// In this case, we limit it to just enough to ensure that a transfer is always occurring.
let dtoh_pending_offload_manager = TransferManager::new(MAX_OFFLOAD_STREAM_DEPTH);
let mut queue = BTreeSet::new();
loop {
// Try to check the offload queue.
let request = self.dtoh_offload_queue.lock().await.pop_first();
loop {
match offload_rx.try_recv() {
Ok(request) => {
queue.insert(request);
}
Err(TryRecvError::Empty) => {
break;
}
Err(_) => return Ok(()),
}
}
// If there is a request, process it.
if let Some(request) = request {
if let Some(request) = queue.pop_first() {
// Try to upgrade the block to a strong reference.
let block = match request.block.upgrade() {
Some(block) => Some(block),
// If unable to upgrade, the block may have been moved to the inactive pool.
None => device
None => source_pool
.match_sequence_hashes(vec![request.sequence_hash].as_slice())
.await?
.pop()
.map(|block| block.mutable_block().clone()),
};
// If we've found the block, offload it to the host.
// If we've found the block, offload it.
if let Some(block) = block {
// If the block is already in the target, don't offload it.
if let Ok(blocks) = target_pool
.match_sequence_hashes_blocking(vec![request.sequence_hash].as_slice())
{
if !blocks.is_empty() {
continue;
}
}
// Allocate a block from the host pool.
// TODO: The most likely error here is that the host pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let host_blocks = match host.allocate_blocks(1).await {
let target_blocks = match target_pool.allocate_blocks(1).await {
Ok(blocks) => blocks,
Err(_) => {
continue;
}
};
if let Some(mut host_block) = host_blocks.into_iter().next() {
// Enqueue the offload into the stream.
block.write_to(&mut host_block, None, &transfer_ctx)?;
// Record an event after the transfer is complete. Use the BLOCKING_SYNC flag to ensure the event is recorded synchronously on the host.
let event = transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
// Update block metadata and register with host pool.
OffloadManager::update_target_metadata(&block, &mut host_block).await?;
// Record the pending offload. This may block if too many offloads are already pending.
dtoh_pending_offload_manager
.handle_pending_transfer(PendingTransfer::new(
if let Some(target_block) = target_blocks.into_iter().next() {
transfer_manager
.begin_transfer(PendingTransfer::new(
vec![block],
vec![host_block],
event,
vec![target_block],
None,
self.host.clone(),
target_pool_arc.clone(),
))
.await?;
}
}
} else {
// If the queue is empty, wait to be notified.
self.dtoh_offload_notify.notified().await;
// Await the next request.
if let Some(request) = offload_rx.recv().await {
queue.insert(request);
}
}
}
}
async fn onboard_worker(
&self,
mut htod_onboard_rx: mpsc::UnboundedReceiver<
OnboardRequest<PinnedStorage, DeviceStorage, Metadata>,
>,
async fn onboard_worker<Source: Storage, Target: Storage>(
source_pool_arc: Arc<Option<BlockPool<Source, Metadata>>>,
target_pool_arc: Arc<Option<BlockPool<Target, Metadata>>>,
mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
) -> Result<()> {
if self.device.is_none() || self.host.is_none() {
if source_pool_arc.is_none() || target_pool_arc.is_none() {
return Ok(());
}
let cuda_ctx = Cuda::device_or_create(0)?;
let transfer_ctx = TransferContext::new(None, cuda_ctx.new_stream()?);
// For the onboarding manager, we can get away with a much bigger queue, since any onboardings would get triggered by an upcoming prefill.
let htod_pending_onboard_manager = TransferManager::new(16384);
let device = self.device.as_ref().as_ref().unwrap();
let target_pool = target_pool_arc.as_ref().as_ref().unwrap();
while let Some(request) = htod_onboard_rx.recv().await {
let mut device_blocks = match device.allocate_blocks(request.blocks.len()).await {
// Loop on incoming requests
while let Some(request) = onboard_rx.recv().await {
// Try to allocate blocks on the device.
let target_blocks = match target_pool.allocate_blocks(request.blocks.len()).await {
Ok(blocks) => blocks,
Err(err) => {
request.response_tx.send(Err(err))?;
......@@ -214,30 +304,18 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
};
for (host_block, device_block) in request.blocks.iter().zip(device_blocks.iter_mut()) {
host_block.write_to(device_block, None, &transfer_ctx)?;
OffloadManager::update_target_metadata(host_block.mutable_block(), device_block)
.await?;
}
// Record an event after all transfers are complete. See use of CU_EVENT_BLOCKING_SYNC in offload_worker.
let event = transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
let sources = request
.blocks
.iter()
.map(|b| b.mutable_block().clone())
.collect();
htod_pending_onboard_manager
.handle_pending_transfer(PendingTransfer::new(
transfer_manager
.begin_transfer(PendingTransfer::new(
sources,
device_blocks,
event,
target_blocks,
Some(request.response_tx),
self.device.clone(),
target_pool_arc.clone(),
))
.await?;
}
......@@ -257,23 +335,28 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
)));
}
}
let mut tick = self.tick.lock().await;
let key = OffloadRequestKey {
priority,
timestamp: *tick,
};
// Increment a counter for each block. Within the same priority, blocks with lower counter values are processed first.
*tick += 1;
drop(tick);
// This can get called by all pools, regardless of whether or not they have a place to offload to.
// Because of this, we need to check the block type here.
let any_block = block as &dyn Any;
// For now, only consider offloads from G1 (device) to G2 (host).
// TODO: What's the performance penalty of this runtime type-checking?
if let Some(device_block) =
any_block.downcast_ref::<ImmutableBlock<DeviceStorage, Metadata>>()
{
let mut tick = self.tick.lock().await;
let key = OffloadRequestKey {
priority,
timestamp: *tick,
};
// Increment a counter for each block. Within the same priority, blocks with lower counter values are processed first.
*tick += 1;
drop(tick);
// The host pool doesn't exist, so we can't offload to it.
if self.device_offload_tx.is_closed() {
return Ok(());
}
let request = OffloadRequest {
block: Arc::downgrade(device_block.mutable_block()),
......@@ -281,17 +364,31 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
key,
};
self.dtoh_offload_queue.lock().await.insert(request);
self.dtoh_offload_notify.notify_one();
self.device_offload_tx.send(request).unwrap();
} else if let Some(host_block) =
any_block.downcast_ref::<ImmutableBlock<PinnedStorage, Metadata>>()
{
// The disk pool doesn't exist, so we can't offload to it.
if self.host_offload_tx.is_closed() {
return Ok(());
}
let request = OffloadRequest {
block: Arc::downgrade(host_block.mutable_block()),
sequence_hash: host_block.sequence_hash()?,
key,
};
self.host_offload_tx.send(request).unwrap();
}
Ok(())
}
pub async fn onboard(
pub async fn onboard<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
for block in &blocks {
match block.state() {
BlockState::Registered(_) => {}
......@@ -303,11 +400,51 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
if blocks.is_empty() {
return Ok(vec![]);
}
let (tx, rx) = oneshot::channel();
self.htod_onboard_tx
.send(OnboardRequest::new(blocks, tx))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
let any_block = blocks.first().unwrap() as &dyn Any;
// TODO: This is really ugly.
if any_block
.downcast_ref::<ImmutableBlock<PinnedStorage, Metadata>>()
.is_some()
{
let host_blocks = blocks
.iter()
.map(|b| {
(b as &dyn Any)
.downcast_ref::<ImmutableBlock<PinnedStorage, Metadata>>()
.unwrap()
.clone()
})
.collect();
self.host_onboard_tx
.send(OnboardRequest::new(host_blocks, tx))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
} else if any_block
.downcast_ref::<ImmutableBlock<DiskStorage, Metadata>>()
.is_some()
{
let disk_blocks = blocks
.iter()
.map(|b| {
(b as &dyn Any)
.downcast_ref::<ImmutableBlock<DiskStorage, Metadata>>()
.unwrap()
.clone()
})
.collect();
self.disk_onboard_tx
.send(OnboardRequest::new(disk_blocks, tx))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
}
match rx.await {
Ok(res) => res,
Err(_) => Err(BlockPoolError::ProgressEngineShutdown),
......@@ -321,27 +458,51 @@ mod tests {
use crate::block_manager::block::test_utils::get_private_token;
use crate::block_manager::{
block::{BasicMetadata, BlockDataExt, BlockDataProvider, Blocks},
layout::FullyContiguous,
block::{BasicMetadata, BlockDataExt, BlockDataProvider, BlockExt, Blocks, MutableBlock},
layout::{nixl::NixlLayout, FullyContiguous},
pool::BlockPool,
storage::{
cuda::CudaAccessible, DeviceAllocator, DeviceStorage, PinnedAllocator, PinnedStorage,
cuda::CudaAccessible, DeviceAllocator, DeviceStorage, DiskAllocator, DiskStorage,
PinnedAllocator, PinnedStorage, StorageType,
},
DType, LayoutConfig,
};
use nixl_sys::NixlDescriptor;
use nixl_sys::{MemoryRegion, NixlDescriptor};
use aligned_vec::avec;
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind, cudaMemset};
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::mem::ManuallyDrop;
use std::os::unix::io::FromRawFd;
const BLOCK_SIZE: usize = 4;
type DevicePool = Arc<Option<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Arc<Option<BlockPool<PinnedStorage, BasicMetadata>>>;
type DiskPool = Arc<Option<BlockPool<DiskStorage, BasicMetadata>>>;
lazy_static::lazy_static! {
static ref NIXL_AGENT: Arc<Option<NixlAgent>> = {
let agent = NixlAgent::new("offload-manager").unwrap();
let (_, ucx_params) = agent.get_plugin_params("UCX").unwrap();
let (_, gds_params) = agent.get_plugin_params("GDS").unwrap();
agent.create_backend("UCX", &ucx_params).unwrap();
agent.create_backend("GDS", &gds_params).unwrap();
Arc::new(Some(agent))
};
}
fn build_pools(
device_blocks: usize,
host_blocks: Option<usize>,
) -> Result<(Arc<OffloadManager<BasicMetadata>>, DevicePool, HostPool)> {
disk_blocks: Option<usize>,
) -> Result<(
Arc<OffloadManager<BasicMetadata>>,
DevicePool,
HostPool,
DiskPool,
)> {
let mut config = LayoutConfig {
num_blocks: device_blocks,
num_layers: 8,
......@@ -351,22 +512,47 @@ mod tests {
dtype: DType::FP16,
};
let device = FullyContiguous::allocate(config.clone(), &DeviceAllocator::default())?;
let agent_arc = NIXL_AGENT.clone();
let agent = agent_arc.as_ref().as_ref().unwrap();
let mut device = FullyContiguous::allocate(config.clone(), &DeviceAllocator::default())?;
device.nixl_register(agent, None)?;
let device_blocks = Blocks::<_, BasicMetadata>::new(device, 42, 0)?.into_blocks()?;
let device_pool = Arc::new(Some(BlockPool::builder().blocks(device_blocks).build()?));
let host_pool = if let Some(host_blocks) = host_blocks {
config.num_blocks = host_blocks;
let host = FullyContiguous::allocate(config, &PinnedAllocator::default())?;
let mut host = FullyContiguous::allocate(config.clone(), &PinnedAllocator::default())?;
host.nixl_register(agent, None)?;
let host_blocks = Blocks::<_, BasicMetadata>::new(host, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(host_blocks).build()?))
} else {
Arc::new(None)
};
let manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?;
let disk_pool = if let Some(disk_blocks) = disk_blocks {
config.num_blocks = disk_blocks;
let mut disk = FullyContiguous::allocate(config, &DiskAllocator)?;
disk.nixl_register(agent, None)?;
let disk_blocks = Blocks::<_, BasicMetadata>::new(disk, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(disk_blocks).build()?))
} else {
Arc::new(None)
};
Ok((manager, device_pool, host_pool))
let async_rt_handle = Handle::current();
let manager = OffloadManager::new(
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
agent_arc,
async_rt_handle,
)?;
Ok((manager, device_pool, host_pool, disk_pool))
}
/// Create a block in the 'RESET' state.
......@@ -423,40 +609,61 @@ mod tests {
Ok(())
}
/// Compare the contents of a device block and a host block.
async fn compare_block_contents(
device_block: &impl BlockDataProvider<StorageType = DeviceStorage>,
host_block: &impl BlockDataProvider<StorageType = PinnedStorage>,
) -> Result<()> {
let host_data = host_block.block_data(get_private_token()).block_view()?;
let device_data = device_block.block_data(get_private_token()).block_view()?;
let size = host_data.size();
assert_eq!(size, device_data.size());
let mut host_buffer = vec![0u8; size];
let host_slice;
unsafe {
cudaMemcpy(
host_buffer.as_mut_ptr() as *mut std::ffi::c_void,
device_data.as_ptr() as *const std::ffi::c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
.result()?;
host_slice = std::slice::from_raw_parts(host_buffer.as_ptr(), size);
fn get_block_contents<S: Storage + NixlDescriptor>(
block: &impl BlockDataProvider<StorageType = S>,
) -> Result<Vec<u8>> {
let block_data = block.block_data(get_private_token());
let block_view = block_data.block_view()?;
let size = block_view.size();
let mut contents: Vec<u8> = vec![0; size];
match block_data.storage_type() {
StorageType::Device(_) => unsafe {
cudaMemcpy(
contents.as_mut_ptr() as *mut std::ffi::c_void,
block_view.as_ptr() as *const std::ffi::c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
.result()?;
},
StorageType::Pinned => unsafe {
contents = std::slice::from_raw_parts(block_view.as_ptr(), size).to_vec();
},
StorageType::Disk => {
let nixl_desc = block_view.as_nixl_descriptor();
let mut file: ManuallyDrop<File>;
let mut aligned = avec![[4096] | 0; size];
unsafe {
file = ManuallyDrop::new(File::from_raw_fd(nixl_desc.device_id() as i32));
file.seek(SeekFrom::Start(nixl_desc.as_ptr() as u64))?;
}
file.read_exact(&mut aligned)?;
contents = aligned.to_vec();
}
_ => {
panic!();
}
}
assert_eq!(host_buffer, host_slice);
Ok(contents.to_vec())
}
/// Compare the contents of a device block and a host block.
fn compare_block_contents(
block1: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
block2: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
) -> Result<()> {
assert_eq!(get_block_contents(block1)?, get_block_contents(block2)?);
Ok(())
}
#[tokio::test]
async fn test_offload_invalid_blocks() -> Result<()> {
let (offload_manager, device_pool, _) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
......@@ -489,7 +696,7 @@ mod tests {
#[tokio::test]
async fn test_offload_registered_blocks() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
......@@ -525,14 +732,14 @@ mod tests {
immutable_device_block.sequence_hash()?
);
compare_block_contents(&immutable_device_block, &host_blocks[0]).await?;
compare_block_contents(&immutable_device_block, &host_blocks[0])?;
Ok(())
}
#[tokio::test]
async fn test_no_host_blocks_available() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
......@@ -580,7 +787,7 @@ mod tests {
#[tokio::test]
async fn test_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
......@@ -613,7 +820,7 @@ mod tests {
BlockState::Registered(_)
));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;
compare_block_contents(&onboarded_blocks[0], &immutable_host_block)?;
// Wait for the new value to show up in the device pool.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
......@@ -627,14 +834,14 @@ mod tests {
);
// Check that this is the same block.
compare_block_contents(&device_blocks[0], &immutable_host_block).await?;
compare_block_contents(&device_blocks[0], &immutable_host_block)?;
Ok(())
}
#[tokio::test]
async fn test_offload_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
......@@ -662,7 +869,7 @@ mod tests {
.next()
.unwrap();
compare_block_contents(&immutable_device_block, &immutable_host_block).await?;
compare_block_contents(&immutable_device_block, &immutable_host_block)?;
// Remove the device block from the pool by dropping it and allocating more blocks.
drop(immutable_device_block);
......@@ -696,14 +903,14 @@ mod tests {
BlockState::Registered(_)
));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;
compare_block_contents(&onboarded_blocks[0], &immutable_host_block)?;
Ok(())
}
#[tokio::test]
async fn test_onboard_err_handling() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
......@@ -732,7 +939,7 @@ mod tests {
#[tokio::test]
async fn test_offload_onboard_no_host_blocks() -> Result<()> {
let (offload_manager, device_pool, _) = build_pools(4, None)?;
let (offload_manager, device_pool, _, _) = build_pools(4, None, None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
......@@ -748,4 +955,124 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_offload_disk() -> Result<()> {
let (offload_manager, _, host_pool, disk_pool) = build_pools(4, Some(4), Some(4))?;
let host_pool = host_pool.as_ref().as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool
.register_blocks(vec![host_block])
.await?
.into_iter()
.next()
.unwrap();
populate_cuda_block(&immutable_host_block, 42)?;
offload_manager.offload(&immutable_host_block, 0).await?;
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let disk_blocks = disk_pool
.match_sequence_hashes(vec![immutable_host_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(disk_blocks.len(), 1);
assert_eq!(
disk_blocks[0].sequence_hash()?,
immutable_host_block.sequence_hash()?
);
compare_block_contents(&disk_blocks[0], &immutable_host_block)?;
Ok(())
}
#[tokio::test]
async fn test_onboard_disk() -> Result<()> {
let (offload_manager, device_pool, _, disk_pool) = build_pools(4, None, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let disk_block = completed_block(disk_pool, [0, 1, 2, 3]).await?;
let immutable_disk_block = disk_pool
.register_blocks(vec![disk_block])
.await?
.into_iter()
.next()
.unwrap();
let device_block = offload_manager
.onboard(vec![immutable_disk_block.clone()])
.await?;
assert_eq!(device_block.len(), 1);
assert_eq!(
device_block[0].sequence_hash()?,
immutable_disk_block.sequence_hash()?
);
assert_eq!(
device_pool
.match_sequence_hashes(vec![immutable_disk_block.sequence_hash()?].as_slice())
.await?
.len(),
1
);
Ok(())
}
#[tokio::test]
async fn test_bulk_transfer_disk() -> Result<()> {
let (offload_manager, device_pool, host_pool, disk_pool) =
build_pools(8, Some(8), Some(8))?;
let disk_pool = disk_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_pool = device_pool.as_ref().as_ref().unwrap();
let mut host_blocks = Vec::new();
for i in 0..8 {
let block = completed_block(host_pool, [i; 4]).await?;
populate_cuda_block(&block, i as i32)?;
host_blocks.push(block);
}
let immutable_host_blocks = host_pool.register_blocks(host_blocks).await?;
for block in &immutable_host_blocks {
offload_manager.offload(block, 0).await?;
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let mut disk_blocks = Vec::new();
for host_block in &immutable_host_blocks {
let blocks = disk_pool
.match_sequence_hashes(vec![host_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], host_block)?;
disk_blocks.push(blocks[0].clone());
}
let device_blocks = offload_manager.onboard(disk_blocks.clone()).await?;
assert_eq!(device_blocks.len(), disk_blocks.len());
for disk_block in &disk_blocks {
let blocks = device_pool
.match_sequence_hashes(vec![disk_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(blocks.len(), 1);
compare_block_contents(&blocks[0], disk_block)?;
}
Ok(())
}
}
......@@ -13,32 +13,62 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Transfer Managers
//!
//! Transfer managers are responsible for multiple things:
//! - Before the transfer:
//! - Rate-limiting the number of transfers that can be initiated concurrently. This is implemented through bounded channels.
//! - Due to the nature of the [`super::OffloadManager`], we only apply this rate-limiting to offloads.
//! - During the transfer:
//! - Initiating the transfer
//! - Holding strong references to blocks being transfered.
//! - After the transfer:
//! - Dropping these references once the transfer is complete.
//! - Registering the blocks with the target pool.
//! - Returning the registered blocks to the caller.
//!
//! This is implemented through the [`TransferManager`] trait, which takes a single [`PendingTransfer`]
//! and initiates the transfer.
//!
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::begin_transfer`]
//! 2. [`TransferManager::begin_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use std::pin::Pin;
use std::sync::Arc;
use std::thread::spawn;
use tokio::sync::mpsc;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy},
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, WritableBlock,
};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
use crate::block_manager::state::TransferContext;
use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool;
use anyhow::Result;
use cudarc::driver::CudaEvent;
use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{future::join_all, stream::FuturesUnordered, StreamExt};
type OnboardResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
use super::BlockResult;
/// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from.
_sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
/// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>,
/// The Cuda event that indicates the completion of the transfer.
event: CudaEvent,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
/// The target pool that will receive the registered block.
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
......@@ -47,65 +77,221 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>,
event: CudaEvent,
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
) -> Self {
Self {
_sources: sources,
sources,
targets,
completion_indicator,
target_registration_pool,
}
}
fn handle_complete(self) -> Result<()> {
let Self {
targets,
event,
target_registration_pool,
completion_indicator,
target_pool,
..
} = self;
if let Some(target_registration_pool) = target_registration_pool.as_ref() {
let blocks = target_registration_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
Ok(())
}
}
fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
source: &Arc<MutableBlock<Source, Metadata>>,
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
target.update_metadata(source.metadata().clone());
// Copy tokens
target.apply_token_block(reg_handle.token_block().clone())?;
} else {
Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)))?;
}
Ok(())
}
#[async_trait]
pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata>:
Send + Sync
{
/// Begin a transfer. Blocks if the pending queue is full.
async fn begin_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()>;
}
pub struct TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<PendingTransfer<Source, Target, Metadata>>,
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>,
transfer_ctx: Arc<TransferContext>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
TransferManager<Source, Target, Metadata>
CudaTransferManager<Source, Target, Metadata>
{
pub fn new(max_depth: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<PendingTransfer<Source, Target, Metadata>>(max_depth);
pub fn new(transfer_ctx: Arc<TransferContext>, max_depth: usize) -> Self {
let (tx, mut rx) =
mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(max_depth);
spawn(move || {
while let Some(pending_transfer) = rx.blocking_recv() {
while let Some((pending_transfer, event)) = rx.blocking_recv() {
// Wait for the event.
pending_transfer.event.synchronize()?;
let PendingTransfer {
targets,
target_pool,
..
} = pending_transfer;
if let Some(target_pool) = target_pool.as_ref() {
// Register the blocks in the new pool only AFTER the transfers have been completed.
// This way, we maintain the invariant that blocks that are registered in a pool
// are always available in that pool.
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = pending_transfer.completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
event.synchronize()?;
// Only finalize the transfer after the event is signaled.
pending_transfer.handle_complete()?;
}
Ok::<(), anyhow::Error>(())
});
Self {
pending_transfer_q: tx,
transfer_ctx,
}
}
}
pub async fn handle_pending_transfer(
#[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata>
for CudaTransferManager<Source, Target, Metadata>
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
self.pending_transfer_q.send(pending_transfer).await?;
for (source, target) in pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
{
transfer_metadata(source, target)?;
source.write_to(target, None, self.transfer_ctx.clone())?;
}
// Use a cuda event to record the completion of the transfers.
let event = self
.transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q
.send((pending_transfer, event))
.await?;
Ok(())
}
}
pub struct DiskTransferManager {
futures_tx: mpsc::Sender<Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync>>>,
transfer_ctx: Arc<TransferContext>,
}
impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_size: usize) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move {
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new();
loop {
tokio::select! {
Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_size {
pending_transfers.next().await;
}
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future);
}
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
}
else => {
// Both branches are pending, wait for one to become ready
tokio::task::yield_now().await;
}
}
}
});
Self {
futures_tx,
transfer_ctx,
}
}
}
#[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata> for DiskTransferManager
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
let futures = pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
.map(|(source, target)| {
transfer_metadata(source, target).unwrap();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to(target, None, self.transfer_ctx.clone())
.unwrap()
})
.collect::<Vec<_>>();
let completion_future = async move {
let _ = join_all(futures).await;
pending_transfer.handle_complete().unwrap();
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
// this call will block until the worker has processed the prior future.
self.futures_tx.send(Box::pin(completion_future)).await?;
Ok(())
}
......
......@@ -57,6 +57,11 @@ impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> {
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {}
pub type BlockResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
/// Data needed for onboarding.
/// Unlike offloading, we need a means to return the resulting blocks to the caller.
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> {
pub blocks: Vec<ImmutableBlock<Source, M>>,
pub response_tx:
......
......@@ -19,23 +19,23 @@ use super::offload::OffloadManager;
use super::{
block::{Block, ImmutableBlock},
config::NixlOptions,
pool::BlockPoolError,
};
use cudarc::driver::CudaStream;
use std::sync::Arc;
use tokio::runtime::Handle;
pub struct TransferContext {
nixl_agent: Option<NixlAgent>,
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
}
impl TransferContext {
pub fn new(nixl_agent: Option<NixlAgent>, stream: Arc<CudaStream>) -> Self {
pub fn new(nixl_agent: Arc<Option<NixlAgent>>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream }
}
pub fn nixl_agent(&self) -> Option<&NixlAgent> {
self.nixl_agent.as_ref()
pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.clone()
}
pub fn stream(&self) -> &Arc<CudaStream> {
......@@ -48,9 +48,10 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID,
cancellation_token: CancellationToken,
nixl_agent: Option<NixlAgent>,
nixl_agent: Arc<Option<NixlAgent>>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
......@@ -77,21 +78,34 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = match config.runtime.nixl {
let nixl_agent = Arc::new(match config.runtime.nixl {
NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends");
let (_ucx_mem_list1, ucx_params) = agent.get_plugin_params("UCX")?;
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") {
let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend));
} else {
tracing::warn!("No UCX plugin found; will not create UCX backend");
}
if config.disk_layout.is_some() {
if let Ok((_, gds_params)) = agent.get_plugin_params("GDS") {
let backend = agent.create_backend("GDS", &gds_params)?;
nixl_backends.insert("GDS".to_string(), Arc::new(backend));
} else {
tracing::warn!("No GDS plugin found; will not create GDS backend");
}
}
Some(agent)
}
NixlOptions::EnabledWithAgent(agent) => Some(agent),
NixlOptions::Disabled => None,
};
});
// Initialize model-specific layout config. The layout_builder is incomplete at this point.
// We will clone this builder and apply the storage-specific configs to each clone in the
......@@ -108,11 +122,35 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
(Arc::new(None), None)
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Arc::new(Some(pool)), Some(blocks))
}
} else {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(Arc::new(None), None)
};
// Create the host block pool if a host layout is provided
let (host_pool, host_blocks) = if let Some(config) = config.host_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing host pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
......@@ -130,7 +168,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let (device_pool, device_blocks) = if let Some(config) = config.device_layout {
next_block_set_idx += 1;
tracing::debug!("Constructing device pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?;
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
......@@ -145,18 +184,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
};
// Finalize the local block set by adding NIXL metadata
if let Some(nixl_agent) = &nixl_agent {
if let Some(nixl_agent) = nixl_agent.as_ref() {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
}
let offload_manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?;
let offload_async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let offload_manager = OffloadManager::new(
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
nixl_agent.clone(),
offload_async_rt_handle,
)?;
let state = Arc::new(Self {
worker_id,
cancellation_token,
nixl_agent,
nixl_backends,
disk_pool,
host_pool,
device_pool,
local_block_set,
......@@ -164,6 +218,19 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
offload_manager,
});
if let Some(mut blocks) = disk_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.disk_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
......@@ -230,6 +297,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let agent = self
.nixl_agent
.as_ref()
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NIXL agent not initialized"))?;
let mut remote_block_sets = self.remote_block_sets.write().unwrap();
......@@ -344,6 +412,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(blocks)
}
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().as_ref()
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().as_ref()
}
......@@ -366,10 +438,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(())
}
pub async fn onboard_blocks(
pub async fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
self.offload_manager.onboard(blocks).await
}
}
......
......@@ -78,9 +78,11 @@
//! - [`StorageAllocator`] - Factory for creating storage instances
pub mod cuda;
pub mod disk;
pub mod nixl;
pub use cuda::*;
pub use disk::*;
use std::{
alloc::{alloc_zeroed, dealloc, Layout},
......@@ -107,6 +109,9 @@ pub enum StorageType {
/// CUDA page-locked host memory
Pinned,
/// Disk memory
Disk,
/// Remote memory accessible through NIXL
Nixl,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use nix::fcntl::{fallocate, FallocateFlags};
use std::ffi::CString;
use std::fs::File;
use std::os::unix::io::{AsRawFd, FromRawFd};
#[derive(Debug)]
pub struct DiskStorage {
file: File,
file_name: String,
size: usize,
handles: RegistrationHandles,
}
impl Local for DiskStorage {}
impl SystemAccessible for DiskStorage {}
impl DiskStorage {
pub fn new(size: usize) -> Result<Self, StorageError> {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let template = CString::new("/tmp/dynamo-kvbm-disk-cache-XXXXXX").unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let raw_fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut i8,
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned.
nix::libc::O_RDWR | nix::libc::O_DIRECT,
)
};
let file = unsafe { File::from_raw_fd(raw_fd) };
let file_name = String::from_utf8_lossy(&template_bytes)
.trim_end_matches("\0")
.to_string();
file.set_len(size as u64).map_err(|_| {
StorageError::AllocationFailed("Failed to set temp file size".to_string())
})?;
// File::set_len() only updates the metadata of the file, it does not allocate the underlying storage.
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(file.as_raw_fd(), FallocateFlags::empty(), 0, size as i64).map_err(|_| {
StorageError::AllocationFailed("Failed to allocate temp file".to_string())
})?;
Ok(Self {
file,
file_name,
size,
handles: RegistrationHandles::new(),
})
}
pub fn fd(&self) -> u64 {
self.file.as_raw_fd() as u64
}
}
impl Drop for DiskStorage {
// TODO: How robust is this actually?
fn drop(&mut self) {
std::fs::remove_file(self.file_name.clone()).unwrap();
}
}
impl Storage for DiskStorage {
fn storage_type(&self) -> StorageType {
StorageType::Disk
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
impl RegisterableStorage for DiskStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
#[derive(Default)]
pub struct DiskAllocator;
impl StorageAllocator<DiskStorage> for DiskAllocator {
fn allocate(&self, size: usize) -> Result<DiskStorage, StorageError> {
DiskStorage::new(size)
}
}
......@@ -82,8 +82,8 @@ use derive_getters::Getters;
use serde::{Deserialize, Serialize};
use super::{
CudaContextProivder, DeviceStorage, PinnedStorage, RegistationHandle, RegisterableStorage,
Remote, Storage, StorageError, StorageType, SystemStorage,
CudaContextProivder, DeviceStorage, DiskStorage, PinnedStorage, RegistationHandle,
RegisterableStorage, Remote, Storage, StorageError, StorageType, SystemStorage,
};
/// Marker trait for storage types that can be accessed by NIXL.
......@@ -104,17 +104,7 @@ impl StorageType {
StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown,
}
}
/// Get the NIXL device ID for a given storage type.
pub fn nixl_device_id(&self) -> u64 {
match self {
StorageType::System => 0,
StorageType::Pinned => 0,
StorageType::Device(id) => *id as u64,
StorageType::Nixl => 0,
StorageType::Null => 0,
StorageType::Disk => MemType::File,
}
}
}
......@@ -311,3 +301,27 @@ impl NixlDescriptor for DeviceStorage {
CudaContextProivder::cuda_context(self).cu_device() as u64
}
}
impl NixlAccessible for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {}
impl MemoryRegion for DiskStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DiskStorage {
fn mem_type(&self) -> MemType {
MemType::File
}
/// Nixl treats the file descriptor as the device ID.
fn device_id(&self) -> u64 {
self.fd()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment