"lib/vscode:/vscode.git/clone" did not exist on "acbdabc464fc6dca467a4eac1bf870a31d64b079"
Unverified Commit 312ee8e2 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Restructure the KVBM WriteTo trait (#1363)

parent 3d499705
......@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor;
pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid};
pub use transfer::TransferContext;
use crate::block_manager::{
state::KvBlockManagerState as BlockManager,
......
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod context;
mod cuda;
mod memcpy;
mod nixl;
......@@ -29,12 +30,12 @@ use crate::block_manager::storage::{
use cudarc::driver::CudaStream;
use nixl_sys::XferOp::{Read, Write};
use std::future::Future;
use std::ops::Range;
use tokio::sync::oneshot;
pub use crate::block_manager::state::TransferContext;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait;
pub use context::TransferContext;
/// A block that can be the target of a write
pub trait Writable {}
......@@ -149,19 +150,9 @@ pub trait WriteTo<Target> {
fn write_to(
&self,
dst: &mut Vec<Target>,
notify: Option<String>,
notify: bool,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError>;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn nixl_write_to(
&self,
dst: &mut Vec<Target>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError>;
) -> Result<Option<oneshot::Receiver<()>>, TransferError>;
}
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for Vec<Arc<RB>>
......@@ -171,15 +162,25 @@ where
fn write_to(
&self,
dst: &mut Vec<WB>,
notify: Option<String>,
notify: bool,
ctx: Arc<TransferContext>,
) -> Result<(), TransferError> {
) -> Result<Option<oneshot::Receiver<()>>, TransferError> {
let (tx, rx) = oneshot::channel();
match RB::write_to_strategy() {
TransferStrategy::Memcpy => {
for (src, dst) in self.iter().zip(dst.iter_mut()) {
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy::copy_block(src.as_ref(), dst)?;
}
Ok(())
if notify {
tx.send(()).unwrap();
Ok(Some(rx))
} else {
Ok(None)
}
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
......@@ -192,17 +193,27 @@ where
RB::write_to_strategy(),
)?;
}
Ok(())
if notify {
let (tx, rx) = oneshot::channel();
ctx.cuda_event(tx)?;
Ok(Some(rx))
} else {
Ok(None)
}
}
TransferStrategy::Nixl(transfer_type) => {
std::mem::drop(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?);
Ok(())
let transfer_fut = nixl::write_blocks_to(self, dst, &ctx, transfer_type)?;
if notify {
ctx.async_rt_handle().spawn(async move {
transfer_fut.await;
tx.send(()).unwrap();
});
Ok(Some(rx))
} else {
Ok(None)
}
}
_ => Err(TransferError::IncompatibleTypes(format!(
"Unsupported copy strategy: {:?}",
......@@ -210,28 +221,6 @@ where
))),
}
}
fn nixl_write_to(
&self,
dst: &mut Vec<WB>,
notify: Option<String>,
ctx: Arc<TransferContext>,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>, TransferError> {
if let TransferStrategy::Nixl(transfer_type) = RB::write_to_strategy() {
Ok(nixl::write_blocks_to(
self,
dst,
ctx,
notify,
transfer_type,
)?)
} else {
Err(TransferError::IncompatibleTypes(format!(
"Expected NIXL transfer strategy, got: {:?}",
RB::write_to_strategy()
)))?
}
}
}
#[derive(Default)]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use cudarc::driver::{sys::CUevent_flags, CudaEvent, CudaStream};
use nixl_sys::Agent as NixlAgent;
use std::sync::Arc;
use std::thread::JoinHandle;
use tokio::runtime::Handle;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
pub struct TransferContext {
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
async_rt_handle: Handle,
cuda_event_tx: mpsc::UnboundedSender<(CudaEvent, oneshot::Sender<()>)>,
cuda_event_worker: Option<JoinHandle<()>>,
cancel_token: CancellationToken,
}
impl TransferContext {
pub fn new(
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
async_rt_handle: Handle,
) -> Self {
let (cuda_event_tx, mut cuda_event_rx) =
mpsc::unbounded_channel::<(CudaEvent, oneshot::Sender<()>)>();
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
let cuda_event_worker = std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime for CUDA event worker.");
runtime.block_on(async move {
loop {
tokio::select! {
Some((event, tx)) = cuda_event_rx.recv() => {
if let Err(e) = event.synchronize() {
tracing::error!("Error synchronizing CUDA event: {}", e);
}
let _ = tx.send(());
}
_ = cancel_token_clone.cancelled() => {
break;
}
}
}
});
});
Self {
nixl_agent,
stream,
async_rt_handle,
cuda_event_tx,
cuda_event_worker: Some(cuda_event_worker),
cancel_token,
}
}
pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.clone()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub fn async_rt_handle(&self) -> &Handle {
&self.async_rt_handle
}
pub fn cuda_event(&self, tx: oneshot::Sender<()>) -> Result<(), TransferError> {
let event = self
.stream
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))
.map_err(|e| TransferError::ExecutionError(e.to_string()))?;
self.cuda_event_tx
.send((event, tx))
.map_err(|_| TransferError::ExecutionError("CUDA event worker exited.".into()))?;
Ok(())
}
}
impl Drop for TransferContext {
fn drop(&mut self) {
self.cancel_token.cancel();
if let Some(handle) = self.cuda_event_worker.take() {
if let Err(e) = handle.join() {
tracing::error!("Error joining CUDA event worker: {:?}", e);
}
}
}
}
......@@ -16,9 +16,8 @@
use super::*;
use anyhow::Result;
use nixl_sys::{MemoryRegion, NixlDescriptor, OptArgs, XferDescList};
use std::future::{poll_fn, Future};
use std::task::Poll;
use nixl_sys::{MemoryRegion, NixlDescriptor, XferDescList};
use std::future::Future;
fn append_xfer_request<Source, Destination>(
src: &Arc<Source>,
......@@ -87,8 +86,7 @@ where
pub fn write_blocks_to<Source, Destination>(
src: &[Arc<Source>],
dst: &mut [Destination],
ctx: Arc<TransferContext>,
notify: Option<String>,
ctx: &Arc<TransferContext>,
transfer_type: NixlTransfer,
) -> Result<Box<dyn Future<Output = ()> + Send + Sync + Unpin>>
where
......@@ -136,26 +134,27 @@ where
None,
)?;
let mut xfer_args = OptArgs::new()?;
let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;
if let Some(notify) = notify {
xfer_args.set_has_notification(true)?;
xfer_args.set_notification_message(notify.as_bytes())?;
}
let _ = nixl_agent.post_xfer_req(&xfer_req, Some(&xfer_args))?;
Ok(Box::new(poll_fn(move |_cx| {
if still_pending {
Ok(Box::new(Box::pin(async move {
let nixl_agent = nixl_agent_arc
.as_ref()
.as_ref()
.expect("NIXL agent not found");
// The nixl agent returns true if the transfer is still in progress.
if !nixl_agent.get_xfer_status(&xfer_req).unwrap() {
Poll::Ready(())
} else {
Poll::Pending
loop {
match nixl_agent.get_xfer_status(&xfer_req) {
Ok(false) => break, // Transfer is complete.
Ok(true) => tokio::time::sleep(std::time::Duration::from_millis(5)).await, // Transfer is still in progress.
Err(e) => {
tracing::error!("Error getting transfer status: {}", e);
break;
}
}
}
})))
} else {
Ok(Box::new(std::future::ready(())))
}
}
......@@ -44,9 +44,8 @@
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
use super::block::{BlockError, BlockMetadata, BlockState, ImmutableBlock};
use super::block::{BlockError, BlockMetadata, BlockState, ImmutableBlock, TransferContext};
use super::pool::BlockPoolError;
use super::state::TransferContext;
use super::storage::{Cuda, Storage};
use super::{BlockPool, DeviceStorage, DiskStorage, PinnedStorage};
use nixl_sys::Agent as NixlAgent;
......@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let device_offload_transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(),
cuda_ctx.new_stream()?,
async_rt_handle.clone(),
));
// Device -> Host offload
......@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager::new(
device_offload_transfer_ctx,
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
)?,
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
......@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(),
cuda_ctx.new_stream()?,
async_rt_handle.clone(),
));
// Host -> Disk offload
......@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
)?,
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
......@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager::new(
transfer_ctx.clone(),
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
)?,
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
......@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
)?,
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
......@@ -549,8 +552,10 @@ mod tests {
let agent = NixlAgent::new("offload-manager").unwrap();
let (_, ucx_params) = agent.get_plugin_params("UCX").unwrap();
let (_, gds_params) = agent.get_plugin_params("GDS").unwrap();
let (_, posix_params) = agent.get_plugin_params("POSIX").unwrap();
agent.create_backend("UCX", &ucx_params).unwrap();
agent.create_backend("GDS", &gds_params).unwrap();
agent.create_backend("POSIX", &posix_params).unwrap();
Arc::new(Some(agent))
};
}
......
......@@ -41,23 +41,21 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::thread::spawn;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy},
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, WritableBlock,
BlockError, BlockExt, BlockMetadata, BlockState, MutableBlock, ReadableBlock, TransferContext,
WritableBlock,
};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::state::TransferContext;
use crate::block_manager::storage::{Local, Storage};
use crate::block_manager::BlockPool;
use anyhow::Result;
use async_trait::async_trait;
use cudarc::driver::{sys::CUevent_flags, CudaEvent};
use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult;
......@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = completion_indicator {
completion_indicator.send(Ok(blocks))?;
completion_indicator
.send(Ok(blocks))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
}
Ok(())
......@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
}
pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>,
pending_transfer_q: mpsc::Sender<(
PendingTransfer<Source, Target, Metadata>,
tokio::sync::oneshot::Receiver<()>,
)>,
transfer_ctx: Arc<TransferContext>,
}
......@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(
max_concurrent_transfers,
);
) -> Result<Self> {
let (tx, mut rx) = mpsc::channel::<(
PendingTransfer<Source, Target, Metadata>,
tokio::sync::oneshot::Receiver<()>,
)>(max_concurrent_transfers);
spawn(move || {
while let Some((pending_transfer, event)) = rx.blocking_recv() {
CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move {
loop {
tokio::select! {
Some((pending_transfer, notify)) = rx.recv() => {
// Wait for the event.
event.synchronize()?;
notify.await.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
// Only finalize the transfer after the event is signaled.
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
// The only case where this can fail is if the progress engine is shutdown.
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
}
// Flush any remaining transfers.
if cancellation_token.is_cancelled() {
while rx.blocking_recv().is_some() {}
break;
_ = cancel_token.cancelled() => {
return Ok(());
}
}
}
Ok::<(), anyhow::Error>(())
});
},
cancellation_token.clone(),
"Cuda Transfer Manager",
runtime,
)?
.detach();
Self {
Ok(Self {
pending_transfer_q: tx,
transfer_ctx,
}
})
}
}
......@@ -214,22 +226,23 @@ where
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
pending_transfer.sources.write_to(
let notify = pending_transfer
.sources
.write_to(
&mut pending_transfer.targets,
None,
true,
self.transfer_ctx.clone(),
)?;
// Use a cuda event to record the completion of the transfers.
let event = self
.transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self.pending_transfer_q
.send((pending_transfer, event))
.send((pending_transfer, notify))
.await?;
Ok(())
......@@ -247,10 +260,11 @@ impl DiskTransferManager {
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
) -> Result<Self> {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
runtime.spawn(async move {
CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move {
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
......@@ -258,10 +272,8 @@ impl DiskTransferManager {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
// Flush remaining transfers.
while (pending_transfers.next().await).is_some() {}
return;
_ = cancel_token.cancelled() => {
return Ok(());
}
Some(future) = futures_rx.recv() => {
......@@ -277,12 +289,17 @@ impl DiskTransferManager {
}
}
}
});
},
cancellation_token.clone(),
"Disk Transfer Manager",
runtime,
)?
.detach();
Self {
Ok(Self {
futures_tx,
transfer_ctx,
}
})
}
}
......@@ -303,14 +320,21 @@ where
&self,
mut pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
let future = pending_transfer.sources.nixl_write_to(
let notify = pending_transfer
.sources
.write_to(
&mut pending_transfer.targets,
None,
true,
self.transfer_ctx.clone(),
)?;
)?
.ok_or_else(|| {
anyhow::anyhow!(
"write_to returned None when notify was true. This should never happen!"
)
})?;
let completion_future = async move {
let _ = future.await;
let _ = notify.await;
match pending_transfer.handle_complete() {
Ok(_) => {}
Err(e) => {
......
......@@ -21,29 +21,9 @@ use super::{
config::NixlOptions,
events::{EventManager, NullEventManager},
};
use cudarc::driver::CudaStream;
use std::sync::Arc;
use tokio::runtime::Handle;
pub struct TransferContext {
nixl_agent: Arc<Option<NixlAgent>>,
stream: Arc<CudaStream>,
}
impl TransferContext {
pub fn new(nixl_agent: Arc<Option<NixlAgent>>, stream: Arc<CudaStream>) -> Self {
Self { nixl_agent, stream }
}
pub fn nixl_agent(&self) -> Arc<Option<NixlAgent>> {
self.nixl_agent.clone()
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
#[allow(dead_code)]
pub struct KvBlockManagerState<Metadata: BlockMetadata> {
worker_id: WorkerID,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment