Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
25c711f8
Unverified
Commit
25c711f8
authored
Jun 03, 2025
by
jthomson04
Committed by
GitHub
Jun 04, 2025
Browse files
feat: Integrate KVBM with `CriticalTaskHandle` (#1321)
parent
8deb3ea4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
269 additions
and
136 deletions
+269
-136
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+140
-91
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+71
-25
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+1
-0
lib/runtime/src/utils/task.rs
lib/runtime/src/utils/task.rs
+57
-20
No files found.
lib/llm/src/block_manager/offload.rs
View file @
25c711f8
...
@@ -56,6 +56,7 @@ use tokio::sync::{
...
@@ -56,6 +56,7 @@ use tokio::sync::{
mpsc
::{
self
,
error
::
TryRecvError
},
mpsc
::{
self
,
error
::
TryRecvError
},
Mutex
,
Mutex
,
};
};
use
tokio_util
::
sync
::
CancellationToken
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
std
::
any
::
Any
;
use
std
::
any
::
Any
;
...
@@ -70,6 +71,8 @@ use pending::{
...
@@ -70,6 +71,8 @@ use pending::{
};
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
use
dynamo_runtime
::
utils
::
task
::
CriticalTaskExecutionHandle
;
const
MAX_CONCURRENT_TRANSFERS
:
usize
=
4
;
const
MAX_CONCURRENT_TRANSFERS
:
usize
=
4
;
const
MAX_TRANSFER_BATCH_SIZE
:
usize
=
16
;
const
MAX_TRANSFER_BATCH_SIZE
:
usize
=
16
;
...
@@ -99,6 +102,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -99,6 +102,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
async_rt_handle
:
Handle
,
async_rt_handle
:
Handle
,
cancellation_token
:
CancellationToken
,
)
->
Result
<
Arc
<
Self
>>
{
)
->
Result
<
Arc
<
Self
>>
{
let
(
device_offload_tx
,
device_offload_rx
)
=
mpsc
::
unbounded_channel
();
let
(
device_offload_tx
,
device_offload_rx
)
=
mpsc
::
unbounded_channel
();
let
(
host_offload_tx
,
host_offload_rx
)
=
mpsc
::
unbounded_channel
();
let
(
host_offload_tx
,
host_offload_rx
)
=
mpsc
::
unbounded_channel
();
...
@@ -128,21 +132,29 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -128,21 +132,29 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
));
));
// Device -> Host offload
// Device -> Host offload
let
device_clone
=
this
.device
.clone
();
let
device_to_host_task
=
OffloadManager
::
offload_worker
(
let
host_clone
=
this
.host
.clone
();
this
.device
.clone
(),
async_rt_handle
.spawn
(
async
move
{
this
.host
.clone
(),
let
res
=
OffloadManager
::
offload_worker
(
device_offload_rx
,
device_clone
,
Arc
::
new
(
TransferBatcher
::
new
(
host_clone
,
CudaTransferManager
::
new
(
device_offload_rx
,
device_offload_transfer_ctx
,
Arc
::
new
(
TransferBatcher
::
new
(
MAX_CONCURRENT_TRANSFERS
,
CudaTransferManager
::
new
(
device_offload_transfer_ctx
,
MAX_CONCURRENT_TRANSFERS
),
cancellation_token
.clone
(),
MAX_TRANSFER_BATCH_SIZE
,
),
)),
MAX_TRANSFER_BATCH_SIZE
,
)
&
async_rt_handle
,
.await
;
cancellation_token
.clone
(),
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
)),
});
cancellation_token
.clone
(),
);
CriticalTaskExecutionHandle
::
new_with_runtime
(
|
_
|
device_to_host_task
,
cancellation_token
.clone
(),
"Device -> Host offload worker"
,
&
async_rt_handle
,
)
?
.detach
();
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
nixl_agent
.clone
(),
nixl_agent
.clone
(),
...
@@ -150,58 +162,81 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -150,58 +162,81 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
));
));
// Host -> Disk offload
// Host -> Disk offload
let
host_clone
=
this
.host
.clone
();
let
host_to_disk_task
=
OffloadManager
::
offload_worker
(
let
disk_clone
=
this
.disk
.clone
();
this
.host
.clone
(),
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
this
.disk
.clone
(),
async_rt_handle
.spawn
(
async
move
{
host_offload_rx
,
let
res
=
OffloadManager
::
offload_worker
(
Arc
::
new
(
TransferBatcher
::
new
(
host_clone
,
DiskTransferManager
::
new
(
disk_clone
,
transfer_ctx
.clone
(),
host_offload_rx
,
MAX_CONCURRENT_TRANSFERS
,
Arc
::
new
(
TransferBatcher
::
new
(
&
async_rt_handle
,
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
cancellation_token
.clone
(),
MAX_TRANSFER_BATCH_SIZE
,
),
)),
MAX_TRANSFER_BATCH_SIZE
,
)
&
async_rt_handle
,
.await
;
cancellation_token
.clone
(),
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
)),
});
cancellation_token
.clone
(),
);
CriticalTaskExecutionHandle
::
new_with_runtime
(
|
_
|
host_to_disk_task
,
cancellation_token
.clone
(),
"Host -> Disk offload worker"
,
&
async_rt_handle
,
)
?
.detach
();
// Host -> Device onboarding
// Host -> Device onboarding
let
host_clone
=
this
.host
.clone
();
let
host_to_device_task
=
OffloadManager
::
onboard_worker
(
let
device_clone
=
this
.device
.clone
();
this
.host
.clone
(),
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
this
.device
.clone
(),
async_rt_handle
.spawn
(
async
move
{
host_onboard_rx
,
let
res
=
OffloadManager
::
onboard_worker
(
Arc
::
new
(
TransferBatcher
::
new
(
host_clone
,
CudaTransferManager
::
new
(
device_clone
,
transfer_ctx
.clone
(),
host_onboard_rx
,
MAX_CONCURRENT_TRANSFERS
,
Arc
::
new
(
TransferBatcher
::
new
(
cancellation_token
.clone
(),
CudaTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
),
MAX_TRANSFER_BATCH_SIZE
,
MAX_TRANSFER_BATCH_SIZE
,
)),
&
async_rt_handle
,
)
cancellation_token
.clone
(),
.await
;
)),
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
cancellation_token
.clone
(),
});
);
CriticalTaskExecutionHandle
::
new_with_runtime
(
|
_
|
host_to_device_task
,
cancellation_token
.clone
(),
"Host -> Device onboarding worker"
,
&
async_rt_handle
,
)
?
.detach
();
// Disk -> Device onboarding
// Disk -> Device onboarding
let
disk_clone
=
this
.disk
.clone
();
let
disk_to_device_task
=
OffloadManager
::
onboard_worker
(
let
device_clone
=
this
.device
.clone
();
this
.disk
.clone
(),
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
this
.device
.clone
(),
async_rt_handle
.spawn
(
async
move
{
disk_onboard_rx
,
let
res
=
OffloadManager
::
onboard_worker
(
Arc
::
new
(
TransferBatcher
::
new
(
disk_clone
,
DiskTransferManager
::
new
(
device_clone
,
transfer_ctx
.clone
(),
disk_onboard_rx
,
MAX_CONCURRENT_TRANSFERS
,
Arc
::
new
(
TransferBatcher
::
new
(
&
async_rt_handle
,
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
cancellation_token
.clone
(),
MAX_TRANSFER_BATCH_SIZE
,
),
)),
MAX_TRANSFER_BATCH_SIZE
,
)
&
async_rt_handle
,
.await
;
cancellation_token
.clone
(),
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
)),
});
cancellation_token
.clone
(),
);
CriticalTaskExecutionHandle
::
new_with_runtime
(
|
_
|
disk_to_device_task
,
cancellation_token
.clone
(),
"Disk -> Device onboarding worker"
,
&
async_rt_handle
,
)
?
.detach
();
Ok
(
this_clone
)
Ok
(
this_clone
)
}
}
...
@@ -211,6 +246,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -211,6 +246,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
cancellation_token
:
CancellationToken
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
...
@@ -222,6 +258,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -222,6 +258,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
mut
queue
=
BTreeSet
::
new
();
let
mut
queue
=
BTreeSet
::
new
();
loop
{
loop
{
if
cancellation_token
.is_cancelled
()
{
return
Ok
(());
}
// Try to check the offload queue.
// Try to check the offload queue.
loop
{
loop
{
match
offload_rx
.try_recv
()
{
match
offload_rx
.try_recv
()
{
...
@@ -231,7 +271,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -231,7 +271,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
Err
(
TryRecvError
::
Empty
)
=>
{
Err
(
TryRecvError
::
Empty
)
=>
{
break
;
break
;
}
}
Err
(
_
)
=>
return
Ok
(
()),
Err
(
e
)
=>
return
Err
(
e
.into
()),
}
}
}
}
...
@@ -280,8 +320,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -280,8 +320,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
}
else
{
}
else
{
// Await the next request.
// Await the next request.
if
let
Some
(
request
)
=
offload_rx
.recv
()
.await
{
tokio
::
select!
{
queue
.insert
(
request
);
_
=
cancellation_token
.cancelled
()
=>
return
Ok
(()),
Some
(
request
)
=
offload_rx
.recv
()
=>
{
queue
.insert
(
request
);
}
}
}
}
}
}
}
...
@@ -292,40 +335,45 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -292,40 +335,45 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
cancellation_token
:
CancellationToken
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
}
}
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
return
Ok
::
<
(),
anyhow
::
Error
>
(()),
Some
(
request
)
=
onboard_rx
.recv
()
=>
{
// Try to allocate blocks on the device.
let
target_blocks
=
match
target_pool
.allocate_blocks
(
request
.blocks
.len
())
.await
{
Ok
(
blocks
)
=>
blocks
,
Err
(
err
)
=>
{
request
.response_tx
.send
(
Err
(
err
))
?
;
continue
;
}
};
// Loop on incoming requests
let
sources
=
request
while
let
Some
(
request
)
=
onboard_rx
.recv
()
.await
{
.blocks
// Try to allocate blocks on the device.
.iter
()
let
target_blocks
=
match
target_pool
.allocate_blocks
(
request
.blocks
.len
())
.await
{
.map
(|
b
|
b
.mutable_block
()
.clone
())
Ok
(
blocks
)
=>
blocks
,
.collect
();
Err
(
err
)
=>
{
request
.response_tx
.send
(
Err
(
err
))
?
;
transfer_manager
continue
;
.enqueue_transfer
(
PendingTransfer
::
new
(
sources
,
target_blocks
,
Some
(
request
.response_tx
),
target_pool
.clone
(),
))
.await
?
;
Ok
::
<
(),
anyhow
::
Error
>
(())
}
}
};
}
?
;
let
sources
=
request
.blocks
.iter
()
.map
(|
b
|
b
.mutable_block
()
.clone
())
.collect
();
transfer_manager
.enqueue_transfer
(
PendingTransfer
::
new
(
sources
,
target_blocks
,
Some
(
request
.response_tx
),
target_pool
.clone
(),
))
.await
?
;
}
}
Ok
(())
}
}
pub
async
fn
offload
<
S
:
Storage
>
(
pub
async
fn
offload
<
S
:
Storage
>
(
...
@@ -568,6 +616,7 @@ mod tests {
...
@@ -568,6 +616,7 @@ mod tests {
device_pool
.clone
(),
device_pool
.clone
(),
agent_arc
,
agent_arc
,
async_rt_handle
,
async_rt_handle
,
CancellationToken
::
new
(),
)
?
;
)
?
;
Ok
((
manager
,
device_pool
,
host_pool
,
disk_pool
))
Ok
((
manager
,
device_pool
,
host_pool
,
disk_pool
))
...
...
lib/llm/src/block_manager/offload/pending.rs
View file @
25c711f8
...
@@ -42,7 +42,9 @@ use std::marker::PhantomData;
...
@@ -42,7 +42,9 @@ use std::marker::PhantomData;
use
std
::
pin
::
Pin
;
use
std
::
pin
::
Pin
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
std
::
thread
::
spawn
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::
mpsc
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
crate
::
block_manager
::
block
::{
use
crate
::
block_manager
::
block
::{
transfer
::{
WriteTo
,
WriteToStrategy
},
transfer
::{
WriteTo
,
WriteToStrategy
},
...
@@ -60,6 +62,8 @@ use futures::{stream::FuturesUnordered, StreamExt};
...
@@ -60,6 +62,8 @@ use futures::{stream::FuturesUnordered, StreamExt};
use
super
::
BlockResult
;
use
super
::
BlockResult
;
use
dynamo_runtime
::
utils
::
task
::
CriticalTaskExecutionHandle
;
/// Manage a set of pending transfers.
/// Manage a set of pending transfers.
pub
struct
PendingTransfer
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
pub
struct
PendingTransfer
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
/// The block being copied from.
/// The block being copied from.
...
@@ -153,7 +157,11 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
...
@@ -153,7 +157,11 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
{
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
max_concurrent_transfers
,
max_concurrent_transfers
,
);
);
...
@@ -171,6 +179,12 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...
@@ -171,6 +179,12 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
}
}
// Flush any remaining transfers.
if
cancellation_token
.is_cancelled
()
{
while
rx
.blocking_recv
()
.is_some
()
{}
break
;
}
}
}
Ok
::
<
(),
anyhow
::
Error
>
(())
Ok
::
<
(),
anyhow
::
Error
>
(())
});
});
...
@@ -228,16 +242,28 @@ pub struct DiskTransferManager {
...
@@ -228,16 +242,28 @@ pub struct DiskTransferManager {
}
}
impl
DiskTransferManager
{
impl
DiskTransferManager
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
tokio
::
spawn
(
async
move
{
runtime
.
spawn
(
async
move
{
// Keep track of our pending transfers.
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
// Consume the futures as they complete, while also receiving new ones.
let
mut
pending_transfers
=
FuturesUnordered
::
new
();
let
mut
pending_transfers
=
FuturesUnordered
::
new
();
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
// Flush remaining transfers.
while
(
pending_transfers
.next
()
.await
)
.is_some
()
{}
return
;
}
Some
(
future
)
=
futures_rx
.recv
()
=>
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
// If we're at max size, block the worker thread on the next() call until we have capacity.
// If we're at max size, block the worker thread on the next() call until we have capacity.
while
pending_transfers
.len
()
>=
max_concurrent_transfers
{
while
pending_transfers
.len
()
>=
max_concurrent_transfers
{
...
@@ -249,10 +275,6 @@ impl DiskTransferManager {
...
@@ -249,10 +275,6 @@ impl DiskTransferManager {
Some
(
_
)
=
pending_transfers
.next
(),
if
!
pending_transfers
.is_empty
()
=>
{
Some
(
_
)
=
pending_transfers
.next
(),
if
!
pending_transfers
.is_empty
()
=>
{
// A transfer completed, just continue to process more
// A transfer completed, just continue to process more
}
}
else
=>
{
// Both branches are pending, wait for one to become ready
tokio
::
task
::
yield_now
()
.await
;
}
}
}
}
}
});
});
...
@@ -317,6 +339,8 @@ where
...
@@ -317,6 +339,8 @@ where
{
{
transfer_manager
:
Manager
,
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
,
max_transfer_batch_size
:
usize
,
runtime
:
Handle
,
cancellation_token
:
CancellationToken
,
_
phantom
:
PhantomData
<
(
Source
,
Target
,
Metadata
)
>
,
_
phantom
:
PhantomData
<
(
Source
,
Target
,
Metadata
)
>
,
}
}
...
@@ -327,10 +351,17 @@ where
...
@@ -327,10 +351,17 @@ where
Metadata
:
BlockMetadata
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
{
pub
fn
new
(
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
)
->
Self
{
pub
fn
new
(
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
Self
{
Self
{
transfer_manager
,
transfer_manager
,
max_transfer_batch_size
,
max_transfer_batch_size
,
runtime
:
runtime
.clone
(),
cancellation_token
,
_
phantom
:
PhantomData
,
_
phantom
:
PhantomData
,
}
}
}
}
...
@@ -391,25 +422,40 @@ where
...
@@ -391,25 +422,40 @@ where
}
}
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
tokio
::
spawn
(
async
move
{
CriticalTaskExecutionHandle
::
new_with_runtime
(
let
mut
results
=
Vec
::
new
();
move
|
cancel_token
|
async
move
{
let
mut
results
=
Vec
::
new
();
for
indicator
in
indicators
.into_iter
()
{
// Await each sub-transfer, and append the results to our final results.
for
indicator
in
indicators
.into_iter
()
{
let
result
=
match
indicator
.await
.unwrap
()
{
// Await each sub-transfer, and append the results to our final results.
Ok
(
result
)
=>
result
,
tokio
::
select!
{
Err
(
e
)
=>
{
_
=
cancel_token
.cancelled
()
=>
{
tracing
::
error!
(
"Error receiving transfer results: {:?}"
,
e
);
return
Ok
(());
completion_indicator
.send
(
Err
(
e
))
.unwrap
();
}
return
;
Ok
(
indicator
)
=
indicator
=>
{
let
result
=
match
indicator
{
Ok
(
result
)
=>
result
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Error receiving transfer results: {:?}"
,
e
);
completion_indicator
.send
(
Err
(
e
))
.unwrap
();
return
Ok
(());
}
};
results
.extend
(
result
);
}
}
}
};
}
results
.extend
(
result
);
}
// Send the final results to the top-level completion indicator.
completion_indicator
.send
(
Ok
(
results
))
?
;
// Send the final results to the top-level completion indicator.
Ok
(())
completion_indicator
.send
(
Ok
(
results
))
.unwrap
();
},
});
self
.cancellation_token
.clone
(),
"Transfer Batcher"
,
&
self
.runtime
,
)
?
.detach
();
}
}
Ok
(())
Ok
(())
...
...
lib/llm/src/block_manager/state.rs
View file @
25c711f8
...
@@ -212,6 +212,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -212,6 +212,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool
.clone
(),
device_pool
.clone
(),
nixl_agent
.clone
(),
nixl_agent
.clone
(),
async_rt_handle
,
async_rt_handle
,
cancellation_token
.clone
(),
)
?
;
)
?
;
let
state
=
Arc
::
new
(
Self
{
let
state
=
Arc
::
new
(
Self
{
...
...
lib/runtime/src/utils/task.rs
View file @
25c711f8
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
use
anyhow
::{
Context
,
Result
};
use
anyhow
::{
Context
,
Result
};
use
std
::
future
::
Future
;
use
std
::
future
::
Future
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::
oneshot
;
use
tokio
::
sync
::
oneshot
;
use
tokio
::
task
::
JoinHandle
;
use
tokio
::
task
::
JoinHandle
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
...
@@ -41,20 +42,34 @@ pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send
...
@@ -41,20 +42,34 @@ pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send
pub
struct
CriticalTaskExecutionHandle
{
pub
struct
CriticalTaskExecutionHandle
{
monitor_task
:
JoinHandle
<
()
>
,
monitor_task
:
JoinHandle
<
()
>
,
graceful_shutdown_token
:
CancellationToken
,
graceful_shutdown_token
:
CancellationToken
,
result_receiver
:
oneshot
::
Receiver
<
Result
<
()
>>
,
result_receiver
:
Option
<
oneshot
::
Receiver
<
Result
<
()
>>>
,
detached
:
bool
,
}
}
impl
CriticalTaskExecutionHandle
{
impl
CriticalTaskExecutionHandle
{
pub
fn
new
<
Fut
>
(
task_fn
:
impl
FnOnce
(
CancellationToken
)
->
Fut
+
Send
+
'static
,
parent_token
:
CancellationToken
,
description
:
&
str
,
)
->
Result
<
Self
>
where
Fut
:
Future
<
Output
=
Result
<
()
>>
+
Send
+
'static
,
{
Self
::
new_with_runtime
(
task_fn
,
parent_token
,
description
,
&
Handle
::
try_current
()
?
)
}
/// Create a new [CriticalTaskExecutionHandle] for a critical task.
/// Create a new [CriticalTaskExecutionHandle] for a critical task.
///
///
/// # Arguments
/// # Arguments
/// * `task_fn` - A function that takes a cancellation token and returns the critical task future
/// * `task_fn` - A function that takes a cancellation token and returns the critical task future
/// * `parent_token` - Token that will be cancelled if this critical task fails
/// * `parent_token` - Token that will be cancelled if this critical task fails
/// * `description` - Description for logging purposes
/// * `description` - Description for logging purposes
pub
async
fn
new
<
Fut
>
(
/// * `runtime` - The runtime to use for the task.
pub
fn
new_with_runtime
<
Fut
>
(
task_fn
:
impl
FnOnce
(
CancellationToken
)
->
Fut
+
Send
+
'static
,
task_fn
:
impl
FnOnce
(
CancellationToken
)
->
Fut
+
Send
+
'static
,
parent_token
:
CancellationToken
,
parent_token
:
CancellationToken
,
description
:
&
str
,
description
:
&
str
,
runtime
:
&
Handle
,
)
->
Result
<
Self
>
)
->
Result
<
Self
>
where
where
Fut
:
Future
<
Output
=
Result
<
()
>>
+
Send
+
'static
,
Fut
:
Future
<
Output
=
Result
<
()
>>
+
Send
+
'static
,
...
@@ -68,7 +83,7 @@ impl CriticalTaskExecutionHandle {
...
@@ -68,7 +83,7 @@ impl CriticalTaskExecutionHandle {
let
graceful_shutdown_token_clone
=
graceful_shutdown_token
.clone
();
let
graceful_shutdown_token_clone
=
graceful_shutdown_token
.clone
();
let
description_clone
=
description
.to_string
();
let
description_clone
=
description
.to_string
();
let
task
=
tokio
::
spawn
(
async
move
{
let
task
=
runtime
.
spawn
(
async
move
{
let
future
=
task_fn
(
graceful_shutdown_token_clone
);
let
future
=
task_fn
(
graceful_shutdown_token_clone
);
match
future
.await
{
match
future
.await
{
...
@@ -92,7 +107,7 @@ impl CriticalTaskExecutionHandle {
...
@@ -92,7 +107,7 @@ impl CriticalTaskExecutionHandle {
let
parent_token_monitor
=
parent_token_clone
.clone
();
let
parent_token_monitor
=
parent_token_clone
.clone
();
let
description_monitor
=
description
.clone
();
let
description_monitor
=
description
.clone
();
tokio
::
spawn
(
async
move
{
runtime
.
spawn
(
async
move
{
let
result
=
match
main_task_handle
.await
{
let
result
=
match
main_task_handle
.await
{
Ok
(
task_result
)
=>
{
Ok
(
task_result
)
=>
{
// Task completed normally (success or error)
// Task completed normally (success or error)
...
@@ -147,7 +162,8 @@ impl CriticalTaskExecutionHandle {
...
@@ -147,7 +162,8 @@ impl CriticalTaskExecutionHandle {
Ok
(
Self
{
Ok
(
Self
{
monitor_task
,
monitor_task
,
graceful_shutdown_token
,
graceful_shutdown_token
,
result_receiver
,
result_receiver
:
Some
(
result_receiver
),
detached
:
false
,
})
})
}
}
...
@@ -179,13 +195,28 @@ impl CriticalTaskExecutionHandle {
...
@@ -179,13 +195,28 @@ impl CriticalTaskExecutionHandle {
/// - `Err(...)` if the task failed or panicked, preserving the original error
/// - `Err(...)` if the task failed or panicked, preserving the original error
///
///
/// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
/// Note: Both errors and panics trigger parent cancellation immediately via the monitor task.
pub
async
fn
join
(
self
)
->
Result
<
()
>
{
pub
async
fn
join
(
mut
self
)
->
Result
<
()
>
{
match
self
.result_receiver
.await
{
self
.detached
=
true
;
let
result
=
match
self
.result_receiver
.take
()
.unwrap
()
.await
{
Ok
(
task_result
)
=>
task_result
,
Ok
(
task_result
)
=>
task_result
,
Err
(
_
)
=>
{
Err
(
_
)
=>
{
// This should rarely happen - means monitor task was dropped/cancelled
// This should rarely happen - means monitor task was dropped/cancelled
Err
(
anyhow
::
anyhow!
(
"Critical task monitor was cancelled"
))
Err
(
anyhow
::
anyhow!
(
"Critical task monitor was cancelled"
))
}
}
};
result
}
/// Detach the task. This allows the task to continue running after the handle is dropped.
pub
fn
detach
(
mut
self
)
{
self
.detached
=
true
;
}
}
impl
Drop
for
CriticalTaskExecutionHandle
{
fn
drop
(
&
mut
self
)
{
if
!
self
.detached
{
panic!
(
"Critical task was not detached prior to drop!"
);
}
}
}
}
}
}
...
@@ -218,7 +249,6 @@ mod tests {
...
@@ -218,7 +249,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"test-success-task"
,
"test-success-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Task should complete successfully
// Task should complete successfully
...
@@ -245,7 +275,6 @@ mod tests {
...
@@ -245,7 +275,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"test-failure-task"
,
"test-failure-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Task should fail and cancel parent token
// Task should fail and cancel parent token
...
@@ -284,7 +313,6 @@ mod tests {
...
@@ -284,7 +313,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"test-panic-task"
,
"test-panic-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Panic should be caught and converted to error
// Panic should be caught and converted to error
...
@@ -328,7 +356,6 @@ mod tests {
...
@@ -328,7 +356,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"test-graceful-shutdown"
,
"test-graceful-shutdown"
,
)
)
.await
.unwrap
();
.unwrap
();
// Let task do some work
// Let task do some work
...
@@ -381,7 +408,6 @@ mod tests {
...
@@ -381,7 +408,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"long-running-task"
,
"long-running-task"
,
)
)
.await
.unwrap
();
.unwrap
();
let
handle2
=
CriticalTaskExecutionHandle
::
new
(
let
handle2
=
CriticalTaskExecutionHandle
::
new
(
...
@@ -393,7 +419,6 @@ mod tests {
...
@@ -393,7 +419,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"failing-task"
,
"failing-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Wait for task 2 to fail
// Wait for task 2 to fail
...
@@ -432,7 +457,6 @@ mod tests {
...
@@ -432,7 +457,6 @@ mod tests {
parent_token
,
parent_token
,
"status-test-task"
,
"status-test-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Initially task should be running
// Initially task should be running
...
@@ -479,7 +503,6 @@ mod tests {
...
@@ -479,7 +503,6 @@ mod tests {
parent_token
,
parent_token
,
"select-pattern-task"
,
"select-pattern-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Cancel after a short time
// Cancel after a short time
...
@@ -511,7 +534,6 @@ mod tests {
...
@@ -511,7 +534,6 @@ mod tests {
parent_token
,
parent_token
,
"long-task"
,
"long-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Test with timeout
// Test with timeout
...
@@ -532,7 +554,7 @@ mod tests {
...
@@ -532,7 +554,7 @@ mod tests {
// - Demonstrates true "critical task" behavior with immediate failure propagation
// - Demonstrates true "critical task" behavior with immediate failure propagation
let
parent_token
=
CancellationToken
::
new
();
let
parent_token
=
CancellationToken
::
new
();
let
_
handle
=
CriticalTaskExecutionHandle
::
new
(
let
handle
=
CriticalTaskExecutionHandle
::
new
(
|
_
cancel_token
|
async
move
{
|
_
cancel_token
|
async
move
{
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
panic!
(
"Critical failure!"
);
panic!
(
"Critical failure!"
);
...
@@ -540,7 +562,6 @@ mod tests {
...
@@ -540,7 +562,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"immediate-panic-task"
,
"immediate-panic-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Wait for the panic to be detected by monitor task
// Wait for the panic to be detected by monitor task
...
@@ -551,6 +572,7 @@ mod tests {
...
@@ -551,6 +572,7 @@ mod tests {
parent_token
.is_cancelled
(),
parent_token
.is_cancelled
(),
"Parent token should be cancelled immediately when critical task panics"
"Parent token should be cancelled immediately when critical task panics"
);
);
assert
!
(
handle
.join
()
.await
.is_err
());
}
}
#[tokio::test]
#[tokio::test]
...
@@ -563,7 +585,7 @@ mod tests {
...
@@ -563,7 +585,7 @@ mod tests {
// - Demonstrates consistent critical failure behavior
// - Demonstrates consistent critical failure behavior
let
parent_token
=
CancellationToken
::
new
();
let
parent_token
=
CancellationToken
::
new
();
let
_
handle
=
CriticalTaskExecutionHandle
::
new
(
let
handle
=
CriticalTaskExecutionHandle
::
new
(
|
_
cancel_token
|
async
move
{
|
_
cancel_token
|
async
move
{
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
tokio
::
time
::
sleep
(
Duration
::
from_millis
(
50
))
.await
;
anyhow
::
bail!
(
"Critical error!"
);
anyhow
::
bail!
(
"Critical error!"
);
...
@@ -571,7 +593,6 @@ mod tests {
...
@@ -571,7 +593,6 @@ mod tests {
parent_token
.clone
(),
parent_token
.clone
(),
"immediate-error-task"
,
"immediate-error-task"
,
)
)
.await
.unwrap
();
.unwrap
();
// Don't call join() - just wait for the error to be detected
// Don't call join() - just wait for the error to be detected
...
@@ -582,5 +603,21 @@ mod tests {
...
@@ -582,5 +603,21 @@ mod tests {
parent_token
.is_cancelled
(),
parent_token
.is_cancelled
(),
"Parent token should be cancelled immediately when critical task errors"
"Parent token should be cancelled immediately when critical task errors"
);
);
assert
!
(
handle
.join
()
.await
.is_err
());
}
#[tokio::test]
#[should_panic]
async
fn
test_task_detach
()
{
// Dropping without detaching should panic
let
parent_token
=
CancellationToken
::
new
();
let
_
handle
=
CriticalTaskExecutionHandle
::
new
(
|
_
cancel_token
|
async
move
{
Ok
(())
},
parent_token
,
"test-detach-task"
,
)
.unwrap
();
// Dropping without detaching should panic
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment