Unverified Commit 25c711f8 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Integrate KVBM with `CriticalTaskHandle` (#1321)

parent 8deb3ea4
......@@ -56,6 +56,7 @@ use tokio::sync::{
mpsc::{self, error::TryRecvError},
Mutex,
};
use tokio_util::sync::CancellationToken;
use anyhow::Result;
use std::any::Any;
......@@ -70,6 +71,8 @@ use pending::{
};
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
const MAX_CONCURRENT_TRANSFERS: usize = 4;
const MAX_TRANSFER_BATCH_SIZE: usize = 16;
......@@ -99,6 +102,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
nixl_agent: Arc<Option<NixlAgent>>,
async_rt_handle: Handle,
cancellation_token: CancellationToken,
) -> Result<Arc<Self>> {
let (device_offload_tx, device_offload_rx) = mpsc::unbounded_channel();
let (host_offload_tx, host_offload_rx) = mpsc::unbounded_channel();
......@@ -128,21 +132,29 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
));
// Device -> Host offload
let device_clone = this.device.clone();
let host_clone = this.host.clone();
async_rt_handle.spawn(async move {
let res = OffloadManager::offload_worker(
device_clone,
host_clone,
device_offload_rx,
Arc::new(TransferBatcher::new(
CudaTransferManager::new(device_offload_transfer_ctx, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await;
tracing::warn!("Offload worker terminated: {:?}", res);
});
let device_to_host_task = OffloadManager::offload_worker(
this.device.clone(),
this.host.clone(),
device_offload_rx,
Arc::new(TransferBatcher::new(
CudaTransferManager::new(
device_offload_transfer_ctx,
MAX_CONCURRENT_TRANSFERS,
cancellation_token.clone(),
),
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
)),
cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| device_to_host_task,
cancellation_token.clone(),
"Device -> Host offload worker",
&async_rt_handle,
)?
.detach();
let transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(),
......@@ -150,58 +162,81 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
));
// Host -> Disk offload
let host_clone = this.host.clone();
let disk_clone = this.disk.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
let res = OffloadManager::offload_worker(
host_clone,
disk_clone,
host_offload_rx,
Arc::new(TransferBatcher::new(
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await;
tracing::warn!("Offload worker terminated: {:?}", res);
});
let host_to_disk_task = OffloadManager::offload_worker(
this.host.clone(),
this.disk.clone(),
host_offload_rx,
Arc::new(TransferBatcher::new(
DiskTransferManager::new(
transfer_ctx.clone(),
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
)),
cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| host_to_disk_task,
cancellation_token.clone(),
"Host -> Disk offload worker",
&async_rt_handle,
)?
.detach();
// Host -> Device onboarding
let host_clone = this.host.clone();
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
let res = OffloadManager::onboard_worker(
host_clone,
device_clone,
host_onboard_rx,
Arc::new(TransferBatcher::new(
CudaTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await;
tracing::warn!("Onboard worker terminated: {:?}", res);
});
let host_to_device_task = OffloadManager::onboard_worker(
this.host.clone(),
this.device.clone(),
host_onboard_rx,
Arc::new(TransferBatcher::new(
CudaTransferManager::new(
transfer_ctx.clone(),
MAX_CONCURRENT_TRANSFERS,
cancellation_token.clone(),
),
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
)),
cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| host_to_device_task,
cancellation_token.clone(),
"Host -> Device onboarding worker",
&async_rt_handle,
)?
.detach();
// Disk -> Device onboarding
let disk_clone = this.disk.clone();
let device_clone = this.device.clone();
let transfer_ctx_clone = transfer_ctx.clone();
async_rt_handle.spawn(async move {
let res = OffloadManager::onboard_worker(
disk_clone,
device_clone,
disk_onboard_rx,
Arc::new(TransferBatcher::new(
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS),
MAX_TRANSFER_BATCH_SIZE,
)),
)
.await;
tracing::warn!("Onboard worker terminated: {:?}", res);
});
let disk_to_device_task = OffloadManager::onboard_worker(
this.disk.clone(),
this.device.clone(),
disk_onboard_rx,
Arc::new(TransferBatcher::new(
DiskTransferManager::new(
transfer_ctx.clone(),
MAX_CONCURRENT_TRANSFERS,
&async_rt_handle,
cancellation_token.clone(),
),
MAX_TRANSFER_BATCH_SIZE,
&async_rt_handle,
cancellation_token.clone(),
)),
cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| disk_to_device_task,
cancellation_token.clone(),
"Disk -> Device onboarding worker",
&async_rt_handle,
)?
.detach();
Ok(this_clone)
}
......@@ -211,6 +246,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
cancellation_token: CancellationToken,
) -> Result<()> {
if source_pool.is_none() || target_pool.is_none() {
return Ok(());
......@@ -222,6 +258,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let mut queue = BTreeSet::new();
loop {
if cancellation_token.is_cancelled() {
return Ok(());
}
// Try to check the offload queue.
loop {
match offload_rx.try_recv() {
......@@ -231,7 +271,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
Err(TryRecvError::Empty) => {
break;
}
Err(_) => return Ok(()),
Err(e) => return Err(e.into()),
}
}
......@@ -280,8 +320,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
} else {
// Await the next request.
if let Some(request) = offload_rx.recv().await {
queue.insert(request);
tokio::select! {
_ = cancellation_token.cancelled() => return Ok(()),
Some(request) = offload_rx.recv() => {
queue.insert(request);
}
}
}
}
......@@ -292,40 +335,45 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
cancellation_token: CancellationToken,
) -> Result<()> {
if source_pool.is_none() || target_pool.is_none() {
return Ok(());
}
let target_pool = target_pool.as_ref().unwrap();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => return Ok::<(), anyhow::Error>(()),
Some(request) = onboard_rx.recv() => {
// Try to allocate blocks on the device.
let target_blocks = match target_pool.allocate_blocks(request.blocks.len()).await {
Ok(blocks) => blocks,
Err(err) => {
request.response_tx.send(Err(err))?;
continue;
}
};
// Loop on incoming requests
while let Some(request) = onboard_rx.recv().await {
// Try to allocate blocks on the device.
let target_blocks = match target_pool.allocate_blocks(request.blocks.len()).await {
Ok(blocks) => blocks,
Err(err) => {
request.response_tx.send(Err(err))?;
continue;
let sources = request
.blocks
.iter()
.map(|b| b.mutable_block().clone())
.collect();
transfer_manager
.enqueue_transfer(PendingTransfer::new(
sources,
target_blocks,
Some(request.response_tx),
target_pool.clone(),
))
.await?;
Ok::<(), anyhow::Error>(())
}
};
let sources = request
.blocks
.iter()
.map(|b| b.mutable_block().clone())
.collect();
transfer_manager
.enqueue_transfer(PendingTransfer::new(
sources,
target_blocks,
Some(request.response_tx),
target_pool.clone(),
))
.await?;
}?;
}
Ok(())
}
pub async fn offload<S: Storage>(
......@@ -568,6 +616,7 @@ mod tests {
device_pool.clone(),
agent_arc,
async_rt_handle,
CancellationToken::new(),
)?;
Ok((manager, device_pool, host_pool, disk_pool))
......
......@@ -42,7 +42,9 @@ use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::thread::spawn;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy},
......@@ -60,6 +62,8 @@ use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
/// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from.
......@@ -153,7 +157,11 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata>
{
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
cancellation_token: CancellationToken,
) -> Self {
let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(
max_concurrent_transfers,
);
......@@ -171,6 +179,12 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
tracing::warn!("Error handling transfer completion: {:?}", e);
}
}
// Flush any remaining transfers.
if cancellation_token.is_cancelled() {
while rx.blocking_recv().is_some() {}
break;
}
}
Ok::<(), anyhow::Error>(())
});
......@@ -228,16 +242,28 @@ pub struct DiskTransferManager {
}
impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self {
pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move {
runtime.spawn(async move {
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
// Flush remaining transfers.
while (pending_transfers.next().await).is_some() {}
return;
}
Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_concurrent_transfers {
......@@ -249,10 +275,6 @@ impl DiskTransferManager {
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more
}
else => {
// Both branches are pending, wait for one to become ready
tokio::task::yield_now().await;
}
}
}
});
......@@ -317,6 +339,8 @@ where
{
transfer_manager: Manager,
max_transfer_batch_size: usize,
runtime: Handle,
cancellation_token: CancellationToken,
_phantom: PhantomData<(Source, Target, Metadata)>,
}
......@@ -327,10 +351,17 @@ where
Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>,
{
pub fn new(transfer_manager: Manager, max_transfer_batch_size: usize) -> Self {
pub fn new(
transfer_manager: Manager,
max_transfer_batch_size: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
Self {
transfer_manager,
max_transfer_batch_size,
runtime: runtime.clone(),
cancellation_token,
_phantom: PhantomData,
}
}
......@@ -391,25 +422,40 @@ where
}
if let Some(completion_indicator) = completion_indicator {
tokio::spawn(async move {
let mut results = Vec::new();
for indicator in indicators.into_iter() {
// Await each sub-transfer, and append the results to our final results.
let result = match indicator.await.unwrap() {
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
return;
CriticalTaskExecutionHandle::new_with_runtime(
move |cancel_token| async move {
let mut results = Vec::new();
for indicator in indicators.into_iter() {
// Await each sub-transfer, and append the results to our final results.
tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(());
}
Ok(indicator) = indicator => {
let result = match indicator {
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
return Ok(());
}
};
results.extend(result);
}
}
};
results.extend(result);
}
}
// Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results))?;
// Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results)).unwrap();
});
Ok(())
},
self.cancellation_token.clone(),
"Transfer Batcher",
&self.runtime,
)?.detach();
}
Ok(())
......
......@@ -212,6 +212,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool.clone(),
nixl_agent.clone(),
async_rt_handle,
cancellation_token.clone(),
)?;
let state = Arc::new(Self {
......
......@@ -17,6 +17,7 @@
use anyhow::{Context, Result};
use std::future::Future;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
......@@ -41,20 +42,34 @@ pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send
pub struct CriticalTaskExecutionHandle {
monitor_task: JoinHandle<()>,
graceful_shutdown_token: CancellationToken,
result_receiver: oneshot::Receiver<Result<()>>,
result_receiver: Option<oneshot::Receiver<Result<()>>>,
detached: bool,
}
impl CriticalTaskExecutionHandle {
pub fn new<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken,
description: &str,
) -> Result<Self>
where
Fut: Future<Output = Result<()>> + Send + 'static,
{
Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
}
/// Create a new [CriticalTaskExecutionHandle] for a critical task.
///
/// # Arguments
/// * `task_fn` - A function that takes a cancellation token and returns the critical task future
/// * `parent_token` - Token that will be cancelled if this critical task fails
/// * `description` - Description for logging purposes
pub async fn new<Fut>(
/// * `runtime` - The runtime to use for the task.
pub fn new_with_runtime<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken,
description: &str,
runtime: &Handle,
) -> Result<Self>
where
Fut: Future<Output = Result<()>> + Send + 'static,
......@@ -68,7 +83,7 @@ impl CriticalTaskExecutionHandle {
let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
let description_clone = description.to_string();
let task = tokio::spawn(async move {
let task = runtime.spawn(async move {
let future = task_fn(graceful_shutdown_token_clone);
match future.await {
......@@ -92,7 +107,7 @@ impl CriticalTaskExecutionHandle {
let parent_token_monitor = parent_token_clone.clone();
let description_monitor = description.clone();
tokio::spawn(async move {
runtime.spawn(async move {
let result = match main_task_handle.await {
Ok(task_result) => {
// Task completed normally (success or error)
......@@ -147,7 +162,8 @@ impl CriticalTaskExecutionHandle {
Ok(Self {
monitor_task,
graceful_shutdown_token,
result_receiver,
result_receiver: Some(result_receiver),
detached: false,
})
}
......@@ -179,13 +195,28 @@ impl CriticalTaskExecutionHandle {
/// - `Err(...)` if the task failed or panicked, preserving the original error
///
/// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
pub async fn join(self) -> Result<()> {
match self.result_receiver.await {
pub async fn join(mut self) -> Result<()> {
self.detached = true;
let result = match self.result_receiver.take().unwrap().await {
Ok(task_result) => task_result,
Err(_) => {
// This should rarely happen - means monitor task was dropped/cancelled
Err(anyhow::anyhow!("Critical task monitor was cancelled"))
}
};
result
}
/// Detach the task. This allows the task to continue running after the handle is dropped.
pub fn detach(mut self) {
self.detached = true;
}
}
impl Drop for CriticalTaskExecutionHandle {
fn drop(&mut self) {
if !self.detached {
panic!("Critical task was not detached prior to drop!");
}
}
}
......@@ -218,7 +249,6 @@ mod tests {
parent_token.clone(),
"test-success-task",
)
.await
.unwrap();
// Task should complete successfully
......@@ -245,7 +275,6 @@ mod tests {
parent_token.clone(),
"test-failure-task",
)
.await
.unwrap();
// Task should fail and cancel parent token
......@@ -284,7 +313,6 @@ mod tests {
parent_token.clone(),
"test-panic-task",
)
.await
.unwrap();
// Panic should be caught and converted to error
......@@ -328,7 +356,6 @@ mod tests {
parent_token.clone(),
"test-graceful-shutdown",
)
.await
.unwrap();
// Let task do some work
......@@ -381,7 +408,6 @@ mod tests {
parent_token.clone(),
"long-running-task",
)
.await
.unwrap();
let handle2 = CriticalTaskExecutionHandle::new(
......@@ -393,7 +419,6 @@ mod tests {
parent_token.clone(),
"failing-task",
)
.await
.unwrap();
// Wait for task 2 to fail
......@@ -432,7 +457,6 @@ mod tests {
parent_token,
"status-test-task",
)
.await
.unwrap();
// Initially task should be running
......@@ -479,7 +503,6 @@ mod tests {
parent_token,
"select-pattern-task",
)
.await
.unwrap();
// Cancel after a short time
......@@ -511,7 +534,6 @@ mod tests {
parent_token,
"long-task",
)
.await
.unwrap();
// Test with timeout
......@@ -532,7 +554,7 @@ mod tests {
// - Demonstrates true "critical task" behavior with immediate failure propagation
let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new(
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await;
panic!("Critical failure!");
......@@ -540,7 +562,6 @@ mod tests {
parent_token.clone(),
"immediate-panic-task",
)
.await
.unwrap();
// Wait for the panic to be detected by monitor task
......@@ -551,6 +572,7 @@ mod tests {
parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task panics"
);
assert!(handle.join().await.is_err());
}
#[tokio::test]
......@@ -563,7 +585,7 @@ mod tests {
// - Demonstrates consistent critical failure behavior
let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new(
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await;
anyhow::bail!("Critical error!");
......@@ -571,7 +593,6 @@ mod tests {
parent_token.clone(),
"immediate-error-task",
)
.await
.unwrap();
// Don't call join() - just wait for the error to be detected
......@@ -582,5 +603,21 @@ mod tests {
parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task errors"
);
assert!(handle.join().await.is_err());
}
#[tokio::test]
#[should_panic]
async fn test_task_detach() {
// Dropping without detaching should panic
let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move { Ok(()) },
parent_token,
"test-detach-task",
)
.unwrap();
// Dropping without detaching should panic
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment