"vscode:/vscode.git/clone" did not exist on "57648c196130c7281de48412c5b8f71f7005ef1a"
Unverified Commit 74221fd7 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Add support for SSD offloading in block manager (#1115)

parent 024422b9
...@@ -42,6 +42,15 @@ version = "0.4.0" ...@@ -42,6 +42,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f" checksum = "1763692fc1416554cf051efc56a3de5595eca47299d731cc5c2b583adf8b4d2f"
[[package]]
name = "aligned-vec"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc890384c8602f339876ded803c97ad529f3842aba97f6392b3dba0dd171769b"
dependencies = [
"equator",
]
[[package]] [[package]]
name = "android-tzdata" name = "android-tzdata"
version = "0.1.1" version = "0.1.1"
...@@ -522,7 +531,7 @@ dependencies = [ ...@@ -522,7 +531,7 @@ dependencies = [
"bitflags 2.9.0", "bitflags 2.9.0",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"itertools 0.10.5", "itertools 0.13.0",
"log", "log",
"prettyplease", "prettyplease",
"proc-macro2", "proc-macro2",
...@@ -1596,6 +1605,7 @@ name = "dynamo-llm" ...@@ -1596,6 +1605,7 @@ name = "dynamo-llm"
version = "0.2.1" version = "0.2.1"
dependencies = [ dependencies = [
"akin", "akin",
"aligned-vec",
"anyhow", "anyhow",
"async-nats", "async-nats",
"async-openai", "async-openai",
...@@ -1622,10 +1632,12 @@ dependencies = [ ...@@ -1622,10 +1632,12 @@ dependencies = [
"hf-hub", "hf-hub",
"insta", "insta",
"itertools 0.14.0", "itertools 0.14.0",
"lazy_static",
"memmap2", "memmap2",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"ndarray", "ndarray",
"nix 0.26.4",
"nixl-sys", "nixl-sys",
"oneshot", "oneshot",
"prometheus", "prometheus",
...@@ -1711,7 +1723,7 @@ dependencies = [ ...@@ -1711,7 +1723,7 @@ dependencies = [
"local-ip-address", "local-ip-address",
"log", "log",
"nid", "nid",
"nix", "nix 0.29.0",
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
...@@ -1889,6 +1901,26 @@ dependencies = [ ...@@ -1889,6 +1901,26 @@ dependencies = [
"log", "log",
] ]
[[package]]
name = "equator"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4711b213838dfee0117e3be6ac926007d7f433d7bbe33595975d4190cb07e6fc"
dependencies = [
"equator-macro",
]
[[package]]
name = "equator-macro"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.2" version = "1.0.2"
...@@ -3401,7 +3433,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3401,7 +3433,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"windows-targets 0.48.5", "windows-targets 0.52.6",
] ]
[[package]] [[package]]
...@@ -3661,6 +3693,15 @@ version = "0.3.3" ...@@ -3661,6 +3693,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.9.1" version = "0.9.1"
...@@ -4067,6 +4108,19 @@ dependencies = [ ...@@ -4067,6 +4108,19 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.0",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.29.0" version = "0.29.0"
...@@ -4900,7 +4954,7 @@ dependencies = [ ...@@ -4900,7 +4954,7 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"indoc", "indoc",
"libc", "libc",
"memoffset", "memoffset 0.9.1",
"once_cell", "once_cell",
"portable-atomic", "portable-atomic",
"pyo3-build-config", "pyo3-build-config",
......
...@@ -39,6 +39,7 @@ MOUNT_WORKSPACE= ...@@ -39,6 +39,7 @@ MOUNT_WORKSPACE=
ENVIRONMENT_VARIABLES= ENVIRONMENT_VARIABLES=
REMAINING_ARGS= REMAINING_ARGS=
INTERACTIVE= INTERACTIVE=
USE_NIXL_GDS=
get_options() { get_options() {
while :; do while :; do
...@@ -142,6 +143,9 @@ get_options() { ...@@ -142,6 +143,9 @@ get_options() {
--mount-workspace) --mount-workspace)
MOUNT_WORKSPACE=TRUE MOUNT_WORKSPACE=TRUE
;; ;;
--use-nixl-gds)
USE_NIXL_GDS=TRUE
;;
--dry-run) --dry-run)
RUN_PREFIX="echo" RUN_PREFIX="echo"
echo "" echo ""
...@@ -251,6 +255,12 @@ get_options() { ...@@ -251,6 +255,12 @@ get_options() {
RM_STRING=" --rm " RM_STRING=" --rm "
fi fi
if [ -n "$USE_NIXL_GDS" ]; then
VOLUME_MOUNTS+=" -v /run/udev:/run/udev:ro "
NIXL_GDS_CAPS="--cap-add=IPC_LOCK"
else
NIXL_GDS_CAPS=""
fi
REMAINING_ARGS=("$@") REMAINING_ARGS=("$@")
} }
...@@ -264,6 +274,7 @@ show_help() { ...@@ -264,6 +274,7 @@ show_help() {
echo " [--dry-run print docker commands without running]" echo " [--dry-run print docker commands without running]"
echo " [--hf-cache directory to volume mount as the hf cache, default is NONE unless mounting workspace]" echo " [--hf-cache directory to volume mount as the hf cache, default is NONE unless mounting workspace]"
echo " [--gpus gpus to enable, default is 'all', 'none' disables gpu support]" echo " [--gpus gpus to enable, default is 'all', 'none' disables gpu support]"
echo " [--use-nixl-gds add volume mounts and capabilities needed for NVIDIA GPUDirect Storage]"
echo " [-v add volume mount]" echo " [-v add volume mount]"
echo " [-e add environment variable]" echo " [-e add environment variable]"
echo " [--mount-workspace set up for local development]" echo " [--mount-workspace set up for local development]"
...@@ -301,6 +312,7 @@ ${RUN_PREFIX} docker run \ ...@@ -301,6 +312,7 @@ ${RUN_PREFIX} docker run \
${VOLUME_MOUNTS} \ ${VOLUME_MOUNTS} \
-w /workspace \ -w /workspace \
--cap-add CAP_SYS_PTRACE \ --cap-add CAP_SYS_PTRACE \
${NIXL_GDS_CAPS} \
--ipc host \ --ipc host \
${PRIVILEGED_STRING} \ ${PRIVILEGED_STRING} \
${NAME_STRING} \ ${NAME_STRING} \
......
...@@ -1119,6 +1119,7 @@ dependencies = [ ...@@ -1119,6 +1119,7 @@ dependencies = [
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"ndarray", "ndarray",
"nix 0.26.4",
"nixl-sys", "nixl-sys",
"oneshot", "oneshot",
"prometheus", "prometheus",
...@@ -1190,7 +1191,7 @@ dependencies = [ ...@@ -1190,7 +1191,7 @@ dependencies = [
"local-ip-address", "local-ip-address",
"log", "log",
"nid", "nid",
"nix", "nix 0.29.0",
"nuid", "nuid",
"once_cell", "once_cell",
"prometheus", "prometheus",
...@@ -2564,6 +2565,15 @@ version = "0.3.3" ...@@ -2564,6 +2565,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "memoffset" name = "memoffset"
version = "0.9.1" version = "0.9.1"
...@@ -2756,6 +2766,19 @@ dependencies = [ ...@@ -2756,6 +2766,19 @@ dependencies = [
"thiserror 1.0.69", "thiserror 1.0.69",
] ]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if 1.0.0",
"libc",
"memoffset 0.7.1",
"pin-utils",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.29.0" version = "0.29.0"
...@@ -3361,7 +3384,7 @@ dependencies = [ ...@@ -3361,7 +3384,7 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
"indoc", "indoc",
"libc", "libc",
"memoffset", "memoffset 0.9.1",
"once_cell", "once_cell",
"portable-atomic", "portable-atomic",
"pyo3-build-config", "pyo3-build-config",
......
...@@ -30,7 +30,7 @@ default = [] ...@@ -30,7 +30,7 @@ default = []
testing-full = ["testing-cuda", "testing-nixl"] testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc"] testing-cuda = ["dep:cudarc"]
testing-nixl = ["dep:nixl-sys"] testing-nixl = ["dep:nixl-sys"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray"] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:ndarray", "dep:nix"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
[dependencies] [dependencies]
...@@ -80,6 +80,7 @@ rayon = "1" ...@@ -80,6 +80,7 @@ rayon = "1"
nixl-sys = { version = "0.2.1-rc.3", optional = true } nixl-sys = { version = "0.2.1-rc.3", optional = true }
cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true } cudarc = { version = "0.16.2", features = ["cuda-12020"], optional = true }
ndarray = { version = "0.16", optional = true } ndarray = { version = "0.16", optional = true }
nix = { version = "0.26", optional = true }
# protocols # protocols
unicode-segmentation = "1.12" unicode-segmentation = "1.12"
...@@ -124,3 +125,5 @@ insta = { version = "1.41", features = [ ...@@ -124,3 +125,5 @@ insta = { version = "1.41", features = [
"redactions", "redactions",
"filters", "filters",
] } ] }
aligned-vec = "0.6.4"
lazy_static = "1.4"
...@@ -36,13 +36,15 @@ pub use block::{ ...@@ -36,13 +36,15 @@ pub use block::{
RemoteBlock, RemoteBlock,
}, },
transfer::{BlockTransferEngineV1, TransferRequestPut}, transfer::{BlockTransferEngineV1, TransferRequestPut},
BasicMetadata, BlockMetadata, Blocks, BasicMetadata, BlockMetadata, Blocks, ImmutableBlock,
}; };
pub use config::*; pub use config::*;
pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType}; pub use layout::{nixl::NixlLayout, LayoutConfig, LayoutConfigBuilder, LayoutError, LayoutType};
use offload::request::BlockResult;
pub use pool::BlockPool; pub use pool::BlockPool;
pub use storage::{ pub use storage::{
nixl::NixlRegisterableStorage, DeviceStorage, PinnedStorage, Storage, StorageAllocator, nixl::NixlRegisterableStorage, DeviceStorage, DiskStorage, PinnedStorage, Storage,
StorageAllocator,
}; };
pub use tokio_util::sync::CancellationToken; pub use tokio_util::sync::CancellationToken;
...@@ -143,6 +145,11 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> { ...@@ -143,6 +145,11 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
self.state.get_remote_blocks_mutable(bds) self.state.get_remote_blocks_mutable(bds)
} }
/// Get a reference to the disk block pool
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.state.disk()
}
/// Get a reference to the host block pool /// Get a reference to the host block pool
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> { pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.state.host() self.state.host()
...@@ -157,6 +164,13 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> { ...@@ -157,6 +164,13 @@ impl<Metadata: BlockMetadata> KvBlockManager<Metadata> {
pub fn worker_id(&self) -> WorkerID { pub fn worker_id(&self) -> WorkerID {
self.state.worker_id() self.state.worker_id()
} }
pub async fn onboard_blocks<S: Storage>(
&self,
blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> BlockResult<DeviceStorage, Metadata> {
self.state.onboard_blocks(blocks).await
}
} }
impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> { impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
...@@ -169,6 +183,8 @@ impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> { ...@@ -169,6 +183,8 @@ impl<Metadata: BlockMetadata> Drop for KvBlockManager<Metadata> {
mod tests { mod tests {
use super::*; use super::*;
use crate::block_manager::block::BlockExt;
use crate::tokens::Tokens;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
// Atomic Counter for Worker ID // Atomic Counter for Worker ID
...@@ -180,6 +196,7 @@ mod tests { ...@@ -180,6 +196,7 @@ mod tests {
.runtime( .runtime(
KvManagerRuntimeConfig::builder() KvManagerRuntimeConfig::builder()
.worker_id(worker_id) .worker_id(worker_id)
.enable_nixl()
.build() .build()
.unwrap(), .unwrap(),
) )
...@@ -191,6 +208,13 @@ mod tests { ...@@ -191,6 +208,13 @@ mod tests {
.build() .build()
.unwrap(), .unwrap(),
) )
.disk_layout(
KvManagerLayoutConfig::builder()
.num_blocks(16)
.allocator(storage::DiskAllocator)
.build()
.unwrap(),
)
.host_layout( .host_layout(
KvManagerLayoutConfig::builder() KvManagerLayoutConfig::builder()
.num_blocks(16) .num_blocks(16)
...@@ -296,4 +320,44 @@ mod tests { ...@@ -296,4 +320,44 @@ mod tests {
// // Execute the transfer request // // Execute the transfer request
// transfer_request.execute().unwrap(); // transfer_request.execute().unwrap();
} }
#[tokio::test]
async fn test_offload() -> Result<()> {
dynamo_runtime::logging::init();
let block_manager = create_reference_block_manager();
let device = block_manager.device().unwrap();
let tokens = Tokens::from(vec![1, 2, 3, 4]);
let token_sequence = tokens.into_sequence(4, Some(0));
let token_block = token_sequence.blocks().first().unwrap();
let mut device_block = device.allocate_blocks(1).await?.into_iter().next().unwrap();
device_block.apply_token_block(token_block.clone())?;
let immutable_device_blocks = device.register_blocks(vec![device_block]).await.unwrap();
assert_eq!(immutable_device_blocks.len(), 1);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// It should now be on host and disk.
let host_blocks = block_manager
.host()
.unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice())
.await
.unwrap();
assert_eq!(host_blocks.len(), 1);
let disk_blocks = block_manager
.disk()
.unwrap()
.match_sequence_hashes(vec![immutable_device_blocks[0].sequence_hash()?].as_slice())
.await
.unwrap();
assert_eq!(disk_blocks.len(), 1);
Ok(())
}
} }
...@@ -23,11 +23,12 @@ use super::*; ...@@ -23,11 +23,12 @@ use super::*;
use crate::block_manager::storage::{ use crate::block_manager::storage::{
nixl::{NixlRegisterableStorage, NixlStorage}, nixl::{NixlRegisterableStorage, NixlStorage},
DeviceStorage, PinnedStorage, SystemStorage, DeviceStorage, DiskStorage, PinnedStorage, SystemStorage,
}; };
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use std::future::Future;
use std::ops::Range; use std::ops::Range;
pub use crate::block_manager::state::TransferContext; pub use crate::block_manager::state::TransferContext;
...@@ -134,8 +135,18 @@ pub trait WriteTo<Target> { ...@@ -134,8 +135,18 @@ pub trait WriteTo<Target> {
&self, &self,
dst: &mut Target, dst: &mut Target,
notify: Option<String>, notify: Option<String>,
ctx: &TransferContext, ctx: Arc<TransferContext>,
) -> Result<(), TransferError>; ) -> Result<(), TransferError>;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn nixl_write_to(
&self,
dst: &mut Target,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
...@@ -146,7 +157,7 @@ where ...@@ -146,7 +157,7 @@ where
&self, &self,
dst: &mut WB, dst: &mut WB,
notify: Option<String>, notify: Option<String>,
ctx: &TransferContext, ctx: Arc<TransferContext>,
) -> Result<(), TransferError> { ) -> Result<(), TransferError> {
match Self::write_to_strategy() { match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst), TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
...@@ -155,7 +166,10 @@ where ...@@ -155,7 +166,10 @@ where
| TransferStrategy::CudaAsyncD2D => { | TransferStrategy::CudaAsyncD2D => {
cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy()) cuda::copy_block(self, dst, ctx.stream().as_ref(), RB::write_to_strategy())
} }
TransferStrategy::NixlWrite => Ok(nixl::write_block_to(self, dst, ctx, notify)?), TransferStrategy::NixlWrite => {
std::mem::drop(nixl::write_block_to(self, dst, ctx, notify)?);
Ok(())
}
_ => Err(TransferError::IncompatibleTypes(format!( _ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}", "Unsupported copy strategy: {:?}",
RB::write_to_strategy() RB::write_to_strategy()
...@@ -163,6 +177,22 @@ where ...@@ -163,6 +177,22 @@ where
} }
// dispatch_copy_to(self, dst, self.transfer_context()) // dispatch_copy_to(self, dst, self.transfer_context())
} }
fn nixl_write_to(
&self,
dst: &mut WB,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::NixlWrite = RB::write_to_strategy() {
Ok(nixl::write_block_to(self, dst, ctx, notify)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
RB::write_to_strategy()
)))?
}
}
} }
#[derive(Default)] #[derive(Default)]
......
...@@ -17,15 +17,17 @@ use super::*; ...@@ -17,15 +17,17 @@ use super::*;
use anyhow::Result; use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp}; use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList, XferOp};
use std::future::{poll_fn, Future};
use std::ops::Range; use std::ops::Range;
use std::task::Poll;
/// Copy a block from a source to a destination using CUDA memcpy /// Copy a block from a source to a destination using CUDA memcpy
pub fn write_block_to<'a, Source, Destination>( pub fn write_block_to<'a, Source, Destination>(
src: &'a Source, src: &'a Source,
dst: &'a mut Destination, dst: &'a mut Destination,
ctx: &TransferContext, ctx: Arc<TransferContext>,
notify: Option<String>, notify: Option<String>,
) -> Result<()> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
...@@ -34,8 +36,13 @@ where ...@@ -34,8 +36,13 @@ where
let dst_data = dst.block_data_mut(private::PrivateToken); let dst_data = dst.block_data_mut(private::PrivateToken);
if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() { if src_data.is_fully_contiguous() && dst_data.is_fully_contiguous() {
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found"); // Keep the arc to use in the returned future.
let remote_worker_id = dst_data.worker_id.to_string(); let nixl_agent_arc = ctx.as_ref().nixl_agent();
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?; let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?; let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
...@@ -57,8 +64,9 @@ where ...@@ -57,8 +64,9 @@ where
)?; )?;
} }
let xfer_req = let xfer_req = nixl_agent
nixl_agent.create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &remote_worker_id, None)?; .create_xfer_req(XferOp::Write, &src_dl, &dst_dl, &nixl_agent.name(), None)
.unwrap();
let mut xfer_args = OptArgs::new()?; let mut xfer_args = OptArgs::new()?;
...@@ -67,18 +75,27 @@ where ...@@ -67,18 +75,27 @@ where
xfer_args.set_notification_message(notify.as_bytes())?; xfer_args.set_notification_message(notify.as_bytes())?;
} }
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?; let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| { // Return a future that completes when the transfer is complete.
while status { // TODO: How efficient is this? Can we do better?
status = nixl_agent.get_xfer_status(&xfer_req).unwrap(); Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
} }
}); })))
} else { } else {
assert_eq!(src_data.num_layers(), dst_data.num_layers()); assert_eq!(src_data.num_layers(), dst_data.num_layers());
write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)?; write_layers_to(0..src_data.num_layers(), src, dst, ctx, notify)
} }
Ok(())
} }
/// Copy a range of layers from a source to a destination using CUDA memcpy /// Copy a range of layers from a source to a destination using CUDA memcpy
...@@ -86,9 +103,9 @@ pub fn write_layers_to<'a, Source, Destination>( ...@@ -86,9 +103,9 @@ pub fn write_layers_to<'a, Source, Destination>(
layer_range: Range<usize>, layer_range: Range<usize>,
src: &'a Source, src: &'a Source,
dst: &'a mut Destination, dst: &'a mut Destination,
ctx: &TransferContext, ctx: Arc<TransferContext>,
notify: Option<String>, notify: Option<String>,
) -> Result<()> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
Source: BlockDataProvider, Source: BlockDataProvider,
Destination: BlockDataProviderMut, Destination: BlockDataProviderMut,
...@@ -96,9 +113,13 @@ where ...@@ -96,9 +113,13 @@ where
let src_data = src.block_data(private::PrivateToken); let src_data = src.block_data(private::PrivateToken);
let dst_data = dst.block_data_mut(private::PrivateToken); let dst_data = dst.block_data_mut(private::PrivateToken);
let nixl_agent = ctx.nixl_agent().expect("NIXL agent not found"); let nixl_agent_arc = ctx.as_ref().nixl_agent();
let remote_worker_id = dst_data.worker_id.to_string(); let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
let remote_worker_id = dst_data.worker_id.to_string();
let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?; let mut src_dl = XferDescList::new(src_data.storage_type().nixl_mem_type())?;
let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?; let mut dst_dl = XferDescList::new(dst_data.storage_type().nixl_mem_type())?;
...@@ -149,13 +170,17 @@ where ...@@ -149,13 +170,17 @@ where
Some(&xfer_args), Some(&xfer_args),
)?; )?;
let mut status = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?; let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
tracing::span!(tracing::Level::DEBUG, "Waiting for transfer to complete").in_scope(|| { Ok(Box::new(poll_fn(move |_cx| {
while status { let nixl_agent = nixl_agent_arc
status = nixl_agent.get_xfer_status(&xfer_req).unwrap(); .as_ref()
.as_ref()
.expect("NIXL agent not found");
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
} }
}); })))
Ok(())
} }
...@@ -18,6 +18,41 @@ ...@@ -18,6 +18,41 @@
use super::*; use super::*;
impl WriteToStrategy<DiskStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<PinnedStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<DeviceStorage> for DiskStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<DiskStorage> for SystemStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for SystemStorage { impl WriteToStrategy<SystemStorage> for SystemStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
...@@ -39,6 +74,13 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage { ...@@ -39,6 +74,13 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
} }
} }
impl WriteToStrategy<DiskStorage> for PinnedStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for PinnedStorage { impl WriteToStrategy<SystemStorage> for PinnedStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
...@@ -60,6 +102,13 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage { ...@@ -60,6 +102,13 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
} }
} }
impl WriteToStrategy<DiskStorage> for DeviceStorage {
#[inline(always)]
fn write_to_strategy() -> TransferStrategy {
TransferStrategy::NixlWrite
}
}
impl WriteToStrategy<SystemStorage> for DeviceStorage { impl WriteToStrategy<SystemStorage> for DeviceStorage {
#[inline(always)] #[inline(always)]
fn write_to_strategy() -> TransferStrategy { fn write_to_strategy() -> TransferStrategy {
......
...@@ -160,7 +160,13 @@ mod nixl { ...@@ -160,7 +160,13 @@ mod nixl {
} }
fn device_id(&self) -> u64 { fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id() self._block_data
.layout
.storage()
.into_iter()
.next()
.unwrap()
.device_id()
} }
} }
...@@ -184,7 +190,13 @@ mod nixl { ...@@ -184,7 +190,13 @@ mod nixl {
} }
fn device_id(&self) -> u64 { fn device_id(&self) -> u64 {
self._block_data.layout.storage_type().nixl_device_id() self._block_data
.layout
.storage()
.into_iter()
.next()
.unwrap()
.device_id()
} }
} }
......
...@@ -37,6 +37,9 @@ pub struct KvManagerRuntimeConfig { ...@@ -37,6 +37,9 @@ pub struct KvManagerRuntimeConfig {
#[builder(default = "NixlOptions::Enabled")] #[builder(default = "NixlOptions::Enabled")]
pub nixl: NixlOptions, pub nixl: NixlOptions,
#[builder(default)]
pub async_runtime: Option<Arc<tokio::runtime::Runtime>>,
} }
impl KvManagerRuntimeConfig { impl KvManagerRuntimeConfig {
...@@ -163,6 +166,10 @@ pub struct KvBlockManagerConfig { ...@@ -163,6 +166,10 @@ pub struct KvBlockManagerConfig {
/// This includes the number of blocks and the layout of the data into the host memory/storage. /// This includes the number of blocks and the layout of the data into the host memory/storage.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub host_layout: Option<KvManagerLayoutConfig<PinnedStorage>>, pub host_layout: Option<KvManagerLayoutConfig<PinnedStorage>>,
// Specific configuration for the disk layout
#[builder(default, setter(strip_option))]
pub disk_layout: Option<KvManagerLayoutConfig<DiskStorage>>,
} }
impl KvBlockManagerConfig { impl KvBlockManagerConfig {
......
This diff is collapsed.
...@@ -13,32 +13,62 @@ ...@@ -13,32 +13,62 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! # Transfer Managers
//!
//! Transfer managers are responsible for multiple things:
//! - Before the transfer:
//! - Rate-limiting the number of transfers that can be initiated concurrently. This is implemented through bounded channels.
//! - Due to the nature of the [`super::OffloadManager`], we only apply this rate-limiting to offloads.
//! - During the transfer:
//! - Initiating the transfer
//! - Holding strong references to blocks being transfered.
//! - After the transfer:
//! - Dropping these references once the transfer is complete.
//! - Registering the blocks with the target pool.
//! - Returning the registered blocks to the caller.
//!
//! This is implemented through the [`TransferManager`] trait, which takes a single [`PendingTransfer`]
//! and initiates the transfer.
//!
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::begin_transfer`]
//! 2. [`TransferManager::begin_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn; use std::thread::spawn;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock}; use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy},
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, WritableBlock,
};
use crate::block_manager::pool::BlockPoolError; use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage; use crate::block_manager::state::TransferContext;
use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool; use crate::block_manager::BlockPool;
use anyhow::Result; use anyhow::Result;
use cudarc::driver::CudaEvent; use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{future::join_all, stream::FuturesUnordered, StreamExt};
type OnboardResult<Target, Metadata> = use super::BlockResult;
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
/// Manage a set of pending transfers. /// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> { pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from. /// The block being copied from.
_sources: Vec<Arc<MutableBlock<Source, Metadata>>>, sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
/// The block being copied to. /// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>, targets: Vec<MutableBlock<Target, Metadata>>,
/// The Cuda event that indicates the completion of the transfer.
event: CudaEvent,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete. /// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
/// The target pool that will receive the registered block. /// The target pool that will receive the registered block.
target_pool: Arc<Option<BlockPool<Target, Metadata>>>, target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
} }
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...@@ -47,65 +77,221 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -47,65 +77,221 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub fn new( pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>, sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>, targets: Vec<MutableBlock<Target, Metadata>>,
event: CudaEvent, completion_indicator: Option<oneshot::Sender<BlockResult<Target, Metadata>>>,
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>, target_registration_pool: Arc<Option<BlockPool<Target, Metadata>>>,
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
) -> Self { ) -> Self {
Self { Self {
_sources: sources, sources,
targets,
completion_indicator,
target_registration_pool,
}
}
fn handle_complete(self) -> Result<()> {
let Self {
targets, targets,
event, target_registration_pool,
completion_indicator, completion_indicator,
target_pool, ..
} = self;
if let Some(target_registration_pool) = target_registration_pool.as_ref() {
let blocks = target_registration_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
Ok(())
} }
}
fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
source: &Arc<MutableBlock<Source, Metadata>>,
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
target.update_metadata(source.metadata().clone());
// Copy tokens
target.apply_token_block(reg_handle.token_block().clone())?;
} else {
Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)))?;
} }
Ok(())
} }
pub struct TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> { #[async_trait]
pending_transfer_q: mpsc::Sender<PendingTransfer<Source, Target, Metadata>>, pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata>:
Send + Sync
{
/// Begin a transfer. Blocks if the pending queue is full.
async fn begin_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()>;
}
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>,
transfer_ctx: Arc<TransferContext>,
} }
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
TransferManager<Source, Target, Metadata> CudaTransferManager<Source, Target, Metadata>
{ {
pub fn new(max_depth: usize) -> Self { pub fn new(transfer_ctx: Arc<TransferContext>, max_depth: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<PendingTransfer<Source, Target, Metadata>>(max_depth); let (tx, mut rx) =
mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(max_depth);
spawn(move || { spawn(move || {
while let Some(pending_transfer) = rx.blocking_recv() { while let Some((pending_transfer, event)) = rx.blocking_recv() {
// Wait for the event. // Wait for the event.
pending_transfer.event.synchronize()?; event.synchronize()?;
// Only finalize the transfer after the event is signaled.
pending_transfer.handle_complete()?;
}
Ok::<(), anyhow::Error>(())
});
let PendingTransfer { Self {
targets, pending_transfer_q: tx,
target_pool, transfer_ctx,
.. }
} = pending_transfer; }
}
if let Some(target_pool) = target_pool.as_ref() { #[async_trait]
// Register the blocks in the new pool only AFTER the transfers have been completed. impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata>
// This way, we maintain the invariant that blocks that are registered in a pool for CudaTransferManager<Source, Target, Metadata>
// are always available in that pool. where
let blocks = target_pool.register_blocks_blocking(targets)?; Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
for (source, target) in pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
{
transfer_metadata(source, target)?;
source.write_to(target, None, self.transfer_ctx.clone())?;
}
if let Some(completion_indicator) = pending_transfer.completion_indicator { // Use a cuda event to record the completion of the transfers.
completion_indicator.send(Ok(blocks))?; let event = self
.transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q
.send((pending_transfer, event))
.await?;
Ok(())
}
}
pub struct DiskTransferManager {
futures_tx: mpsc::Sender<Pin<Box<dyn std::future::Future<Output = ()> + Send + Sync>>>,
transfer_ctx: Arc<TransferContext>,
}
impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_size: usize) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move {
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new();
loop {
tokio::select! {
Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_size {
pending_transfers.next().await;
}
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future);
}
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
}
else => {
// Both branches are pending, wait for one to become ready
tokio::task::yield_now().await;
} }
} }
} }
Ok::<(), anyhow::Error>(())
}); });
Self { Self {
pending_transfer_q: tx, futures_tx,
transfer_ctx,
} }
} }
}
pub async fn handle_pending_transfer( #[async_trait]
impl<Source, Target, Metadata> TransferManager<Source, Target, Metadata> for DiskTransferManager
where
Source: Storage,
Target: Storage,
Metadata: BlockMetadata,
// Check that the source block is readable, local, and writable to the target block.
MutableBlock<Source, Metadata>: ReadableBlock<StorageType = Source>
+ Local
+ WriteToStrategy<MutableBlock<Target, Metadata>>,
// Check that the target block is writable.
MutableBlock<Target, Metadata>: WritableBlock<StorageType = Target>,
{
async fn begin_transfer(
&self, &self,
pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
self.pending_transfer_q.send(pending_transfer).await?; let futures = pending_transfer
.sources
.iter()
.zip(pending_transfer.targets.iter_mut())
.map(|(source, target)| {
transfer_metadata(source, target).unwrap();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to(target, None, self.transfer_ctx.clone())
.unwrap()
})
.collect::<Vec<_>>();
let completion_future = async move {
let _ = join_all(futures).await;
pending_transfer.handle_complete().unwrap();
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
// this call will block until the worker has processed the prior future.
self.futures_tx.send(Box::pin(completion_future)).await?;
Ok(()) Ok(())
} }
......
...@@ -57,6 +57,11 @@ impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> { ...@@ -57,6 +57,11 @@ impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> {
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {} impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {}
pub type BlockResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
/// Data needed for onboarding.
/// Unlike offloading, we need a means to return the resulting blocks to the caller.
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> { pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> {
pub blocks: Vec<ImmutableBlock<Source, M>>, pub blocks: Vec<ImmutableBlock<Source, M>>,
pub response_tx: pub response_tx:
......
...@@ -19,23 +19,23 @@ use super::offload::OffloadManager; ...@@ -19,23 +19,23 @@ use super::offload::OffloadManager;
use super::{ use super::{
block::{Block, ImmutableBlock}, block::{Block, ImmutableBlock},
config::NixlOptions, config::NixlOptions,
pool::BlockPoolError,
}; };
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use std::sync::Arc; use std::sync::Arc;
use tokio::runtime::Handle;
pub struct TransferContext { pub struct TransferContext {
nixl_agent: Option<NixlAgent>, nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>, stream: Arc<CudaStream>,
} }
impl TransferContext { impl TransferContext {
pub fn new(nixl_agent: Option<NixlAgent>, stream: Arc<CudaStream>) -> Self { pub fn new(nixl_agent: Arc<Option<NixlAgent>>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream } Self { nixl_agent, stream }
} }
pub fn nixl_agent(&self) -> Option<&NixlAgent> { pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.as_ref() self.nixl_agent.clone()
} }
pub fn stream(&self) -> &Arc<CudaStream> { pub fn stream(&self) -> &Arc<CudaStream> {
...@@ -48,9 +48,10 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> { ...@@ -48,9 +48,10 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID, worker_id: WorkerID,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
nixl_agent: Option<NixlAgent>, nixl_agent: Arc<Option<NixlAgent>>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>, nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
disk_pool: Arc<Option<BlockPool<DiskStorage, Metadata>>>,
host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>, host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>, device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
...@@ -77,21 +78,34 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -77,21 +78,34 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
// Create a NIXL agent if NIXL is enabled and instantiate requested backends // Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets // TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = match config.runtime.nixl { let nixl_agent = Arc::new(match config.runtime.nixl {
NixlOptions::Enabled => { NixlOptions::Enabled => {
tracing::debug!("Creating NIXL agent"); tracing::debug!("Creating NIXL agent");
let agent = NixlAgent::new(&worker_id.to_string())?; let agent = NixlAgent::new(&worker_id.to_string())?;
tracing::debug!("Creating NIXL backends"); tracing::debug!("Creating NIXL backends");
let (_ucx_mem_list1, ucx_params) = agent.get_plugin_params("UCX")?;
if let Ok((_, ucx_params)) = agent.get_plugin_params("UCX") {
let backend = agent.create_backend("UCX", &ucx_params)?; let backend = agent.create_backend("UCX", &ucx_params)?;
nixl_backends.insert("UCX".to_string(), Arc::new(backend)); nixl_backends.insert("UCX".to_string(), Arc::new(backend));
} else {
tracing::warn!("No UCX plugin found; will not create UCX backend");
}
if config.disk_layout.is_some() {
if let Ok((_, gds_params)) = agent.get_plugin_params("GDS") {
let backend = agent.create_backend("GDS", &gds_params)?;
nixl_backends.insert("GDS".to_string(), Arc::new(backend));
} else {
tracing::warn!("No GDS plugin found; will not create GDS backend");
}
}
Some(agent) Some(agent)
} }
NixlOptions::EnabledWithAgent(agent) => Some(agent), NixlOptions::EnabledWithAgent(agent) => Some(agent),
NixlOptions::Disabled => None, NixlOptions::Disabled => None,
}; });
// Initialize model-specific layout config. The layout_builder is incomplete at this point. // Initialize model-specific layout config. The layout_builder is incomplete at this point.
// We will clone this builder and apply the storage-specific configs to each clone in the // We will clone this builder and apply the storage-specific configs to each clone in the
...@@ -108,11 +122,35 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -108,11 +122,35 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let mut next_block_set_idx = 0; let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id); let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
(Arc::new(None), None)
} else {
next_block_set_idx += 1;
tracing::debug!("Constructing disk pool.");
let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>(
layout,
next_block_set_idx,
cancellation_token.clone(),
worker_id,
)?;
(Arc::new(Some(pool)), Some(blocks))
}
} else {
tracing::debug!("No disk layout provided; will not allocate disk blocks.");
(Arc::new(None), None)
};
// Create the host block pool if a host layout is provided // Create the host block pool if a host layout is provided
let (host_pool, host_blocks) = if let Some(config) = config.host_layout { let (host_pool, host_blocks) = if let Some(config) = config.host_layout {
next_block_set_idx += 1; next_block_set_idx += 1;
tracing::debug!("Constructing host pool."); tracing::debug!("Constructing host pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?; let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?); local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>( let (pool, blocks) = create_block_pool::<_, Metadata>(
layout, layout,
...@@ -130,7 +168,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -130,7 +168,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let (device_pool, device_blocks) = if let Some(config) = config.device_layout { let (device_pool, device_blocks) = if let Some(config) = config.device_layout {
next_block_set_idx += 1; next_block_set_idx += 1;
tracing::debug!("Constructing device pool."); tracing::debug!("Constructing device pool.");
let layout = create_layout(layout_builder.clone(), config, nixl_agent.as_ref())?; let layout =
create_layout(layout_builder.clone(), config, nixl_agent.as_ref().as_ref())?;
local_block_set.add_block_set(next_block_set_idx, layout.serialize()?); local_block_set.add_block_set(next_block_set_idx, layout.serialize()?);
let (pool, blocks) = create_block_pool::<_, Metadata>( let (pool, blocks) = create_block_pool::<_, Metadata>(
layout, layout,
...@@ -145,18 +184,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -145,18 +184,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
}; };
// Finalize the local block set by adding NIXL metadata // Finalize the local block set by adding NIXL metadata
if let Some(nixl_agent) = &nixl_agent { if let Some(nixl_agent) = nixl_agent.as_ref() {
tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata."); tracing::debug!("Finalize NixlBlockSet: adding NIXL metadata.");
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
} }
let offload_manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?; let offload_async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let offload_manager = OffloadManager::new(
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
nixl_agent.clone(),
offload_async_rt_handle,
)?;
let state = Arc::new(Self { let state = Arc::new(Self {
worker_id, worker_id,
cancellation_token, cancellation_token,
nixl_agent, nixl_agent,
nixl_backends, nixl_backends,
disk_pool,
host_pool, host_pool,
device_pool, device_pool,
local_block_set, local_block_set,
...@@ -164,6 +218,19 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -164,6 +218,19 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
offload_manager, offload_manager,
}); });
if let Some(mut blocks) = disk_blocks {
blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone());
});
state
.disk_pool
.as_ref()
.as_ref()
.unwrap()
.add_blocks_blocking(blocks)?;
}
if let Some(mut blocks) = host_blocks { if let Some(mut blocks) = host_blocks {
blocks.iter_mut().for_each(|block| { blocks.iter_mut().for_each(|block| {
block.set_manager(state.clone()); block.set_manager(state.clone());
...@@ -230,6 +297,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -230,6 +297,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let agent = self let agent = self
.nixl_agent .nixl_agent
.as_ref() .as_ref()
.as_ref()
.ok_or_else(|| anyhow::anyhow!("NIXL agent not initialized"))?; .ok_or_else(|| anyhow::anyhow!("NIXL agent not initialized"))?;
let mut remote_block_sets = self.remote_block_sets.write().unwrap(); let mut remote_block_sets = self.remote_block_sets.write().unwrap();
...@@ -344,6 +412,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -344,6 +412,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(blocks) Ok(blocks)
} }
pub fn disk(&self) -> Option<&BlockPool<DiskStorage, Metadata>> {
self.disk_pool.as_ref().as_ref()
}
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> { pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref().as_ref() self.host_pool.as_ref().as_ref()
} }
...@@ -366,10 +438,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -366,10 +438,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
Ok(()) Ok(())
} }
pub async fn onboard_blocks( pub async fn onboard_blocks<S: Storage>(
&self, &self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>, blocks: Vec<ImmutableBlock<S, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> { ) -> BlockResult<DeviceStorage, Metadata> {
self.offload_manager.onboard(blocks).await self.offload_manager.onboard(blocks).await
} }
} }
......
...@@ -78,9 +78,11 @@ ...@@ -78,9 +78,11 @@
//! - [`StorageAllocator`] - Factory for creating storage instances //! - [`StorageAllocator`] - Factory for creating storage instances
pub mod cuda; pub mod cuda;
pub mod disk;
pub mod nixl; pub mod nixl;
pub use cuda::*; pub use cuda::*;
pub use disk::*;
use std::{ use std::{
alloc::{alloc_zeroed, dealloc, Layout}, alloc::{alloc_zeroed, dealloc, Layout},
...@@ -107,6 +109,9 @@ pub enum StorageType { ...@@ -107,6 +109,9 @@ pub enum StorageType {
/// CUDA page-locked host memory /// CUDA page-locked host memory
Pinned, Pinned,
/// Disk memory
Disk,
/// Remote memory accessible through NIXL /// Remote memory accessible through NIXL
Nixl, Nixl,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use nix::fcntl::{fallocate, FallocateFlags};
use std::ffi::CString;
use std::fs::File;
use std::os::unix::io::{AsRawFd, FromRawFd};
#[derive(Debug)]
pub struct DiskStorage {
file: File,
file_name: String,
size: usize,
handles: RegistrationHandles,
}
impl Local for DiskStorage {}
impl SystemAccessible for DiskStorage {}
impl DiskStorage {
pub fn new(size: usize) -> Result<Self, StorageError> {
// We need to open our file with some special flags that aren't supported by the tempfile crate.
// Instead, we'll use the mkostemp function to create a temporary file with the correct flags.
let template = CString::new("/tmp/dynamo-kvbm-disk-cache-XXXXXX").unwrap();
let mut template_bytes = template.into_bytes_with_nul();
let raw_fd = unsafe {
nix::libc::mkostemp(
template_bytes.as_mut_ptr() as *mut i8,
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned.
nix::libc::O_RDWR | nix::libc::O_DIRECT,
)
};
let file = unsafe { File::from_raw_fd(raw_fd) };
let file_name = String::from_utf8_lossy(&template_bytes)
.trim_end_matches("\0")
.to_string();
file.set_len(size as u64).map_err(|_| {
StorageError::AllocationFailed("Failed to set temp file size".to_string())
})?;
// File::set_len() only updates the metadata of the file, it does not allocate the underlying storage.
// We need to use fallocate to actually allocate the storage and create the blocks on disk.
fallocate(file.as_raw_fd(), FallocateFlags::empty(), 0, size as i64).map_err(|_| {
StorageError::AllocationFailed("Failed to allocate temp file".to_string())
})?;
Ok(Self {
file,
file_name,
size,
handles: RegistrationHandles::new(),
})
}
pub fn fd(&self) -> u64 {
self.file.as_raw_fd() as u64
}
}
impl Drop for DiskStorage {
// TODO: How robust is this actually?
fn drop(&mut self) {
std::fs::remove_file(self.file_name.clone()).unwrap();
}
}
impl Storage for DiskStorage {
fn storage_type(&self) -> StorageType {
StorageType::Disk
}
fn addr(&self) -> u64 {
0
}
fn size(&self) -> usize {
self.size
}
unsafe fn as_ptr(&self) -> *const u8 {
std::ptr::null()
}
unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
std::ptr::null_mut()
}
}
impl RegisterableStorage for DiskStorage {
fn register(
&mut self,
key: &str,
handle: Box<dyn RegistationHandle>,
) -> Result<(), StorageError> {
self.handles.register(key, handle)
}
fn is_registered(&self, key: &str) -> bool {
self.handles.is_registered(key)
}
fn registration_handle(&self, key: &str) -> Option<&dyn RegistationHandle> {
self.handles.registration_handle(key)
}
}
#[derive(Default)]
pub struct DiskAllocator;
impl StorageAllocator<DiskStorage> for DiskAllocator {
fn allocate(&self, size: usize) -> Result<DiskStorage, StorageError> {
DiskStorage::new(size)
}
}
...@@ -82,8 +82,8 @@ use derive_getters::Getters; ...@@ -82,8 +82,8 @@ use derive_getters::Getters;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{ use super::{
CudaContextProivder, DeviceStorage, PinnedStorage, RegistationHandle, RegisterableStorage, CudaContextProivder, DeviceStorage, DiskStorage, PinnedStorage, RegistationHandle,
Remote, Storage, StorageError, StorageType, SystemStorage, RegisterableStorage, Remote, Storage, StorageError, StorageType, SystemStorage,
}; };
/// Marker trait for storage types that can be accessed by NIXL. /// Marker trait for storage types that can be accessed by NIXL.
...@@ -104,17 +104,7 @@ impl StorageType { ...@@ -104,17 +104,7 @@ impl StorageType {
StorageType::Device(_) => MemType::Vram, StorageType::Device(_) => MemType::Vram,
StorageType::Nixl => MemType::Unknown, StorageType::Nixl => MemType::Unknown,
StorageType::Null => MemType::Unknown, StorageType::Null => MemType::Unknown,
} StorageType::Disk => MemType::File,
}
/// Get the NIXL device ID for a given storage type.
pub fn nixl_device_id(&self) -> u64 {
match self {
StorageType::System => 0,
StorageType::Pinned => 0,
StorageType::Device(id) => *id as u64,
StorageType::Nixl => 0,
StorageType::Null => 0,
} }
} }
} }
...@@ -311,3 +301,27 @@ impl NixlDescriptor for DeviceStorage { ...@@ -311,3 +301,27 @@ impl NixlDescriptor for DeviceStorage {
CudaContextProivder::cuda_context(self).cu_device() as u64 CudaContextProivder::cuda_context(self).cu_device() as u64
} }
} }
impl NixlAccessible for DiskStorage {}
impl NixlRegisterableStorage for DiskStorage {}
impl MemoryRegion for DiskStorage {
unsafe fn as_ptr(&self) -> *const u8 {
Storage::as_ptr(self)
}
fn size(&self) -> usize {
Storage::size(self)
}
}
impl NixlDescriptor for DiskStorage {
fn mem_type(&self) -> MemType {
MemType::File
}
/// Nixl treats the file descriptor as the device ID.
fn device_id(&self) -> u64 {
self.fd()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment