Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
312ee8e2
Unverified
Commit
312ee8e2
authored
Jun 09, 2025
by
jthomson04
Committed by
GitHub
Jun 09, 2025
Browse files
feat: Restructure the KVBM WriteTo trait (#1363)
parent
3d499705
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
292 additions
and
178 deletions
+292
-178
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+1
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+37
-48
lib/llm/src/block_manager/block/transfer/context.rs
lib/llm/src/block_manager/block/transfer/context.rs
+116
-0
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+25
-26
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+11
-6
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+102
-78
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+0
-20
No files found.
lib/llm/src/block_manager/block.rs
View file @
312ee8e2
...
...
@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor;
pub
use
registry
::{
GlobalRegistry
,
RegistrationHandle
};
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
pub
use
transfer
::
TransferContext
;
use
crate
::
block_manager
::{
state
::
KvBlockManagerState
as
BlockManager
,
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
312ee8e2
...
...
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod
context
;
mod
cuda
;
mod
memcpy
;
mod
nixl
;
...
...
@@ -29,12 +30,12 @@ use crate::block_manager::storage::{
use
cudarc
::
driver
::
CudaStream
;
use
nixl_sys
::
XferOp
::{
Read
,
Write
};
use
std
::
future
::
Future
;
use
std
::
ops
::
Range
;
use
tokio
::
sync
::
oneshot
;
pub
use
crate
::
block_manager
::
state
::
TransferContext
;
pub
use
crate
::
block_manager
::
storage
::{
CudaAccessible
,
Local
,
Remote
};
pub
use
async_trait
::
async_trait
;
pub
use
context
::
TransferContext
;
/// A block that can be the target of a write
pub
trait
Writable
{}
...
...
@@ -149,19 +150,9 @@ pub trait WriteTo<Target> {
fn
write_to
(
&
self
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
bool
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn
nixl_write_to
(
&
self
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
)
->
Result
<
Option
<
oneshot
::
Receiver
<
()
>>
,
TransferError
>
;
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
Vec
<
Arc
<
RB
>>
...
...
@@ -171,15 +162,25 @@ where
fn
write_to
(
&
self
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
bool
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
{
)
->
Result
<
Option
<
oneshot
::
Receiver
<
()
>>
,
TransferError
>
{
let
(
tx
,
rx
)
=
oneshot
::
channel
();
match
RB
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
{
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy
::
copy_block
(
src
.as_ref
(),
dst
)
?
;
}
Ok
(())
if
notify
{
tx
.send
(())
.unwrap
();
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
TransferStrategy
::
CudaAsyncH2D
|
TransferStrategy
::
CudaAsyncD2H
...
...
@@ -192,17 +193,27 @@ where
RB
::
write_to_strategy
(),
)
?
;
}
Ok
(())
if
notify
{
let
(
tx
,
rx
)
=
oneshot
::
channel
();
ctx
.cuda_event
(
tx
)
?
;
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
TransferStrategy
::
Nixl
(
transfer_type
)
=>
{
std
::
mem
::
drop
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
,
transfer_type
,
)
?
);
Ok
(())
let
transfer_fut
=
nixl
::
write_blocks_to
(
self
,
dst
,
&
ctx
,
transfer_type
)
?
;
if
notify
{
ctx
.async_rt_handle
()
.spawn
(
async
move
{
transfer_fut
.await
;
tx
.send
(())
.unwrap
();
});
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Unsupported copy strategy: {:?}"
,
...
...
@@ -210,28 +221,6 @@ where
))),
}
}
fn
nixl_write_to
(
&
self
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
if
let
TransferStrategy
::
Nixl
(
transfer_type
)
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
,
transfer_type
,
)
?
)
}
else
{
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Expected NIXL transfer strategy, got: {:?}"
,
RB
::
write_to_strategy
()
)))
?
}
}
}
#[derive(Default)]
...
...
lib/llm/src/block_manager/block/transfer/context.rs
0 → 100644
View file @
312ee8e2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
,
CudaStream
};
use
nixl_sys
::
Agent
as
NixlAgent
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
JoinHandle
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio_util
::
sync
::
CancellationToken
;
pub
struct
TransferContext
{
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
async_rt_handle
:
Handle
,
cuda_event_tx
:
mpsc
::
UnboundedSender
<
(
CudaEvent
,
oneshot
::
Sender
<
()
>
)
>
,
cuda_event_worker
:
Option
<
JoinHandle
<
()
>>
,
cancel_token
:
CancellationToken
,
}
impl
TransferContext
{
pub
fn
new
(
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
async_rt_handle
:
Handle
,
)
->
Self
{
let
(
cuda_event_tx
,
mut
cuda_event_rx
)
=
mpsc
::
unbounded_channel
::
<
(
CudaEvent
,
oneshot
::
Sender
<
()
>
)
>
();
let
cancel_token
=
CancellationToken
::
new
();
let
cancel_token_clone
=
cancel_token
.clone
();
let
cuda_event_worker
=
std
::
thread
::
spawn
(
move
||
{
let
runtime
=
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.build
()
.expect
(
"Failed to build Tokio runtime for CUDA event worker."
);
runtime
.block_on
(
async
move
{
loop
{
tokio
::
select!
{
Some
((
event
,
tx
))
=
cuda_event_rx
.recv
()
=>
{
if
let
Err
(
e
)
=
event
.synchronize
()
{
tracing
::
error!
(
"Error synchronizing CUDA event: {}"
,
e
);
}
let
_
=
tx
.send
(());
}
_
=
cancel_token_clone
.cancelled
()
=>
{
break
;
}
}
}
});
});
Self
{
nixl_agent
,
stream
,
async_rt_handle
,
cuda_event_tx
,
cuda_event_worker
:
Some
(
cuda_event_worker
),
cancel_token
,
}
}
pub
fn
nixl_agent
(
&
self
)
->
Arc
<
Option
<
NixlAgent
>>
{
self
.nixl_agent
.clone
()
}
pub
fn
stream
(
&
self
)
->
&
Arc
<
CudaStream
>
{
&
self
.stream
}
pub
fn
async_rt_handle
(
&
self
)
->
&
Handle
{
&
self
.async_rt_handle
}
pub
fn
cuda_event
(
&
self
,
tx
:
oneshot
::
Sender
<
()
>
)
->
Result
<
(),
TransferError
>
{
let
event
=
self
.stream
.record_event
(
Some
(
CUevent_flags
::
CU_EVENT_BLOCKING_SYNC
))
.map_err
(|
e
|
TransferError
::
ExecutionError
(
e
.to_string
()))
?
;
self
.cuda_event_tx
.send
((
event
,
tx
))
.map_err
(|
_
|
TransferError
::
ExecutionError
(
"CUDA event worker exited."
.into
()))
?
;
Ok
(())
}
}
impl
Drop
for
TransferContext
{
fn
drop
(
&
mut
self
)
{
self
.cancel_token
.cancel
();
if
let
Some
(
handle
)
=
self
.cuda_event_worker
.take
()
{
if
let
Err
(
e
)
=
handle
.join
()
{
tracing
::
error!
(
"Error joining CUDA event worker: {:?}"
,
e
);
}
}
}
}
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
312ee8e2
...
...
@@ -16,9 +16,8 @@
use
super
::
*
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
task
::
Poll
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
XferDescList
};
use
std
::
future
::
Future
;
fn
append_xfer_request
<
Source
,
Destination
>
(
src
:
&
Arc
<
Source
>
,
...
...
@@ -87,8 +86,7 @@ where
pub
fn
write_blocks_to
<
Source
,
Destination
>
(
src
:
&
[
Arc
<
Source
>
],
dst
:
&
mut
[
Destination
],
ctx
:
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
ctx
:
&
Arc
<
TransferContext
>
,
transfer_type
:
NixlTransfer
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
...
...
@@ -136,26 +134,27 @@ where
None
,
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
(
)
?
;
let
still_pending
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
None
)
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
if
still_pending
{
Ok
(
Box
::
new
(
Box
::
pin
(
async
move
{
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
}
else
{
Poll
::
Pending
loop
{
match
nixl_agent
.get_xfer_status
(
&
xfer_req
)
{
Ok
(
false
)
=>
break
,
// Transfer is complete.
Ok
(
true
)
=>
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
5
))
.await
,
// Transfer is still in progress.
Err
(
e
)
=>
{
tracing
::
error!
(
"Error getting transfer status: {}"
,
e
);
break
;
}
}
}
})))
}
else
{
Ok
(
Box
::
new
(
std
::
future
::
ready
(())))
}
}
lib/llm/src/block_manager/offload.rs
View file @
312ee8e2
...
...
@@ -44,9 +44,8 @@
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
use
super
::
block
::{
BlockError
,
BlockMetadata
,
BlockState
,
ImmutableBlock
};
use
super
::
block
::{
BlockError
,
BlockMetadata
,
BlockState
,
ImmutableBlock
,
TransferContext
};
use
super
::
pool
::
BlockPoolError
;
use
super
::
state
::
TransferContext
;
use
super
::
storage
::{
Cuda
,
Storage
};
use
super
::{
BlockPool
,
DeviceStorage
,
DiskStorage
,
PinnedStorage
};
use
nixl_sys
::
Agent
as
NixlAgent
;
...
...
@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_offload_transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
nixl_agent
.clone
(),
cuda_ctx
.new_stream
()
?
,
async_rt_handle
.clone
(),
));
// Device -> Host offload
...
...
@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager
::
new
(
device_offload_transfer_ctx
,
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
cancellation_token
.clone
(),
...
...
@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
nixl_agent
.clone
(),
cuda_ctx
.new_stream
()
?
,
async_rt_handle
.clone
(),
));
// Host -> Disk offload
...
...
@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
cancellation_token
.clone
(),
...
...
@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager
::
new
(
transfer_ctx
.clone
(),
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
cancellation_token
.clone
(),
...
...
@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
cancellation_token
.clone
(),
...
...
@@ -549,8 +552,10 @@ mod tests {
let
agent
=
NixlAgent
::
new
(
"offload-manager"
)
.unwrap
();
let
(
_
,
ucx_params
)
=
agent
.get_plugin_params
(
"UCX"
)
.unwrap
();
let
(
_
,
gds_params
)
=
agent
.get_plugin_params
(
"GDS"
)
.unwrap
();
let
(
_
,
posix_params
)
=
agent
.get_plugin_params
(
"POSIX"
)
.unwrap
();
agent
.create_backend
(
"UCX"
,
&
ucx_params
)
.unwrap
();
agent
.create_backend
(
"GDS"
,
&
gds_params
)
.unwrap
();
agent
.create_backend
(
"POSIX"
,
&
posix_params
)
.unwrap
();
Arc
::
new
(
Some
(
agent
))
};
}
...
...
lib/llm/src/block_manager/offload/pending.rs
View file @
312ee8e2
...
...
@@ -41,23 +41,21 @@
use
std
::
marker
::
PhantomData
;
use
std
::
pin
::
Pin
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
crate
::
block_manager
::
block
::{
transfer
::{
WriteTo
,
WriteToStrategy
},
BlockError
,
BlockExt
,
BlockMetadata
,
BlockState
,
MutableBlock
,
ReadableBlock
,
WritableBlock
,
BlockError
,
BlockExt
,
BlockMetadata
,
BlockState
,
MutableBlock
,
ReadableBlock
,
TransferContext
,
WritableBlock
,
};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
state
::
TransferContext
;
use
crate
::
block_manager
::
storage
::{
Local
,
Storage
};
use
crate
::
block_manager
::
BlockPool
;
use
anyhow
::
Result
;
use
async_trait
::
async_trait
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
futures
::{
stream
::
FuturesUnordered
,
StreamExt
};
use
super
::
BlockResult
;
...
...
@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
completion_indicator
.send
(
Ok
(
blocks
))
?
;
completion_indicator
.send
(
Ok
(
blocks
))
.map_err
(|
_
|
BlockPoolError
::
ProgressEngineShutdown
)
?
;
}
Ok
(())
...
...
@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
}
pub
struct
CudaTransferManager
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
pending_transfer_q
:
mpsc
::
Sender
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
,
pending_transfer_q
:
mpsc
::
Sender
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
)
>
,
transfer_ctx
:
Arc
<
TransferContext
>
,
}
...
...
@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
max_concurrent_transfers
,
);
)
->
Result
<
Self
>
{
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
)
>
(
max_concurrent_transfers
);
spawn
(
move
||
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
CriticalTaskExecutionHandle
::
new_with_runtime
(
move
|
cancel_token
|
async
move
{
loop
{
tokio
::
select!
{
Some
((
pending_transfer
,
notify
))
=
rx
.recv
()
=>
{
// Wait for the event.
event
.synchronize
(
)
?
;
notify
.await
.map_err
(|
_
|
BlockPoolError
::
ProgressEngineShutdown
)
?
;
// Only finalize the transfer after the event is signaled.
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
// The only case where this can fail is if the progress engine is shutdown.
// The only case where this can fail is if the progress engine is
being
shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
}
// Flush any remaining transfers.
if
cancellation_token
.is_cancelled
()
{
while
rx
.blocking_recv
()
.is_some
()
{}
break
;
_
=
cancel_token
.cancelled
()
=>
{
return
Ok
(());
}
}
}
Ok
::
<
(),
anyhow
::
Error
>
(())
});
},
cancellation_token
.clone
(),
"Cuda Transfer Manager"
,
runtime
,
)
?
.detach
();
Self
{
Ok
(
Self
{
pending_transfer_q
:
tx
,
transfer_ctx
,
}
}
)
}
}
...
...
@@ -214,22 +226,23 @@ where
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
pending_transfer
.sources
.write_to
(
let
notify
=
pending_transfer
.sources
.write_to
(
&
mut
pending_transfer
.targets
,
Non
e
,
tru
e
,
self
.transfer_ctx
.clone
(),
)
?
;
// Use a cuda event to record the completion of the transfers.
let
event
=
self
.transfer_ctx
.stream
()
.record_event
(
Some
(
CUevent_flags
::
CU_EVENT_BLOCKING_SYNC
))
?
;
)
?
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"write_to returned None when notify was true. This should never happen!"
)
})
?
;
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
self
.pending_transfer_q
.send
((
pending_transfer
,
event
))
.send
((
pending_transfer
,
notify
))
.await
?
;
Ok
(())
...
...
@@ -247,10 +260,11 @@ impl DiskTransferManager {
max_concurrent_transfers
:
usize
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
)
->
Result
<
Self
>
{
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
runtime
.spawn
(
async
move
{
CriticalTaskExecutionHandle
::
new_with_runtime
(
move
|
cancel_token
|
async
move
{
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
...
...
@@ -258,10 +272,8 @@ impl DiskTransferManager {
loop
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
// Flush remaining transfers.
while
(
pending_transfers
.next
()
.await
)
.is_some
()
{}
return
;
_
=
cancel_token
.cancelled
()
=>
{
return
Ok
(());
}
Some
(
future
)
=
futures_rx
.recv
()
=>
{
...
...
@@ -277,12 +289,17 @@ impl DiskTransferManager {
}
}
}
});
},
cancellation_token
.clone
(),
"Disk Transfer Manager"
,
runtime
,
)
?
.detach
();
Self
{
Ok
(
Self
{
futures_tx
,
transfer_ctx
,
}
}
)
}
}
...
...
@@ -303,14 +320,21 @@ where
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
let
future
=
pending_transfer
.sources
.nixl_write_to
(
let
notify
=
pending_transfer
.sources
.write_to
(
&
mut
pending_transfer
.targets
,
Non
e
,
tru
e
,
self
.transfer_ctx
.clone
(),
)
?
;
)
?
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"write_to returned None when notify was true. This should never happen!"
)
})
?
;
let
completion_future
=
async
move
{
let
_
=
future
.await
;
let
_
=
notify
.await
;
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
...
...
lib/llm/src/block_manager/state.rs
View file @
312ee8e2
...
...
@@ -21,29 +21,9 @@ use super::{
config
::
NixlOptions
,
events
::{
EventManager
,
NullEventManager
},
};
use
cudarc
::
driver
::
CudaStream
;
use
std
::
sync
::
Arc
;
use
tokio
::
runtime
::
Handle
;
pub
struct
TransferContext
{
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
}
impl
TransferContext
{
pub
fn
new
(
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
)
->
Self
{
Self
{
nixl_agent
,
stream
}
}
pub
fn
nixl_agent
(
&
self
)
->
Arc
<
Option
<
NixlAgent
>>
{
self
.nixl_agent
.clone
()
}
pub
fn
stream
(
&
self
)
->
&
Arc
<
CudaStream
>
{
&
self
.stream
}
}
#[allow(dead_code)]
pub
struct
KvBlockManagerState
<
Metadata
:
BlockMetadata
>
{
worker_id
:
WorkerID
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment