Unverified Commit 312ee8e2 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Restructure the KVBM WriteTo trait (#1363)

parent 3d499705
...@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor; ...@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor;
pub use registry::{GlobalRegistry, RegistrationHandle}; pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid}; pub use state::{BlockState, BlockStateInvalid};
pub use transfer::TransferContext;
use crate::block_manager::{ use crate::block_manager::{
state::KvBlockManagerState as BlockManager, state::KvBlockManagerState as BlockManager,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
mod context;
mod cuda; mod cuda;
mod memcpy; mod memcpy;
mod nixl; mod nixl;
...@@ -29,12 +30,12 @@ use crate::block_manager::storage::{ ...@@ -29,12 +30,12 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use nixl_sys::XferOp::{Read, Write}; use nixl_sys::XferOp::{Read, Write};
use std::future::Future;
use std::ops::Range; use std::ops::Range;
use tokio::sync::oneshot;
pub use crate::block_manager::state::TransferContext;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote}; pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait; pub use async_trait::async_trait;
pub use context::TransferContext;
/// A block that can be the target of a write /// A block that can be the target of a write
pub trait Writable {} pub trait Writable {}
...@@ -149,19 +150,9 @@ pub trait WriteTo<Target> { ...@@ -149,19 +150,9 @@ pub trait WriteTo<Target> {
fn write_to( fn write_to(
&self, &self,
dst: &mut Vec<Target>, dst: &mut Vec<Target>,
notify: Option<String>, notify: bool,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<(), TransferError>; ) -> Result<Option<oneshot::Receiver<()>>, TransferError>;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn nixl_write_to(
&self,
dst: &mut Vec<Target>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>> impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>>
...@@ -171,15 +162,25 @@ where ...@@ -171,15 +162,25 @@ where
fn write_to( fn write_to(
&self, &self,
dst: &mut Vec<WB>, dst: &mut Vec<WB>,
notify: Option<String>, notify: bool,
ctx: Arc<TransferContext>, ctx: Arc<TransferContext>,
) -> Result<(), TransferError> { ) -> Result<Option<oneshot::Receiver<()>>, TransferError> {
let (tx, rx) = oneshot::channel();
match RB::write_to_strategy() { match RB::write_to_strategy() {
TransferStrategy::Memcpy => { TransferStrategy::Memcpy => {
for (src, dst) in self.iter().zip(dst.iter_mut()) { for (src, dst) in self.iter().zip(dst.iter_mut()) {
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy::copy_block(src.as_ref(), dst)?; memcpy::copy_block(src.as_ref(), dst)?;
} }
Ok(())
if notify {
tx.send(()).unwrap();
Ok(Some(rx))
} else {
Ok(None)
}
} }
TransferStrategy::CudaAsyncH2D TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H | TransferStrategy::CudaAsyncD2H
...@@ -192,17 +193,27 @@ where ...@@ -192,17 +193,27 @@ where
RB::write_to_strategy(), RB::write_to_strategy(),
)?; )?;
} }
Ok(())
if notify {
let (tx, rx) = oneshot::channel();
ctx.cuda_event(tx)?;
Ok(Some(rx))
} else {
Ok(None)
}
} }
TransferStrategy::Nixl(transfer_type) => { TransferStrategy::Nixl(transfer_type) => {
std::mem::drop(nixl::write_blocks_to( let transfer_fut = nixl::write_blocks_to(self, dst, &ctx, transfer_type)?;
self,
dst, if notify {
ctx, ctx.async_rt_handle().spawn(async move {
notify, transfer_fut.await;
transfer_type, tx.send(()).unwrap();
)?); });
Ok(()) Ok(Some(rx))
} else {
Ok(None)
}
} }
_ => Err(TransferError::IncompatibleTypes(format!( _ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}", "Unsupported copy strategy: {:?}",
...@@ -210,28 +221,6 @@ where ...@@ -210,28 +221,6 @@ where
))), ))),
} }
} }
fn nixl_write_to(
&self,
dst: &mut Vec<WB>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::Nixl(transfer_type) = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
RB::write_to_strategy()
)))?
}
}
} }
#[derive(Default)] #[derive(Default)]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use cudarc::driver::{sys::CUevent_flags, CudaEvent, CudaStream};
use nixl_sys::Agent as NixlAgent;
use std::sync::Arc;
use std::thread::JoinHandle;
use tokio::runtime::Handle;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
pub struct TransferContext {
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
async_rt_handle: Handle,
cuda_event_tx: mpsc::UnboundedSender<(CudaEvent, oneshot::Sender<()>)>,
cuda_event_worker: Option<JoinHandle<()>>,
cancel_token: CancellationToken,
}
impl TransferContext {
pub fn new(
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
async_rt_handle: Handle,
) -> Self {
let (cuda_event_tx, mut cuda_event_rx) =
mpsc::unbounded_channel::<(CudaEvent, oneshot::Sender<()>)>();
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
let cuda_event_worker = std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for CUDA event worker.");
runtime.block_on(async move {
loop {
tokio::select! {
Some((event, tx)) = cuda_event_rx.recv() => {
if let Err(e) = event.synchronize() {
tracing::error!("Error synchronizing CUDA event: {}", e);
}
let _ = tx.send(());
}
_ = cancel_token_clone.cancelled() => {
break;
}
}
}
});
});
Self {
nixl_agent,
stream,
async_rt_handle,
cuda_event_tx,
cuda_event_worker: Some(cuda_event_worker),
cancel_token,
}
}
pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.clone()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub fn async_rt_handle(&self) -> &Handle {
&self.async_rt_handle
}
pub fn cuda_event(&self, tx: oneshot::Sender<()>) -> Result<(), TransferError> {
let event = self
.stream
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))
.map_err(|e| TransferError::ExecutionError(e.to_string()))?;
self.cuda_event_tx
.send((event, tx))
.map_err(|_| TransferError::ExecutionError("CUDA event worker exited.".into()))?;
Ok(())
}
}
impl Drop for TransferContext {
fn drop(&mut self) {
self.cancel_token.cancel();
if let Some(handle) = self.cuda_event_worker.take() {
if let Err(e) = handle.join() {
tracing::error!("Error joining CUDA event worker: {:?}", e);
}
}
}
}
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
use super::*; use super::*;
use anyhow::Result; use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList}; use nixl_sys::{MemoryRegion, NixlDescriptor, XferDescList};
use std::future::{poll_fn, Future}; use std::future::Future;
use std::task::Poll;
fn append_xfer_request<Source, Destination>( fn append_xfer_request<Source, Destination>(
src: &Arc<Source>, src: &Arc<Source>,
...@@ -87,8 +86,7 @@ where ...@@ -87,8 +86,7 @@ where
pub fn write_blocks_to<Source, Destination>( pub fn write_blocks_to<Source, Destination>(
src: &[Arc<Source>], src: &[Arc<Source>],
dst: &mut [Destination], dst: &mut [Destination],
ctx: Arc<TransferContext>, ctx: &Arc<TransferContext>,
notify: Option<String>,
transfer_type: NixlTransfer, transfer_type: NixlTransfer,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>> ) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where where
...@@ -136,26 +134,27 @@ where ...@@ -136,26 +134,27 @@ where
None, None,
)?; )?;
let mut xfer_args = OptArgs::new()?; let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;
if let Some(notify) = notify { if still_pending {
xfer_args.set_has_notification(true)?; Ok(Box::new(Box::pin(async move {
xfer_args.set_notification_message(notify.as_bytes())?; let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
loop {
match nixl_agent.get_xfer_status(&xfer_req) {
Ok(false) => break, // Transfer is complete.
Ok(true) => tokio::time::sleep(std::time::Duration::from_millis(5)).await, // Transfer is still in progress.
Err(e) => {
tracing::error!("Error getting transfer status: {}", e);
break;
}
}
}
})))
} else {
Ok(Box::new(std::future::ready(())))
} }
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
Ok(Box::new(poll_fn(move |_cx| {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
}
})))
} }
...@@ -44,9 +44,8 @@ ...@@ -44,9 +44,8 @@
//! The kind of offloads/onboards they perform is dictated by the source and target arguments //! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods. //! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
use super::block::{BlockError, BlockMetadata, BlockState, ImmutableBlock}; use super::block::{BlockError, BlockMetadata, BlockState, ImmutableBlock, TransferContext};
use super::pool::BlockPoolError; use super::pool::BlockPoolError;
use super::state::TransferContext;
use super::storage::{Cuda, Storage}; use super::storage::{Cuda, Storage};
use super::{BlockPool, DeviceStorage, DiskStorage, PinnedStorage}; use super::{BlockPool, DeviceStorage, DiskStorage, PinnedStorage};
use nixl_sys::Agent as NixlAgent; use nixl_sys::Agent as NixlAgent;
...@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_offload_transfer_ctx = Arc::new(TransferContext::new( let device_offload_transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(), nixl_agent.clone(),
cuda_ctx.new_stream()?, cuda_ctx.new_stream()?,
async_rt_handle.clone(),
)); ));
// Device -> Host offload // Device -> Host offload
...@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager::new( CudaTransferManager::new(
device_offload_transfer_ctx, device_offload_transfer_ctx,
MAX_CONCURRENT_TRANSFERS, MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
), )?,
MAX_TRANSFER_BATCH_SIZE, MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
...@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let transfer_ctx = Arc::new(TransferContext::new( let transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(), nixl_agent.clone(),
cuda_ctx.new_stream()?, cuda_ctx.new_stream()?,
async_rt_handle.clone(),
)); ));
// Host -> Disk offload // Host -> Disk offload
...@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS, MAX_CONCURRENT_TRANSFERS,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
), )?,
MAX_TRANSFER_BATCH_SIZE, MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
...@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager::new( CudaTransferManager::new(
transfer_ctx.clone(), transfer_ctx.clone(),
MAX_CONCURRENT_TRANSFERS, MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
), )?,
MAX_TRANSFER_BATCH_SIZE, MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
...@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS, MAX_CONCURRENT_TRANSFERS,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
), )?,
MAX_TRANSFER_BATCH_SIZE, MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle, &async_rt_handle,
cancellation_token.clone(), cancellation_token.clone(),
...@@ -549,8 +552,10 @@ mod tests { ...@@ -549,8 +552,10 @@ mod tests {
let agent = NixlAgent::new("offload-manager").unwrap(); let agent = NixlAgent::new("offload-manager").unwrap();
let (_, ucx_params) = agent.get_plugin_params("UCX").unwrap(); let (_, ucx_params) = agent.get_plugin_params("UCX").unwrap();
let (_, gds_params) = agent.get_plugin_params("GDS").unwrap(); let (_, gds_params) = agent.get_plugin_params("GDS").unwrap();
let (_, posix_params) = agent.get_plugin_params("POSIX").unwrap();
agent.create_backend("UCX", &ucx_params).unwrap(); agent.create_backend("UCX", &ucx_params).unwrap();
agent.create_backend("GDS", &gds_params).unwrap(); agent.create_backend("GDS", &gds_params).unwrap();
agent.create_backend("POSIX", &posix_params).unwrap();
Arc::new(Some(agent)) Arc::new(Some(agent))
}; };
} }
......
...@@ -41,23 +41,21 @@ ...@@ -41,23 +41,21 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn;
use tokio::runtime::Handle; use tokio::runtime::Handle;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{ use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy}, transfer::{WriteTo, WriteToStrategy},
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, WritableBlock, BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, TransferContext,
WritableBlock,
}; };
use crate::block_manager::pool::BlockPoolError; use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::state::TransferContext;
use crate::block_manager::storage::{Local, Storage}; use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool; use crate::block_manager::BlockPool;
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{stream::FuturesUnordered, StreamExt}; use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult; use super::BlockResult;
...@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
let blocks = target_pool.register_blocks_blocking(targets)?; let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator { if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?; completion_indicator
.send(Ok(blocks))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
} }
Ok(()) Ok(())
...@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad ...@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
} }
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> { pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>, pending_transfer_q: mpsc::Sender<(
PendingTransfer<Source, Target, Metadata>,
tokio::sync::oneshot::Receiver<()>,
)>,
transfer_ctx: Arc<TransferContext>, transfer_ctx: Arc<TransferContext>,
} }
...@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub fn new( pub fn new(
transfer_ctx: Arc<TransferContext>, transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize, max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
) -> Self { ) -> Result<Self> {
let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>( let (tx, mut rx) = mpsc::channel::<(
max_concurrent_transfers, PendingTransfer<Source, Target, Metadata>,
); tokio::sync::oneshot::Receiver<()>,
)>(max_concurrent_transfers);
spawn(move || {
while let Some((pending_transfer, event)) = rx.blocking_recv() { CriticalTaskExecutionHandle::new_with_runtime(
// Wait for the event. move |cancel_token| async move {
event.synchronize()?; loop {
// Only finalize the transfer after the event is signaled. tokio::select! {
match pending_transfer.handle_complete() { Some((pending_transfer, notify)) = rx.recv() => {
Ok(_) => {} // Wait for the event.
Err(e) => { notify.await.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// The only case where this can fail is if the progress engine is shutdown. // Only finalize the transfer after the event is signaled.
// This is not a problem, so we can just ignore it. match pending_transfer.handle_complete() {
tracing::warn!("Error handling transfer completion: {:?}", e); Ok(_) => {}
} Err(e) => {
} // The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}
// Flush any remaining transfers. _ = cancel_token.cancelled() => {
if cancellation_token.is_cancelled() { return Ok(());
while rx.blocking_recv().is_some() {} }
break; }
} }
} },
Ok::<(), anyhow::Error>(()) cancellation_token.clone(),
}); "Cuda Transfer Manager",
runtime,
Self { )?
.detach();
Ok(Self {
pending_transfer_q: tx, pending_transfer_q: tx,
transfer_ctx, transfer_ctx,
} })
} }
} }
...@@ -214,22 +226,23 @@ where ...@@ -214,22 +226,23 @@ where
&self, &self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
pending_transfer.sources.write_to( let notify = pending_transfer
&mut pending_transfer.targets, .sources
None, .write_to(
self.transfer_ctx.clone(), &mut pending_transfer.targets,
)?; true,
self.transfer_ctx.clone(),
// Use a cuda event to record the completion of the transfers. )?
let event = self .ok_or_else(|| {
.transfer_ctx anyhow::anyhow!(
.stream() "write_to returned None when notify was true. This should never happen!"
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?; )
})?;
// Send the pending transfer and event to the worker thread. // Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available. // If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q self.pending_transfer_q
.send((pending_transfer, event)) .send((pending_transfer, notify))
.await?; .await?;
Ok(()) Ok(())
...@@ -247,42 +260,46 @@ impl DiskTransferManager { ...@@ -247,42 +260,46 @@ impl DiskTransferManager {
max_concurrent_transfers: usize, max_concurrent_transfers: usize,
runtime: &Handle, runtime: &Handle,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
) -> Self { ) -> Result<Self> {
let (futures_tx, mut futures_rx) = mpsc::channel(1); let (futures_tx, mut futures_rx) = mpsc::channel(1);
runtime.spawn(async move { CriticalTaskExecutionHandle::new_with_runtime(
// Keep track of our pending transfers. move |cancel_token| async move {
// Consume the futures as they complete, while also receiving new ones. // Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new(); let mut pending_transfers = FuturesUnordered::new();
loop { loop {
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => { _ = cancel_token.cancelled() => {
// Flush remaining transfers. return Ok(());
while (pending_transfers.next().await).is_some() {} }
return;
}
Some(future) = futures_rx.recv() => { Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity. // If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_concurrent_transfers { while pending_transfers.len() >= max_concurrent_transfers {
pending_transfers.next().await; pending_transfers.next().await;
}
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future);
}
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
} }
// Once we have capacity, push the new future onto the queue.
pending_transfers.push(future);
}
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
} }
} }
} },
}); cancellation_token.clone(),
"Disk Transfer Manager",
Self { runtime,
)?
.detach();
Ok(Self {
futures_tx, futures_tx,
transfer_ctx, transfer_ctx,
} })
} }
} }
...@@ -303,14 +320,21 @@ where ...@@ -303,14 +320,21 @@ where
&self, &self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>, mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
let future = pending_transfer.sources.nixl_write_to( let notify = pending_transfer
&mut pending_transfer.targets, .sources
None, .write_to(
self.transfer_ctx.clone(), &mut pending_transfer.targets,
)?; true,
self.transfer_ctx.clone(),
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
let completion_future = async move { let completion_future = async move {
let _ = future.await; let _ = notify.await;
match pending_transfer.handle_complete() { match pending_transfer.handle_complete() {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
......
...@@ -21,29 +21,9 @@ use super::{ ...@@ -21,29 +21,9 @@ use super::{
config::NixlOptions, config::NixlOptions,
events::{EventManager, NullEventManager}, events::{EventManager, NullEventManager},
}; };
use cudarc::driver::CudaStream;
use std::sync::Arc; use std::sync::Arc;
use tokio::runtime::Handle; use tokio::runtime::Handle;
pub struct TransferContext {
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
}
impl TransferContext {
pub fn new(nixl_agent: Arc<Option<NixlAgent>>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream }
}
pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.clone()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
#[allow(dead_code)] #[allow(dead_code)]
pub struct KvBlockManagerState<Metadata: BlockMetadata> { pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID, worker_id: WorkerID,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment