Unverified Commit 25c711f8 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Integrate KVBM with `CriticalTaskHandle` (#1321)

parent 8deb3ea4
...@@ -56,6 +56,7 @@ use tokio::sync::{ ...@@ -56,6 +56,7 @@ use tokio::sync::{
mpsc::{self, error::TryRecvError}, mpsc::{self, error::TryRecvError},
Mutex, Mutex,
}; };
use tokio_util::sync::CancellationToken;
use anyhow::Result; use anyhow::Result;
use std::any::Any; use std::any::Any;
...@@ -70,6 +71,8 @@ use pending::{ ...@@ -70,6 +71,8 @@ use pending::{
}; };
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest}; use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
const MAX_CONCURRENT_TRANSFERS: usize = 4; const MAX_CONCURRENT_TRANSFERS: usize = 4;
const MAX_TRANSFER_BATCH_SIZE: usize = 16; const MAX_TRANSFER_BATCH_SIZE: usize = 16;
...@@ -99,6 +102,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -99,6 +102,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>, device: Option<Arc<BlockPool<DeviceStorage, Metadata>>>,
nixl_agent: Arc<Option<NixlAgent>>, nixl_agent: Arc<Option<NixlAgent>>,
async_rt_handle: Handle, async_rt_handle: Handle,
cancellation_token: CancellationToken,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let (device_offload_tx, device_offload_rx) = mpsc::unbounded_channel(); let (device_offload_tx, device_offload_rx) = mpsc::unbounded_channel();
let (host_offload_tx, host_offload_rx) = mpsc::unbounded_channel(); let (host_offload_tx, host_offload_rx) = mpsc::unbounded_channel();
...@@ -128,21 +132,29 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -128,21 +132,29 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
)); ));
// Device -> Host offload // Device -> Host offload
let device_clone = this.device.clone(); let device_to_host_task = OffloadManager::offload_worker(
let host_clone = this.host.clone(); this.device.clone(),
async_rt_handle.spawn(async move { this.host.clone(),
let res = OffloadManager::offload_worker( device_offload_rx,
device_clone, Arc::new(TransferBatcher::new(
host_clone, CudaTransferManager::new(
device_offload_rx, device_offload_transfer_ctx,
Arc::new(TransferBatcher::new( MAX_CONCURRENT_TRANSFERS,
CudaTransferManager::new(device_offload_transfer_ctx, MAX_CONCURRENT_TRANSFERS), cancellation_token.clone(),
MAX_TRANSFER_BATCH_SIZE, ),
)), MAX_TRANSFER_BATCH_SIZE,
) &async_rt_handle,
.await; cancellation_token.clone(),
tracing::warn!("Offload worker terminated: {:?}", res); )),
}); cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| device_to_host_task,
cancellation_token.clone(),
"Device -> Host offload worker",
&async_rt_handle,
)?
.detach();
let transfer_ctx = Arc::new(TransferContext::new( let transfer_ctx = Arc::new(TransferContext::new(
nixl_agent.clone(), nixl_agent.clone(),
...@@ -150,58 +162,81 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -150,58 +162,81 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
)); ));
// Host -> Disk offload // Host -> Disk offload
let host_clone = this.host.clone(); let host_to_disk_task = OffloadManager::offload_worker(
let disk_clone = this.disk.clone(); this.host.clone(),
let transfer_ctx_clone = transfer_ctx.clone(); this.disk.clone(),
async_rt_handle.spawn(async move { host_offload_rx,
let res = OffloadManager::offload_worker( Arc::new(TransferBatcher::new(
host_clone, DiskTransferManager::new(
disk_clone, transfer_ctx.clone(),
host_offload_rx, MAX_CONCURRENT_TRANSFERS,
Arc::new(TransferBatcher::new( &async_rt_handle,
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS), cancellation_token.clone(),
MAX_TRANSFER_BATCH_SIZE, ),
)), MAX_TRANSFER_BATCH_SIZE,
) &async_rt_handle,
.await; cancellation_token.clone(),
tracing::warn!("Offload worker terminated: {:?}", res); )),
}); cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| host_to_disk_task,
cancellation_token.clone(),
"Host -> Disk offload worker",
&async_rt_handle,
)?
.detach();
// Host -> Device onboarding // Host -> Device onboarding
let host_clone = this.host.clone(); let host_to_device_task = OffloadManager::onboard_worker(
let device_clone = this.device.clone(); this.host.clone(),
let transfer_ctx_clone = transfer_ctx.clone(); this.device.clone(),
async_rt_handle.spawn(async move { host_onboard_rx,
let res = OffloadManager::onboard_worker( Arc::new(TransferBatcher::new(
host_clone, CudaTransferManager::new(
device_clone, transfer_ctx.clone(),
host_onboard_rx, MAX_CONCURRENT_TRANSFERS,
Arc::new(TransferBatcher::new( cancellation_token.clone(),
CudaTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS), ),
MAX_TRANSFER_BATCH_SIZE, MAX_TRANSFER_BATCH_SIZE,
)), &async_rt_handle,
) cancellation_token.clone(),
.await; )),
tracing::warn!("Onboard worker terminated: {:?}", res); cancellation_token.clone(),
}); );
CriticalTaskExecutionHandle::new_with_runtime(
|_| host_to_device_task,
cancellation_token.clone(),
"Host -> Device onboarding worker",
&async_rt_handle,
)?
.detach();
// Disk -> Device onboarding // Disk -> Device onboarding
let disk_clone = this.disk.clone(); let disk_to_device_task = OffloadManager::onboard_worker(
let device_clone = this.device.clone(); this.disk.clone(),
let transfer_ctx_clone = transfer_ctx.clone(); this.device.clone(),
async_rt_handle.spawn(async move { disk_onboard_rx,
let res = OffloadManager::onboard_worker( Arc::new(TransferBatcher::new(
disk_clone, DiskTransferManager::new(
device_clone, transfer_ctx.clone(),
disk_onboard_rx, MAX_CONCURRENT_TRANSFERS,
Arc::new(TransferBatcher::new( &async_rt_handle,
DiskTransferManager::new(transfer_ctx_clone, MAX_CONCURRENT_TRANSFERS), cancellation_token.clone(),
MAX_TRANSFER_BATCH_SIZE, ),
)), MAX_TRANSFER_BATCH_SIZE,
) &async_rt_handle,
.await; cancellation_token.clone(),
tracing::warn!("Onboard worker terminated: {:?}", res); )),
}); cancellation_token.clone(),
);
CriticalTaskExecutionHandle::new_with_runtime(
|_| disk_to_device_task,
cancellation_token.clone(),
"Disk -> Device onboarding worker",
&async_rt_handle,
)?
.detach();
Ok(this_clone) Ok(this_clone)
} }
...@@ -211,6 +246,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -211,6 +246,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool: Option<Arc<BlockPool<Target, Metadata>>>, target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>, mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>, transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
cancellation_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
if source_pool.is_none() || target_pool.is_none() { if source_pool.is_none() || target_pool.is_none() {
return Ok(()); return Ok(());
...@@ -222,6 +258,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -222,6 +258,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let mut queue = BTreeSet::new(); let mut queue = BTreeSet::new();
loop { loop {
if cancellation_token.is_cancelled() {
return Ok(());
}
// Try to check the offload queue. // Try to check the offload queue.
loop { loop {
match offload_rx.try_recv() { match offload_rx.try_recv() {
...@@ -231,7 +271,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -231,7 +271,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
Err(TryRecvError::Empty) => { Err(TryRecvError::Empty) => {
break; break;
} }
Err(_) => return Ok(()), Err(e) => return Err(e.into()),
} }
} }
...@@ -280,8 +320,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -280,8 +320,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
} }
} else { } else {
// Await the next request. // Await the next request.
if let Some(request) = offload_rx.recv().await { tokio::select! {
queue.insert(request); _ = cancellation_token.cancelled() => return Ok(()),
Some(request) = offload_rx.recv() => {
queue.insert(request);
}
} }
} }
} }
...@@ -292,40 +335,45 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -292,40 +335,45 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool: Option<Arc<BlockPool<Target, Metadata>>>, target_pool: Option<Arc<BlockPool<Target, Metadata>>>,
mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>, mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Metadata>>,
transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>, transfer_manager: Arc<dyn TransferManager<Source, Target, Metadata>>,
cancellation_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
if source_pool.is_none() || target_pool.is_none() { if source_pool.is_none() || target_pool.is_none() {
return Ok(()); return Ok(());
} }
let target_pool = target_pool.as_ref().unwrap(); let target_pool = target_pool.as_ref().unwrap();
loop {
tokio::select! {
_ = cancellation_token.cancelled() => return Ok::<(), anyhow::Error>(()),
Some(request) = onboard_rx.recv() => {
// Try to allocate blocks on the device.
let target_blocks = match target_pool.allocate_blocks(request.blocks.len()).await {
Ok(blocks) => blocks,
Err(err) => {
request.response_tx.send(Err(err))?;
continue;
}
};
// Loop on incoming requests let sources = request
while let Some(request) = onboard_rx.recv().await { .blocks
// Try to allocate blocks on the device. .iter()
let target_blocks = match target_pool.allocate_blocks(request.blocks.len()).await { .map(|b| b.mutable_block().clone())
Ok(blocks) => blocks, .collect();
Err(err) => {
request.response_tx.send(Err(err))?; transfer_manager
continue; .enqueue_transfer(PendingTransfer::new(
sources,
target_blocks,
Some(request.response_tx),
target_pool.clone(),
))
.await?;
Ok::<(), anyhow::Error>(())
} }
}; }?;
let sources = request
.blocks
.iter()
.map(|b| b.mutable_block().clone())
.collect();
transfer_manager
.enqueue_transfer(PendingTransfer::new(
sources,
target_blocks,
Some(request.response_tx),
target_pool.clone(),
))
.await?;
} }
Ok(())
} }
pub async fn offload<S: Storage>( pub async fn offload<S: Storage>(
...@@ -568,6 +616,7 @@ mod tests { ...@@ -568,6 +616,7 @@ mod tests {
device_pool.clone(), device_pool.clone(),
agent_arc, agent_arc,
async_rt_handle, async_rt_handle,
CancellationToken::new(),
)?; )?;
Ok((manager, device_pool, host_pool, disk_pool)) Ok((manager, device_pool, host_pool, disk_pool))
......
...@@ -42,7 +42,9 @@ use std::marker::PhantomData; ...@@ -42,7 +42,9 @@ use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::thread::spawn; use std::thread::spawn;
use tokio::runtime::Handle;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::block_manager::block::{ use crate::block_manager::block::{
transfer::{WriteTo, WriteToStrategy}, transfer::{WriteTo, WriteToStrategy},
...@@ -60,6 +62,8 @@ use futures::{stream::FuturesUnordered, StreamExt}; ...@@ -60,6 +62,8 @@ use futures::{stream::FuturesUnordered, StreamExt};
use super::BlockResult; use super::BlockResult;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
/// Manage a set of pending transfers. /// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> { pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from. /// The block being copied from.
...@@ -153,7 +157,11 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block ...@@ -153,7 +157,11 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
CudaTransferManager<Source, Target, Metadata> CudaTransferManager<Source, Target, Metadata>
{ {
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self { pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
cancellation_token: CancellationToken,
) -> Self {
let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>( let (tx, mut rx) = mpsc::channel::<(PendingTransfer<Source, Target, Metadata>, CudaEvent)>(
max_concurrent_transfers, max_concurrent_transfers,
); );
...@@ -171,6 +179,12 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata> ...@@ -171,6 +179,12 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
tracing::warn!("Error handling transfer completion: {:?}", e); tracing::warn!("Error handling transfer completion: {:?}", e);
} }
} }
// Flush any remaining transfers.
if cancellation_token.is_cancelled() {
while rx.blocking_recv().is_some() {}
break;
}
} }
Ok::<(), anyhow::Error>(()) Ok::<(), anyhow::Error>(())
}); });
...@@ -228,16 +242,28 @@ pub struct DiskTransferManager { ...@@ -228,16 +242,28 @@ pub struct DiskTransferManager {
} }
impl DiskTransferManager { impl DiskTransferManager {
pub fn new(transfer_ctx: Arc<TransferContext>, max_concurrent_transfers: usize) -> Self { pub fn new(
transfer_ctx: Arc<TransferContext>,
max_concurrent_transfers: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
let (futures_tx, mut futures_rx) = mpsc::channel(1); let (futures_tx, mut futures_rx) = mpsc::channel(1);
tokio::spawn(async move { runtime.spawn(async move {
// Keep track of our pending transfers. // Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones. // Consume the futures as they complete, while also receiving new ones.
let mut pending_transfers = FuturesUnordered::new(); let mut pending_transfers = FuturesUnordered::new();
loop { loop {
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => {
// Flush remaining transfers.
while (pending_transfers.next().await).is_some() {}
return;
}
Some(future) = futures_rx.recv() => { Some(future) = futures_rx.recv() => {
// If we're at max size, block the worker thread on the next() call until we have capacity. // If we're at max size, block the worker thread on the next() call until we have capacity.
while pending_transfers.len() >= max_concurrent_transfers { while pending_transfers.len() >= max_concurrent_transfers {
...@@ -249,10 +275,6 @@ impl DiskTransferManager { ...@@ -249,10 +275,6 @@ impl DiskTransferManager {
Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => { Some(_) = pending_transfers.next(), if !pending_transfers.is_empty() => {
// A transfer completed, just continue to process more // A transfer completed, just continue to process more
} }
else => {
// Both branches are pending, wait for one to become ready
tokio::task::yield_now().await;
}
} }
} }
}); });
...@@ -317,6 +339,8 @@ where ...@@ -317,6 +339,8 @@ where
{ {
transfer_manager: Manager, transfer_manager: Manager,
max_transfer_batch_size: usize, max_transfer_batch_size: usize,
runtime: Handle,
cancellation_token: CancellationToken,
_phantom: PhantomData<(Source, Target, Metadata)>, _phantom: PhantomData<(Source, Target, Metadata)>,
} }
...@@ -327,10 +351,17 @@ where ...@@ -327,10 +351,17 @@ where
Metadata: BlockMetadata, Metadata: BlockMetadata,
Manager: TransferManager<Source, Target, Metadata>, Manager: TransferManager<Source, Target, Metadata>,
{ {
pub fn new(transfer_manager: Manager, max_transfer_batch_size: usize) -> Self { pub fn new(
transfer_manager: Manager,
max_transfer_batch_size: usize,
runtime: &Handle,
cancellation_token: CancellationToken,
) -> Self {
Self { Self {
transfer_manager, transfer_manager,
max_transfer_batch_size, max_transfer_batch_size,
runtime: runtime.clone(),
cancellation_token,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
...@@ -391,25 +422,40 @@ where ...@@ -391,25 +422,40 @@ where
} }
if let Some(completion_indicator) = completion_indicator { if let Some(completion_indicator) = completion_indicator {
tokio::spawn(async move { CriticalTaskExecutionHandle::new_with_runtime(
let mut results = Vec::new(); move |cancel_token| async move {
let mut results = Vec::new();
for indicator in indicators.into_iter() {
// Await each sub-transfer, and append the results to our final results. for indicator in indicators.into_iter() {
let result = match indicator.await.unwrap() { // Await each sub-transfer, and append the results to our final results.
Ok(result) => result, tokio::select! {
Err(e) => { _ = cancel_token.cancelled() => {
tracing::error!("Error receiving transfer results: {:?}", e); return Ok(());
completion_indicator.send(Err(e)).unwrap(); }
return;
Ok(indicator) = indicator => {
let result = match indicator {
Ok(result) => result,
Err(e) => {
tracing::error!("Error receiving transfer results: {:?}", e);
completion_indicator.send(Err(e)).unwrap();
return Ok(());
}
};
results.extend(result);
}
} }
}; }
results.extend(result);
} // Send the final results to the top-level completion indicator.
completion_indicator.send(Ok(results))?;
// Send the final results to the top-level completion indicator. Ok(())
completion_indicator.send(Ok(results)).unwrap(); },
}); self.cancellation_token.clone(),
"Transfer Batcher",
&self.runtime,
)?.detach();
} }
Ok(()) Ok(())
......
...@@ -212,6 +212,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -212,6 +212,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool.clone(), device_pool.clone(),
nixl_agent.clone(), nixl_agent.clone(),
async_rt_handle, async_rt_handle,
cancellation_token.clone(),
)?; )?;
let state = Arc::new(Self { let state = Arc::new(Self {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::future::Future; use std::future::Future;
use tokio::runtime::Handle;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -41,20 +42,34 @@ pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send ...@@ -41,20 +42,34 @@ pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send
pub struct CriticalTaskExecutionHandle { pub struct CriticalTaskExecutionHandle {
monitor_task: JoinHandle<()>, monitor_task: JoinHandle<()>,
graceful_shutdown_token: CancellationToken, graceful_shutdown_token: CancellationToken,
result_receiver: oneshot::Receiver<Result<()>>, result_receiver: Option<oneshot::Receiver<Result<()>>>,
detached: bool,
} }
impl CriticalTaskExecutionHandle { impl CriticalTaskExecutionHandle {
pub fn new<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken,
description: &str,
) -> Result<Self>
where
Fut: Future<Output = Result<()>> + Send + 'static,
{
Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
}
/// Create a new [CriticalTaskExecutionHandle] for a critical task. /// Create a new [CriticalTaskExecutionHandle] for a critical task.
/// ///
/// # Arguments /// # Arguments
/// * `task_fn` - A function that takes a cancellation token and returns the critical task future /// * `task_fn` - A function that takes a cancellation token and returns the critical task future
/// * `parent_token` - Token that will be cancelled if this critical task fails /// * `parent_token` - Token that will be cancelled if this critical task fails
/// * `description` - Description for logging purposes /// * `description` - Description for logging purposes
pub async fn new<Fut>( /// * `runtime` - The runtime to use for the task.
pub fn new_with_runtime<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static, task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken, parent_token: CancellationToken,
description: &str, description: &str,
runtime: &Handle,
) -> Result<Self> ) -> Result<Self>
where where
Fut: Future<Output = Result<()>> + Send + 'static, Fut: Future<Output = Result<()>> + Send + 'static,
...@@ -68,7 +83,7 @@ impl CriticalTaskExecutionHandle { ...@@ -68,7 +83,7 @@ impl CriticalTaskExecutionHandle {
let graceful_shutdown_token_clone = graceful_shutdown_token.clone(); let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
let description_clone = description.to_string(); let description_clone = description.to_string();
let task = tokio::spawn(async move { let task = runtime.spawn(async move {
let future = task_fn(graceful_shutdown_token_clone); let future = task_fn(graceful_shutdown_token_clone);
match future.await { match future.await {
...@@ -92,7 +107,7 @@ impl CriticalTaskExecutionHandle { ...@@ -92,7 +107,7 @@ impl CriticalTaskExecutionHandle {
let parent_token_monitor = parent_token_clone.clone(); let parent_token_monitor = parent_token_clone.clone();
let description_monitor = description.clone(); let description_monitor = description.clone();
tokio::spawn(async move { runtime.spawn(async move {
let result = match main_task_handle.await { let result = match main_task_handle.await {
Ok(task_result) => { Ok(task_result) => {
// Task completed normally (success or error) // Task completed normally (success or error)
...@@ -147,7 +162,8 @@ impl CriticalTaskExecutionHandle { ...@@ -147,7 +162,8 @@ impl CriticalTaskExecutionHandle {
Ok(Self { Ok(Self {
monitor_task, monitor_task,
graceful_shutdown_token, graceful_shutdown_token,
result_receiver, result_receiver: Some(result_receiver),
detached: false,
}) })
} }
...@@ -179,13 +195,28 @@ impl CriticalTaskExecutionHandle { ...@@ -179,13 +195,28 @@ impl CriticalTaskExecutionHandle {
/// - `Err(...)` if the task failed or panicked, preserving the original error /// - `Err(...)` if the task failed or panicked, preserving the original error
/// ///
/// Note: Both errors and panics trigger parent cancellation immediately via the monitor task. /// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
pub async fn join(self) -> Result<()> { pub async fn join(mut self) -> Result<()> {
match self.result_receiver.await { self.detached = true;
let result = match self.result_receiver.take().unwrap().await {
Ok(task_result) => task_result, Ok(task_result) => task_result,
Err(_) => { Err(_) => {
// This should rarely happen - means monitor task was dropped/cancelled // This should rarely happen - means monitor task was dropped/cancelled
Err(anyhow::anyhow!("Critical task monitor was cancelled")) Err(anyhow::anyhow!("Critical task monitor was cancelled"))
} }
};
result
}
/// Detach the task. This allows the task to continue running after the handle is dropped.
pub fn detach(mut self) {
self.detached = true;
}
}
impl Drop for CriticalTaskExecutionHandle {
fn drop(&mut self) {
if !self.detached {
panic!("Critical task was not detached prior to drop!");
} }
} }
} }
...@@ -218,7 +249,6 @@ mod tests { ...@@ -218,7 +249,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"test-success-task", "test-success-task",
) )
.await
.unwrap(); .unwrap();
// Task should complete successfully // Task should complete successfully
...@@ -245,7 +275,6 @@ mod tests { ...@@ -245,7 +275,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"test-failure-task", "test-failure-task",
) )
.await
.unwrap(); .unwrap();
// Task should fail and cancel parent token // Task should fail and cancel parent token
...@@ -284,7 +313,6 @@ mod tests { ...@@ -284,7 +313,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"test-panic-task", "test-panic-task",
) )
.await
.unwrap(); .unwrap();
// Panic should be caught and converted to error // Panic should be caught and converted to error
...@@ -328,7 +356,6 @@ mod tests { ...@@ -328,7 +356,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"test-graceful-shutdown", "test-graceful-shutdown",
) )
.await
.unwrap(); .unwrap();
// Let task do some work // Let task do some work
...@@ -381,7 +408,6 @@ mod tests { ...@@ -381,7 +408,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"long-running-task", "long-running-task",
) )
.await
.unwrap(); .unwrap();
let handle2 = CriticalTaskExecutionHandle::new( let handle2 = CriticalTaskExecutionHandle::new(
...@@ -393,7 +419,6 @@ mod tests { ...@@ -393,7 +419,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"failing-task", "failing-task",
) )
.await
.unwrap(); .unwrap();
// Wait for task 2 to fail // Wait for task 2 to fail
...@@ -432,7 +457,6 @@ mod tests { ...@@ -432,7 +457,6 @@ mod tests {
parent_token, parent_token,
"status-test-task", "status-test-task",
) )
.await
.unwrap(); .unwrap();
// Initially task should be running // Initially task should be running
...@@ -479,7 +503,6 @@ mod tests { ...@@ -479,7 +503,6 @@ mod tests {
parent_token, parent_token,
"select-pattern-task", "select-pattern-task",
) )
.await
.unwrap(); .unwrap();
// Cancel after a short time // Cancel after a short time
...@@ -511,7 +534,6 @@ mod tests { ...@@ -511,7 +534,6 @@ mod tests {
parent_token, parent_token,
"long-task", "long-task",
) )
.await
.unwrap(); .unwrap();
// Test with timeout // Test with timeout
...@@ -532,7 +554,7 @@ mod tests { ...@@ -532,7 +554,7 @@ mod tests {
// - Demonstrates true "critical task" behavior with immediate failure propagation // - Demonstrates true "critical task" behavior with immediate failure propagation
let parent_token = CancellationToken::new(); let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new( let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move { |_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
panic!("Critical failure!"); panic!("Critical failure!");
...@@ -540,7 +562,6 @@ mod tests { ...@@ -540,7 +562,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"immediate-panic-task", "immediate-panic-task",
) )
.await
.unwrap(); .unwrap();
// Wait for the panic to be detected by monitor task // Wait for the panic to be detected by monitor task
...@@ -551,6 +572,7 @@ mod tests { ...@@ -551,6 +572,7 @@ mod tests {
parent_token.is_cancelled(), parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task panics" "Parent token should be cancelled immediately when critical task panics"
); );
assert!(handle.join().await.is_err());
} }
#[tokio::test] #[tokio::test]
...@@ -563,7 +585,7 @@ mod tests { ...@@ -563,7 +585,7 @@ mod tests {
// - Demonstrates consistent critical failure behavior // - Demonstrates consistent critical failure behavior
let parent_token = CancellationToken::new(); let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new( let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move { |_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await; tokio::time::sleep(Duration::from_millis(50)).await;
anyhow::bail!("Critical error!"); anyhow::bail!("Critical error!");
...@@ -571,7 +593,6 @@ mod tests { ...@@ -571,7 +593,6 @@ mod tests {
parent_token.clone(), parent_token.clone(),
"immediate-error-task", "immediate-error-task",
) )
.await
.unwrap(); .unwrap();
// Don't call join() - just wait for the error to be detected // Don't call join() - just wait for the error to be detected
...@@ -582,5 +603,21 @@ mod tests { ...@@ -582,5 +603,21 @@ mod tests {
parent_token.is_cancelled(), parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task errors" "Parent token should be cancelled immediately when critical task errors"
); );
assert!(handle.join().await.is_err());
}
#[tokio::test]
#[should_panic]
async fn test_task_detach() {
// Dropping without detaching should panic
let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move { Ok(()) },
parent_token,
"test-detach-task",
)
.unwrap();
// Dropping without detaching should panic
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment