Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
312ee8e2
Unverified
Commit
312ee8e2
authored
Jun 09, 2025
by
jthomson04
Committed by
GitHub
Jun 09, 2025
Browse files
feat: Restructure the KVBM WriteTo trait (#1363)
parent
3d499705
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
292 additions
and
178 deletions
+292
-178
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+1
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+37
-48
lib/llm/src/block_manager/block/transfer/context.rs
lib/llm/src/block_manager/block/transfer/context.rs
+116
-0
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+25
-26
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+11
-6
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+102
-78
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+0
-20
No files found.
lib/llm/src/block_manager/block.rs
View file @
312ee8e2
...
@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor;
...
@@ -24,6 +24,7 @@ use nixl_sys::NixlDescriptor;
pub
use
registry
::{
GlobalRegistry
,
RegistrationHandle
};
pub
use
registry
::{
GlobalRegistry
,
RegistrationHandle
};
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
pub
use
transfer
::
TransferContext
;
use
crate
::
block_manager
::{
use
crate
::
block_manager
::{
state
::
KvBlockManagerState
as
BlockManager
,
state
::
KvBlockManagerState
as
BlockManager
,
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
312ee8e2
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
mod
context
;
mod
cuda
;
mod
cuda
;
mod
memcpy
;
mod
memcpy
;
mod
nixl
;
mod
nixl
;
...
@@ -29,12 +30,12 @@ use crate::block_manager::storage::{
...
@@ -29,12 +30,12 @@ use crate::block_manager::storage::{
use
cudarc
::
driver
::
CudaStream
;
use
cudarc
::
driver
::
CudaStream
;
use
nixl_sys
::
XferOp
::{
Read
,
Write
};
use
nixl_sys
::
XferOp
::{
Read
,
Write
};
use
std
::
future
::
Future
;
use
std
::
ops
::
Range
;
use
std
::
ops
::
Range
;
use
tokio
::
sync
::
oneshot
;
pub
use
crate
::
block_manager
::
state
::
TransferContext
;
pub
use
crate
::
block_manager
::
storage
::{
CudaAccessible
,
Local
,
Remote
};
pub
use
crate
::
block_manager
::
storage
::{
CudaAccessible
,
Local
,
Remote
};
pub
use
async_trait
::
async_trait
;
pub
use
async_trait
::
async_trait
;
pub
use
context
::
TransferContext
;
/// A block that can be the target of a write
/// A block that can be the target of a write
pub
trait
Writable
{}
pub
trait
Writable
{}
...
@@ -149,19 +150,9 @@ pub trait WriteTo<Target> {
...
@@ -149,19 +150,9 @@ pub trait WriteTo<Target> {
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
Vec
<
Target
>
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
bool
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
;
)
->
Result
<
Option
<
oneshot
::
Receiver
<
()
>>
,
TransferError
>
;
/// A write_to implementation that expects a NIXL transfer.
/// If the transfer strategy is not NIXL, this method will return an error.
/// Returns a future that will complete when the transfer is complete.
fn
nixl_write_to
(
&
self
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
}
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
Vec
<
Arc
<
RB
>>
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
Vec
<
Arc
<
RB
>>
...
@@ -171,15 +162,25 @@ where
...
@@ -171,15 +162,25 @@ where
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
Vec
<
WB
>
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
bool
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
{
)
->
Result
<
Option
<
oneshot
::
Receiver
<
()
>>
,
TransferError
>
{
let
(
tx
,
rx
)
=
oneshot
::
channel
();
match
RB
::
write_to_strategy
()
{
match
RB
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
{
TransferStrategy
::
Memcpy
=>
{
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
// TODO: Unlike all other transfer strategies, this is fully blocking.
// We probably want some sort of thread pool to handle these.
memcpy
::
copy_block
(
src
.as_ref
(),
dst
)
?
;
memcpy
::
copy_block
(
src
.as_ref
(),
dst
)
?
;
}
}
Ok
(())
if
notify
{
tx
.send
(())
.unwrap
();
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
}
TransferStrategy
::
CudaAsyncH2D
TransferStrategy
::
CudaAsyncH2D
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2H
...
@@ -192,17 +193,27 @@ where
...
@@ -192,17 +193,27 @@ where
RB
::
write_to_strategy
(),
RB
::
write_to_strategy
(),
)
?
;
)
?
;
}
}
Ok
(())
if
notify
{
let
(
tx
,
rx
)
=
oneshot
::
channel
();
ctx
.cuda_event
(
tx
)
?
;
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
}
TransferStrategy
::
Nixl
(
transfer_type
)
=>
{
TransferStrategy
::
Nixl
(
transfer_type
)
=>
{
std
::
mem
::
drop
(
nixl
::
write_blocks_to
(
let
transfer_fut
=
nixl
::
write_blocks_to
(
self
,
dst
,
&
ctx
,
transfer_type
)
?
;
self
,
dst
,
if
notify
{
ctx
,
ctx
.async_rt_handle
()
.spawn
(
async
move
{
notify
,
transfer_fut
.await
;
transfer_type
,
tx
.send
(())
.unwrap
();
)
?
);
});
Ok
(())
Ok
(
Some
(
rx
))
}
else
{
Ok
(
None
)
}
}
}
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Unsupported copy strategy: {:?}"
,
"Unsupported copy strategy: {:?}"
,
...
@@ -210,28 +221,6 @@ where
...
@@ -210,28 +221,6 @@ where
))),
))),
}
}
}
}
fn
nixl_write_to
(
&
self
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
if
let
TransferStrategy
::
Nixl
(
transfer_type
)
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
,
transfer_type
,
)
?
)
}
else
{
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Expected NIXL transfer strategy, got: {:?}"
,
RB
::
write_to_strategy
()
)))
?
}
}
}
}
#[derive(Default)]
#[derive(Default)]
...
...
lib/llm/src/block_manager/block/transfer/context.rs
0 → 100644
View file @
312ee8e2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
super
::
*
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
,
CudaStream
};
use
nixl_sys
::
Agent
as
NixlAgent
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
JoinHandle
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::{
mpsc
,
oneshot
};
use
tokio_util
::
sync
::
CancellationToken
;
pub
struct
TransferContext
{
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
async_rt_handle
:
Handle
,
cuda_event_tx
:
mpsc
::
UnboundedSender
<
(
CudaEvent
,
oneshot
::
Sender
<
()
>
)
>
,
cuda_event_worker
:
Option
<
JoinHandle
<
()
>>
,
cancel_token
:
CancellationToken
,
}
impl
TransferContext
{
pub
fn
new
(
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
async_rt_handle
:
Handle
,
)
->
Self
{
let
(
cuda_event_tx
,
mut
cuda_event_rx
)
=
mpsc
::
unbounded_channel
::
<
(
CudaEvent
,
oneshot
::
Sender
<
()
>
)
>
();
let
cancel_token
=
CancellationToken
::
new
();
let
cancel_token_clone
=
cancel_token
.clone
();
let
cuda_event_worker
=
std
::
thread
::
spawn
(
move
||
{
let
runtime
=
tokio
::
runtime
::
Builder
::
new_current_thread
()
.enable_all
()
.build
()
.expect
(
"Failed to build Tokio runtime for CUDA event worker."
);
runtime
.block_on
(
async
move
{
loop
{
tokio
::
select!
{
Some
((
event
,
tx
))
=
cuda_event_rx
.recv
()
=>
{
if
let
Err
(
e
)
=
event
.synchronize
()
{
tracing
::
error!
(
"Error synchronizing CUDA event: {}"
,
e
);
}
let
_
=
tx
.send
(());
}
_
=
cancel_token_clone
.cancelled
()
=>
{
break
;
}
}
}
});
});
Self
{
nixl_agent
,
stream
,
async_rt_handle
,
cuda_event_tx
,
cuda_event_worker
:
Some
(
cuda_event_worker
),
cancel_token
,
}
}
pub
fn
nixl_agent
(
&
self
)
->
Arc
<
Option
<
NixlAgent
>>
{
self
.nixl_agent
.clone
()
}
pub
fn
stream
(
&
self
)
->
&
Arc
<
CudaStream
>
{
&
self
.stream
}
pub
fn
async_rt_handle
(
&
self
)
->
&
Handle
{
&
self
.async_rt_handle
}
pub
fn
cuda_event
(
&
self
,
tx
:
oneshot
::
Sender
<
()
>
)
->
Result
<
(),
TransferError
>
{
let
event
=
self
.stream
.record_event
(
Some
(
CUevent_flags
::
CU_EVENT_BLOCKING_SYNC
))
.map_err
(|
e
|
TransferError
::
ExecutionError
(
e
.to_string
()))
?
;
self
.cuda_event_tx
.send
((
event
,
tx
))
.map_err
(|
_
|
TransferError
::
ExecutionError
(
"CUDA event worker exited."
.into
()))
?
;
Ok
(())
}
}
impl
Drop
for
TransferContext
{
fn
drop
(
&
mut
self
)
{
self
.cancel_token
.cancel
();
if
let
Some
(
handle
)
=
self
.cuda_event_worker
.take
()
{
if
let
Err
(
e
)
=
handle
.join
()
{
tracing
::
error!
(
"Error joining CUDA event worker: {:?}"
,
e
);
}
}
}
}
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
312ee8e2
...
@@ -16,9 +16,8 @@
...
@@ -16,9 +16,8 @@
use
super
::
*
;
use
super
::
*
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
};
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
XferDescList
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
future
::
Future
;
use
std
::
task
::
Poll
;
fn
append_xfer_request
<
Source
,
Destination
>
(
fn
append_xfer_request
<
Source
,
Destination
>
(
src
:
&
Arc
<
Source
>
,
src
:
&
Arc
<
Source
>
,
...
@@ -87,8 +86,7 @@ where
...
@@ -87,8 +86,7 @@ where
pub
fn
write_blocks_to
<
Source
,
Destination
>
(
pub
fn
write_blocks_to
<
Source
,
Destination
>
(
src
:
&
[
Arc
<
Source
>
],
src
:
&
[
Arc
<
Source
>
],
dst
:
&
mut
[
Destination
],
dst
:
&
mut
[
Destination
],
ctx
:
Arc
<
TransferContext
>
,
ctx
:
&
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
transfer_type
:
NixlTransfer
,
transfer_type
:
NixlTransfer
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
where
...
@@ -136,26 +134,27 @@ where
...
@@ -136,26 +134,27 @@ where
None
,
None
,
)
?
;
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
let
still_pending
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
None
)
?
;
if
let
Some
(
notify
)
=
notify
{
if
still_pending
{
xfer_args
.set_has_notification
(
true
)
?
;
Ok
(
Box
::
new
(
Box
::
pin
(
async
move
{
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
loop
{
match
nixl_agent
.get_xfer_status
(
&
xfer_req
)
{
Ok
(
false
)
=>
break
,
// Transfer is complete.
Ok
(
true
)
=>
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
5
))
.await
,
// Transfer is still in progress.
Err
(
e
)
=>
{
tracing
::
error!
(
"Error getting transfer status: {}"
,
e
);
break
;
}
}
}
})))
}
else
{
Ok
(
Box
::
new
(
std
::
future
::
ready
(())))
}
}
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
}
else
{
Poll
::
Pending
}
})))
}
}
lib/llm/src/block_manager/offload.rs
View file @
312ee8e2
...
@@ -44,9 +44,8 @@
...
@@ -44,9 +44,8 @@
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
//! of the [`OffloadManager::offload`] and [`OffloadManager::onboard`] methods.
use
super
::
block
::{
BlockError
,
BlockMetadata
,
BlockState
,
ImmutableBlock
};
use
super
::
block
::{
BlockError
,
BlockMetadata
,
BlockState
,
ImmutableBlock
,
TransferContext
};
use
super
::
pool
::
BlockPoolError
;
use
super
::
pool
::
BlockPoolError
;
use
super
::
state
::
TransferContext
;
use
super
::
storage
::{
Cuda
,
Storage
};
use
super
::
storage
::{
Cuda
,
Storage
};
use
super
::{
BlockPool
,
DeviceStorage
,
DiskStorage
,
PinnedStorage
};
use
super
::{
BlockPool
,
DeviceStorage
,
DiskStorage
,
PinnedStorage
};
use
nixl_sys
::
Agent
as
NixlAgent
;
use
nixl_sys
::
Agent
as
NixlAgent
;
...
@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -129,6 +128,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_offload_transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
let
device_offload_transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
nixl_agent
.clone
(),
nixl_agent
.clone
(),
cuda_ctx
.new_stream
()
?
,
cuda_ctx
.new_stream
()
?
,
async_rt_handle
.clone
(),
));
));
// Device -> Host offload
// Device -> Host offload
...
@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -140,8 +140,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager
::
new
(
CudaTransferManager
::
new
(
device_offload_transfer_ctx
,
device_offload_transfer_ctx
,
MAX_CONCURRENT_TRANSFERS
,
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
...
@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -159,6 +160,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
nixl_agent
.clone
(),
nixl_agent
.clone
(),
cuda_ctx
.new_stream
()
?
,
cuda_ctx
.new_stream
()
?
,
async_rt_handle
.clone
(),
));
));
// Host -> Disk offload
// Host -> Disk offload
...
@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -172,7 +174,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS
,
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
...
@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -196,8 +198,9 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
CudaTransferManager
::
new
(
CudaTransferManager
::
new
(
transfer_ctx
.clone
(),
transfer_ctx
.clone
(),
MAX_CONCURRENT_TRANSFERS
,
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
...
@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -223,7 +226,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
MAX_CONCURRENT_TRANSFERS
,
MAX_CONCURRENT_TRANSFERS
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
),
)
?
,
MAX_TRANSFER_BATCH_SIZE
,
MAX_TRANSFER_BATCH_SIZE
,
&
async_rt_handle
,
&
async_rt_handle
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
...
@@ -549,8 +552,10 @@ mod tests {
...
@@ -549,8 +552,10 @@ mod tests {
let
agent
=
NixlAgent
::
new
(
"offload-manager"
)
.unwrap
();
let
agent
=
NixlAgent
::
new
(
"offload-manager"
)
.unwrap
();
let
(
_
,
ucx_params
)
=
agent
.get_plugin_params
(
"UCX"
)
.unwrap
();
let
(
_
,
ucx_params
)
=
agent
.get_plugin_params
(
"UCX"
)
.unwrap
();
let
(
_
,
gds_params
)
=
agent
.get_plugin_params
(
"GDS"
)
.unwrap
();
let
(
_
,
gds_params
)
=
agent
.get_plugin_params
(
"GDS"
)
.unwrap
();
let
(
_
,
posix_params
)
=
agent
.get_plugin_params
(
"POSIX"
)
.unwrap
();
agent
.create_backend
(
"UCX"
,
&
ucx_params
)
.unwrap
();
agent
.create_backend
(
"UCX"
,
&
ucx_params
)
.unwrap
();
agent
.create_backend
(
"GDS"
,
&
gds_params
)
.unwrap
();
agent
.create_backend
(
"GDS"
,
&
gds_params
)
.unwrap
();
agent
.create_backend
(
"POSIX"
,
&
posix_params
)
.unwrap
();
Arc
::
new
(
Some
(
agent
))
Arc
::
new
(
Some
(
agent
))
};
};
}
}
...
...
lib/llm/src/block_manager/offload/pending.rs
View file @
312ee8e2
...
@@ -41,23 +41,21 @@
...
@@ -41,23 +41,21 @@
use
std
::
marker
::
PhantomData
;
use
std
::
marker
::
PhantomData
;
use
std
::
pin
::
Pin
;
use
std
::
pin
::
Pin
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
sync
::
mpsc
;
use
tokio
::
sync
::
mpsc
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
crate
::
block_manager
::
block
::{
use
crate
::
block_manager
::
block
::{
transfer
::{
WriteTo
,
WriteToStrategy
},
transfer
::{
WriteTo
,
WriteToStrategy
},
BlockError
,
BlockExt
,
BlockMetadata
,
BlockState
,
MutableBlock
,
ReadableBlock
,
WritableBlock
,
BlockError
,
BlockExt
,
BlockMetadata
,
BlockState
,
MutableBlock
,
ReadableBlock
,
TransferContext
,
WritableBlock
,
};
};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
state
::
TransferContext
;
use
crate
::
block_manager
::
storage
::{
Local
,
Storage
};
use
crate
::
block_manager
::
storage
::{
Local
,
Storage
};
use
crate
::
block_manager
::
BlockPool
;
use
crate
::
block_manager
::
BlockPool
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
futures
::{
stream
::
FuturesUnordered
,
StreamExt
};
use
futures
::{
stream
::
FuturesUnordered
,
StreamExt
};
use
super
::
BlockResult
;
use
super
::
BlockResult
;
...
@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...
@@ -110,7 +108,9 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
completion_indicator
.send
(
Ok
(
blocks
))
?
;
completion_indicator
.send
(
Ok
(
blocks
))
.map_err
(|
_
|
BlockPoolError
::
ProgressEngineShutdown
)
?
;
}
}
Ok
(())
Ok
(())
...
@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
...
@@ -150,7 +150,10 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
}
}
pub
struct
CudaTransferManager
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
pub
struct
CudaTransferManager
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
pending_transfer_q
:
mpsc
::
Sender
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
,
pending_transfer_q
:
mpsc
::
Sender
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
)
>
,
transfer_ctx
:
Arc
<
TransferContext
>
,
transfer_ctx
:
Arc
<
TransferContext
>
,
}
}
...
@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...
@@ -160,39 +163,48 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
pub
fn
new
(
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
,
max_concurrent_transfers
:
usize
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
)
->
Result
<
Self
>
{
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
max_concurrent_transfers
,
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
);
tokio
::
sync
::
oneshot
::
Receiver
<
()
>
,
)
>
(
max_concurrent_transfers
);
spawn
(
move
||
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
CriticalTaskExecutionHandle
::
new_with_runtime
(
// Wait for the event.
move
|
cancel_token
|
async
move
{
event
.synchronize
()
?
;
loop
{
// Only finalize the transfer after the event is signaled.
tokio
::
select!
{
match
pending_transfer
.handle_complete
()
{
Some
((
pending_transfer
,
notify
))
=
rx
.recv
()
=>
{
Ok
(
_
)
=>
{}
// Wait for the event.
Err
(
e
)
=>
{
notify
.await
.map_err
(|
_
|
BlockPoolError
::
ProgressEngineShutdown
)
?
;
// The only case where this can fail is if the progress engine is shutdown.
// Only finalize the transfer after the event is signaled.
// This is not a problem, so we can just ignore it.
match
pending_transfer
.handle_complete
()
{
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
Ok
(
_
)
=>
{}
}
Err
(
e
)
=>
{
}
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
}
// Flush any remaining transfers.
_
=
cancel_token
.cancelled
()
=>
{
if
cancellation_token
.is_cancelled
()
{
return
Ok
(());
while
rx
.blocking_recv
()
.is_some
()
{
}
}
break
;
}
}
}
}
},
Ok
::
<
(),
anyhow
::
Error
>
(())
cancellation_token
.clone
(),
});
"Cuda Transfer Manager"
,
runtime
,
Self
{
)
?
.detach
();
Ok
(
Self
{
pending_transfer_q
:
tx
,
pending_transfer_q
:
tx
,
transfer_ctx
,
transfer_ctx
,
}
}
)
}
}
}
}
...
@@ -214,22 +226,23 @@ where
...
@@ -214,22 +226,23 @@ where
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
pending_transfer
.sources
.write_to
(
let
notify
=
pending_transfer
&
mut
pending_transfer
.targets
,
.sources
None
,
.write_to
(
self
.transfer_ctx
.clone
(),
&
mut
pending_transfer
.targets
,
)
?
;
true
,
self
.transfer_ctx
.clone
(),
// Use a cuda event to record the completion of the transfers.
)
?
let
event
=
self
.ok_or_else
(||
{
.transfer_ctx
anyhow
::
anyhow!
(
.stream
()
"write_to returned None when notify was true. This should never happen!"
.record_event
(
Some
(
CUevent_flags
::
CU_EVENT_BLOCKING_SYNC
))
?
;
)
})
?
;
// Send the pending transfer and event to the worker thread.
// Send the pending transfer and event to the worker thread.
// If the queue is full, we block the worker until space becomes available.
// If the queue is full, we block the worker until space becomes available.
self
.pending_transfer_q
self
.pending_transfer_q
.send
((
pending_transfer
,
event
))
.send
((
pending_transfer
,
notify
))
.await
?
;
.await
?
;
Ok
(())
Ok
(())
...
@@ -247,42 +260,46 @@ impl DiskTransferManager {
...
@@ -247,42 +260,46 @@ impl DiskTransferManager {
max_concurrent_transfers
:
usize
,
max_concurrent_transfers
:
usize
,
runtime
:
&
Handle
,
runtime
:
&
Handle
,
cancellation_token
:
CancellationToken
,
cancellation_token
:
CancellationToken
,
)
->
Self
{
)
->
Result
<
Self
>
{
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
runtime
.spawn
(
async
move
{
CriticalTaskExecutionHandle
::
new_with_runtime
(
// Keep track of our pending transfers.
move
|
cancel_token
|
async
move
{
// Consume the futures as they complete, while also receiving new ones.
// Keep track of our pending transfers.
// Consume the futures as they complete, while also receiving new ones.
let
mut
pending_transfers
=
FuturesUnordered
::
new
();
let
mut
pending_transfers
=
FuturesUnordered
::
new
();
loop
{
loop
{
tokio
::
select!
{
tokio
::
select!
{
_
=
cancellation_token
.cancelled
()
=>
{
_
=
cancel_token
.cancelled
()
=>
{
// Flush remaining transfers.
return
Ok
(());
while
(
pending_transfers
.next
()
.await
)
.is_some
()
{}
}
return
;
}
Some
(
future
)
=
futures_rx
.recv
()
=>
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
// If we're at max size, block the worker thread on the next() call until we have capacity.
// If we're at max size, block the worker thread on the next() call until we have capacity.
while
pending_transfers
.len
()
>=
max_concurrent_transfers
{
while
pending_transfers
.len
()
>=
max_concurrent_transfers
{
pending_transfers
.next
()
.await
;
pending_transfers
.next
()
.await
;
}
// Once we have capacity, push the new future onto the queue.
pending_transfers
.push
(
future
);
}
Some
(
_
)
=
pending_transfers
.next
(),
if
!
pending_transfers
.is_empty
()
=>
{
// A transfer completed, just continue to process more
}
}
// Once we have capacity, push the new future onto the queue.
pending_transfers
.push
(
future
);
}
Some
(
_
)
=
pending_transfers
.next
(),
if
!
pending_transfers
.is_empty
()
=>
{
// A transfer completed, just continue to process more
}
}
}
}
}
},
});
cancellation_token
.clone
(),
"Disk Transfer Manager"
,
Self
{
runtime
,
)
?
.detach
();
Ok
(
Self
{
futures_tx
,
futures_tx
,
transfer_ctx
,
transfer_ctx
,
}
}
)
}
}
}
}
...
@@ -303,14 +320,21 @@ where
...
@@ -303,14 +320,21 @@ where
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
let
future
=
pending_transfer
.sources
.nixl_write_to
(
let
notify
=
pending_transfer
&
mut
pending_transfer
.targets
,
.sources
None
,
.write_to
(
self
.transfer_ctx
.clone
(),
&
mut
pending_transfer
.targets
,
)
?
;
true
,
self
.transfer_ctx
.clone
(),
)
?
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"write_to returned None when notify was true. This should never happen!"
)
})
?
;
let
completion_future
=
async
move
{
let
completion_future
=
async
move
{
let
_
=
future
.await
;
let
_
=
notify
.await
;
match
pending_transfer
.handle_complete
()
{
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
...
...
lib/llm/src/block_manager/state.rs
View file @
312ee8e2
...
@@ -21,29 +21,9 @@ use super::{
...
@@ -21,29 +21,9 @@ use super::{
config
::
NixlOptions
,
config
::
NixlOptions
,
events
::{
EventManager
,
NullEventManager
},
events
::{
EventManager
,
NullEventManager
},
};
};
use
cudarc
::
driver
::
CudaStream
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
tokio
::
runtime
::
Handle
;
use
tokio
::
runtime
::
Handle
;
pub
struct
TransferContext
{
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
,
}
impl
TransferContext
{
pub
fn
new
(
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
stream
:
Arc
<
CudaStream
>
)
->
Self
{
Self
{
nixl_agent
,
stream
}
}
pub
fn
nixl_agent
(
&
self
)
->
Arc
<
Option
<
NixlAgent
>>
{
self
.nixl_agent
.clone
()
}
pub
fn
stream
(
&
self
)
->
&
Arc
<
CudaStream
>
{
&
self
.stream
}
}
#[allow(dead_code)]
#[allow(dead_code)]
pub
struct
KvBlockManagerState
<
Metadata
:
BlockMetadata
>
{
pub
struct
KvBlockManagerState
<
Metadata
:
BlockMetadata
>
{
worker_id
:
WorkerID
,
worker_id
:
WorkerID
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment