Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
5d5080ba
Unverified
Commit
5d5080ba
authored
May 22, 2025
by
jthomson04
Committed by
GitHub
May 22, 2025
Browse files
feat: Various KVBM improvements (#1134)
parent
d3b0cae1
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
429 additions
and
244 deletions
+429
-244
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+10
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+8
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+23
-11
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+66
-99
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+118
-73
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+163
-44
lib/llm/src/block_manager/offload/request.rs
lib/llm/src/block_manager/offload/request.rs
+18
-1
lib/llm/src/block_manager/pool/inactive.rs
lib/llm/src/block_manager/pool/inactive.rs
+4
-0
lib/llm/src/block_manager/pool/state.rs
lib/llm/src/block_manager/pool/state.rs
+3
-2
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+13
-13
lib/llm/src/block_manager/storage/disk.rs
lib/llm/src/block_manager/storage/disk.rs
+3
-1
No files found.
lib/llm/src/block_manager.rs
View file @
5d5080ba
...
@@ -192,11 +192,21 @@ mod tests {
...
@@ -192,11 +192,21 @@ mod tests {
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
// Check if we're already in a Tokio runtime context
let
async_runtime
=
if
tokio
::
runtime
::
Handle
::
try_current
()
.is_ok
()
{
None
// If we're already in a runtime, don't create a new one
}
else
{
// Only create a new runtime if not already in one
Some
(
Arc
::
new
(
tokio
::
runtime
::
Runtime
::
new
()
.unwrap
()))
};
let
config
=
KvBlockManagerConfig
::
builder
()
let
config
=
KvBlockManagerConfig
::
builder
()
.runtime
(
.runtime
(
KvManagerRuntimeConfig
::
builder
()
KvManagerRuntimeConfig
::
builder
()
.worker_id
(
worker_id
)
.worker_id
(
worker_id
)
.enable_nixl
()
.enable_nixl
()
.async_runtime
(
async_runtime
)
.build
()
.build
()
.unwrap
(),
.unwrap
(),
)
)
...
...
lib/llm/src/block_manager/block.rs
View file @
5d5080ba
...
@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
...
@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
/// Resets the metadata to the default value
/// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true
/// If called, the [BlockMetadata::is_reset()] should return true
fn
reset_metadata
(
&
mut
self
);
fn
reset_metadata
(
&
mut
self
);
/// The offload priority of the block. Higher priority blocks are offloaded first.
/// If the block should not be offloaded, return None.
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
;
}
}
/// Marker trait for types that are mutable blocks
/// Marker trait for types that are mutable blocks
...
@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
...
@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
fn
reset_metadata
(
&
mut
self
)
{
fn
reset_metadata
(
&
mut
self
)
{
self
.priority
=
0
;
self
.priority
=
0
;
}
}
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
{
Some
(
self
.priority
as
u64
)
}
}
}
/// Collection that holds shared storage and layout
/// Collection that holds shared storage and layout
#[derive(Debug)]
#[derive(Debug)]
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
5d5080ba
...
@@ -133,7 +133,7 @@ where
...
@@ -133,7 +133,7 @@ where
pub
trait
WriteTo
<
Target
>
{
pub
trait
WriteTo
<
Target
>
{
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
Target
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
;
)
->
Result
<
(),
TransferError
>
;
...
@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
...
@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
/// Returns a future that will complete when the transfer is complete.
/// Returns a future that will complete when the transfer is complete.
fn
nixl_write_to
(
fn
nixl_write_to
(
&
self
,
&
self
,
dst
:
&
mut
Target
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
}
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
RB
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
Vec
<
Arc
<
RB
>>
where
where
RB
:
WriteToStrategy
<
WB
>
+
Local
,
RB
:
WriteToStrategy
<
WB
>
+
Local
,
{
{
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
WB
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
{
)
->
Result
<
(),
TransferError
>
{
match
Self
::
write_to_strategy
()
{
match
RB
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
memcpy
::
copy_block
(
self
,
dst
),
TransferStrategy
::
Memcpy
=>
{
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
memcpy
::
copy_block
(
src
.as_ref
(),
dst
)
?
;
}
Ok
(())
}
TransferStrategy
::
CudaAsyncH2D
TransferStrategy
::
CudaAsyncH2D
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2D
=>
{
|
TransferStrategy
::
CudaAsyncD2D
=>
{
cuda
::
copy_block
(
self
,
dst
,
ctx
.stream
()
.as_ref
(),
RB
::
write_to_strategy
())
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
cuda
::
copy_block
(
src
.as_ref
(),
dst
,
ctx
.stream
()
.as_ref
(),
RB
::
write_to_strategy
(),
)
?
;
}
Ok
(())
}
}
TransferStrategy
::
NixlWrite
=>
{
TransferStrategy
::
NixlWrite
=>
{
std
::
mem
::
drop
(
nixl
::
write_block_to
(
self
,
dst
,
ctx
,
notify
)
?
);
std
::
mem
::
drop
(
nixl
::
write_block
s
_to
(
self
,
dst
,
ctx
,
notify
)
?
);
Ok
(())
Ok
(())
}
}
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
...
@@ -175,17 +188,16 @@ where
...
@@ -175,17 +188,16 @@ where
RB
::
write_to_strategy
()
RB
::
write_to_strategy
()
))),
))),
}
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
}
fn
nixl_write_to
(
fn
nixl_write_to
(
&
self
,
&
self
,
dst
:
&
mut
WB
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
if
let
TransferStrategy
::
NixlWrite
=
RB
::
write_to_strategy
()
{
if
let
TransferStrategy
::
NixlWrite
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_block_to
(
self
,
dst
,
ctx
,
notify
)
?
)
Ok
(
nixl
::
write_block
s
_to
(
self
,
dst
,
ctx
,
notify
)
?
)
}
else
{
}
else
{
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Expected NIXL transfer strategy, got: {:?}"
,
"Expected NIXL transfer strategy, got: {:?}"
,
...
...
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
5d5080ba
...
@@ -18,16 +18,14 @@ use super::*;
...
@@ -18,16 +18,14 @@ use super::*;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
ops
::
Range
;
use
std
::
task
::
Poll
;
use
std
::
task
::
Poll
;
/// Copy a block from a source to a destination using CUDA memcpy
fn
append_xfer_request
<
Source
,
Destination
>
(
pub
fn
write_block_to
<
'a
,
Source
,
Destination
>
(
src
:
&
Arc
<
Source
>
,
src
:
&
'a
Source
,
dst
:
&
mut
Destination
,
dst
:
&
'a
mut
Destination
,
src_dl
:
&
mut
XferDescList
,
ctx
:
Arc
<
TransferContext
>
,
dst_dl
:
&
mut
XferDescList
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
where
Source
:
BlockDataProvider
,
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
Destination
:
BlockDataProviderMut
,
...
@@ -36,17 +34,6 @@ where
...
@@ -36,17 +34,6 @@ where
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
// Keep the arc to use in the returned future.
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
...
@@ -64,74 +51,10 @@ where
...
@@ -64,74 +51,10 @@ where
)
?
;
)
?
;
}
}
let
xfer_req
=
nixl_agent
Ok
(())
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
)
.unwrap
();
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
// Return a future that completes when the transfer is complete.
// TODO: How efficient is this? Can we do better?
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
}
else
{
Poll
::
Pending
}
})))
}
else
{
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
write_layers_to
(
0
..
src_data
.num_layers
(),
src
,
dst
,
ctx
,
notify
)
for
layer_idx
in
0
..
src_data
.num_layers
()
{
}
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
pub
fn
write_layers_to
<
'a
,
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
src
:
&
'a
Source
,
dst
:
&
'a
mut
Destination
,
ctx
:
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
// #[cfg(debug_assertions)]
// {
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
// Destination::StorageType,
// >>::write_to_strategy();
// assert_eq!(strategy, expected_strategy);
// }
for
layer_idx
in
layer_range
{
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
...
@@ -156,6 +79,56 @@ where
...
@@ -156,6 +79,56 @@ where
}
}
}
}
}
}
Ok
(())
}
}
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
write_blocks_to
<
Source
,
Destination
>
(
src
:
&
[
Arc
<
Source
>
],
dst
:
&
mut
[
Destination
],
ctx
:
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
{
if
src
.is_empty
()
||
dst
.is_empty
()
{
return
Ok
(
Box
::
new
(
std
::
future
::
ready
(())));
}
assert_eq!
(
src
.len
(),
dst
.len
());
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
let
src_mem_type
=
src
.first
()
.unwrap
()
.block_data
(
private
::
PrivateToken
)
.storage_type
()
.nixl_mem_type
();
let
dst_mem_type
=
dst
.first
()
.unwrap
()
.block_data
(
private
::
PrivateToken
)
.storage_type
()
.nixl_mem_type
();
let
mut
src_dl
=
XferDescList
::
new
(
src_mem_type
)
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_mem_type
)
?
;
for
(
src
,
dst
)
in
src
.iter
()
.zip
(
dst
.iter_mut
())
{
append_xfer_request
(
src
,
dst
,
&
mut
src_dl
,
&
mut
dst_dl
)
?
;
}
debug_assert!
(
!
src_dl
.has_overlaps
()
?
&&
!
dst_dl
.has_overlaps
()
?
);
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
...
@@ -164,14 +137,6 @@ where
...
@@ -164,14 +137,6 @@ where
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
Some
(
&
xfer_args
),
)
?
;
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
...
@@ -179,6 +144,8 @@ where
...
@@ -179,6 +144,8 @@ where
.as_ref
()
.as_ref
()
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
Poll
::
Ready
(())
}
else
{
}
else
{
...
...
lib/llm/src/block_manager/offload.rs
View file @
5d5080ba
...
@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
...
@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
mod
pending
;
mod
pending
;
pub
mod
request
;
pub
mod
request
;
use
pending
::{
CudaTransferManager
,
DiskTransferManager
,
PendingTransfer
,
TransferManager
};
use
pending
::{
CudaTransferManager
,
DiskTransferManager
,
PendingTransfer
,
TransferBatcher
,
TransferManager
,
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
// TODO: This should be dynamic
const
MAX_CONCURRENT_TRANSFERS
:
usize
=
4
;
const
MAX_
OFFLOAD_STREAM_DEPTH
:
usize
=
4
;
const
MAX_
TRANSFER_BATCH_SIZE
:
usize
=
16
;
/// The offload manager handles all block transfers between different cache levels.
/// The offload manager handles all block transfers between different cache levels.
pub
struct
OffloadManager
<
Metadata
:
BlockMetadata
>
{
pub
struct
OffloadManager
<
Metadata
:
BlockMetadata
>
{
// Handles to the device, host, and disk pools.
// Handles to the device, host, and disk pools.
disk
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
/// Queue of offloading requests.
/// Queue of offloading requests.
device_offload_tx
:
mpsc
::
UnboundedSender
<
OffloadRequest
<
DeviceStorage
,
Metadata
>>
,
device_offload_tx
:
mpsc
::
UnboundedSender
<
OffloadRequest
<
DeviceStorage
,
Metadata
>>
,
...
@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
...
@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
impl
<
Metadata
:
BlockMetadata
>
OffloadManager
<
Metadata
>
{
impl
<
Metadata
:
BlockMetadata
>
OffloadManager
<
Metadata
>
{
pub
fn
new
(
pub
fn
new
(
disk
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
async_rt_handle
:
Handle
,
async_rt_handle
:
Handle
,
)
->
Result
<
Arc
<
Self
>>
{
)
->
Result
<
Arc
<
Self
>>
{
...
@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
host_clone
=
this
.host
.clone
();
let
host_clone
=
this
.host
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
offload_worker
(
let
res
=
OffloadManager
::
offload_worker
(
device_clone
,
device_clone
,
host_clone
,
host_clone
,
device_offload_rx
,
device_offload_rx
,
Arc
::
new
(
Cuda
Transfer
Manag
er
::
new
(
Arc
::
new
(
Transfer
Batch
er
::
new
(
device_offload_transfer_ctx
,
CudaTransferManager
::
new
(
device_offload_transfer_ctx
,
MAX_CONCURRENT_TRANSFERS
),
MAX_
OFFLOAD_STREAM_DEPTH
,
MAX_
TRANSFER_BATCH_SIZE
,
)),
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
});
});
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
...
@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
disk_clone
=
this
.disk
.clone
();
let
disk_clone
=
this
.disk
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
offload_worker
(
let
res
=
OffloadManager
::
offload_worker
(
host_clone
,
host_clone
,
disk_clone
,
disk_clone
,
host_offload_rx
,
host_offload_rx
,
Arc
::
new
(
Disk
Transfer
Manag
er
::
new
(
Arc
::
new
(
Transfer
Batch
er
::
new
(
transfer_ctx_clone
,
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
)
,
MAX_
OFFLOAD_STREAM_DEPTH
,
MAX_
TRANSFER_BATCH_SIZE
,
)),
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
});
});
// Host -> Device onboarding
// Host -> Device onboarding
...
@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
onboard_worker
(
let
res
=
OffloadManager
::
onboard_worker
(
host_clone
,
host_clone
,
device_clone
,
device_clone
,
host_onboard_rx
,
host_onboard_rx
,
Arc
::
new
(
CudaTransferManager
::
new
(
transfer_ctx_clone
,
16384
)),
Arc
::
new
(
TransferBatcher
::
new
(
CudaTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
MAX_TRANSFER_BATCH_SIZE
,
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
});
});
// Disk -> Device onboarding
// Disk -> Device onboarding
...
@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
onboard_worker
(
let
res
=
OffloadManager
::
onboard_worker
(
disk_clone
,
disk_clone
,
device_clone
,
device_clone
,
disk_onboard_rx
,
disk_onboard_rx
,
Arc
::
new
(
DiskTransferManager
::
new
(
transfer_ctx_clone
,
16384
)),
Arc
::
new
(
TransferBatcher
::
new
(
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
MAX_TRANSFER_BATCH_SIZE
,
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
});
});
Ok
(
this_clone
)
Ok
(
this_clone
)
}
}
async
fn
offload_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
async
fn
offload_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
source_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Source
,
Metadata
>>>
,
source_pool
:
Option
<
Arc
<
BlockPool
<
Source
,
Metadata
>>>
,
target_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
_arc
.is_none
()
||
target_pool
_arc
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
}
}
let
source_pool
=
source_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
source_pool
=
source_pool
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
let
mut
queue
=
BTreeSet
::
new
();
let
mut
queue
=
BTreeSet
::
new
();
...
@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
// Allocate a block from the host pool.
// Allocate a block from the host pool.
// TODO: The most likely error here is that the
hos
t pool is full.
// TODO: The most likely error here is that the
targe
t pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let
target_blocks
=
match
target_pool
.allocate_blocks
(
1
)
.await
{
let
target_blocks
=
match
target_pool
.allocate_blocks
(
1
)
.await
{
Ok
(
blocks
)
=>
blocks
,
Ok
(
blocks
)
=>
blocks
,
...
@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
if
let
Some
(
target_block
)
=
target_blocks
.into_iter
()
.next
()
{
if
let
Some
(
target_block
)
=
target_blocks
.into_iter
()
.next
()
{
transfer_manager
transfer_manager
.
begin
_transfer
(
PendingTransfer
::
new
(
.
enqueue
_transfer
(
PendingTransfer
::
new
(
vec!
[
block
],
vec!
[
block
],
vec!
[
target_block
],
vec!
[
target_block
],
None
,
None
,
target_pool
_arc
.clone
(),
target_pool
.clone
(),
))
))
.await
?
;
.await
?
;
}
}
...
@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
async
fn
onboard_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
async
fn
onboard_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
source_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Source
,
Metadata
>>>
,
source_pool
:
Option
<
Arc
<
BlockPool
<
Source
,
Metadata
>>>
,
target_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
_arc
.is_none
()
||
target_pool
_arc
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
}
}
let
target_pool
=
target_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
// Loop on incoming requests
// Loop on incoming requests
while
let
Some
(
request
)
=
onboard_rx
.recv
()
.await
{
while
let
Some
(
request
)
=
onboard_rx
.recv
()
.await
{
...
@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
.collect
();
.collect
();
transfer_manager
transfer_manager
.
begin
_transfer
(
PendingTransfer
::
new
(
.
enqueue
_transfer
(
PendingTransfer
::
new
(
sources
,
sources
,
target_blocks
,
target_blocks
,
Some
(
request
.response_tx
),
Some
(
request
.response_tx
),
target_pool
_arc
.clone
(),
target_pool
.clone
(),
))
))
.await
?
;
.await
?
;
}
}
...
@@ -478,9 +486,9 @@ mod tests {
...
@@ -478,9 +486,9 @@ mod tests {
const
BLOCK_SIZE
:
usize
=
4
;
const
BLOCK_SIZE
:
usize
=
4
;
type
DevicePool
=
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
BasicMetadata
>>>
;
type
DevicePool
=
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
BasicMetadata
>>>
;
type
HostPool
=
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
BasicMetadata
>>>
;
type
HostPool
=
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
BasicMetadata
>>>
;
type
DiskPool
=
Arc
<
Option
<
BlockPool
<
DiskStorage
,
BasicMetadata
>>>
;
type
DiskPool
=
Option
<
Arc
<
BlockPool
<
DiskStorage
,
BasicMetadata
>>>
;
lazy_static
::
lazy_static!
{
lazy_static
::
lazy_static!
{
static
ref
NIXL_AGENT
:
Arc
<
Option
<
NixlAgent
>>
=
{
static
ref
NIXL_AGENT
:
Arc
<
Option
<
NixlAgent
>>
=
{
...
@@ -521,16 +529,18 @@ mod tests {
...
@@ -521,16 +529,18 @@ mod tests {
device
.nixl_register
(
agent
,
None
)
?
;
device
.nixl_register
(
agent
,
None
)
?
;
let
device_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
device
,
42
,
0
)
?
.into_blocks
()
?
;
let
device_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
device
,
42
,
0
)
?
.into_blocks
()
?
;
let
device_pool
=
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
device_blocks
)
.build
()
?
));
let
device_pool
=
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
device_blocks
)
.build
()
?
,
));
let
host_pool
=
if
let
Some
(
host_blocks
)
=
host_blocks
{
let
host_pool
=
if
let
Some
(
host_blocks
)
=
host_blocks
{
config
.num_blocks
=
host_blocks
;
config
.num_blocks
=
host_blocks
;
let
mut
host
=
FullyContiguous
::
allocate
(
config
.clone
(),
&
PinnedAllocator
::
default
())
?
;
let
mut
host
=
FullyContiguous
::
allocate
(
config
.clone
(),
&
PinnedAllocator
::
default
())
?
;
host
.nixl_register
(
agent
,
None
)
?
;
host
.nixl_register
(
agent
,
None
)
?
;
let
host_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
host
,
42
,
0
)
?
.into_blocks
()
?
;
let
host_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
host
,
42
,
0
)
?
.into_blocks
()
?
;
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
host_blocks
)
.build
()
?
))
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
host_blocks
)
.build
()
?
))
}
else
{
}
else
{
Arc
::
new
(
None
)
None
};
};
let
disk_pool
=
if
let
Some
(
disk_blocks
)
=
disk_blocks
{
let
disk_pool
=
if
let
Some
(
disk_blocks
)
=
disk_blocks
{
...
@@ -538,9 +548,9 @@ mod tests {
...
@@ -538,9 +548,9 @@ mod tests {
let
mut
disk
=
FullyContiguous
::
allocate
(
config
,
&
DiskAllocator
)
?
;
let
mut
disk
=
FullyContiguous
::
allocate
(
config
,
&
DiskAllocator
)
?
;
disk
.nixl_register
(
agent
,
None
)
?
;
disk
.nixl_register
(
agent
,
None
)
?
;
let
disk_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
disk
,
42
,
0
)
?
.into_blocks
()
?
;
let
disk_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
disk
,
42
,
0
)
?
.into_blocks
()
?
;
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
disk_blocks
)
.build
()
?
))
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
disk_blocks
)
.build
()
?
))
}
else
{
}
else
{
Arc
::
new
(
None
)
None
};
};
let
async_rt_handle
=
Handle
::
current
();
let
async_rt_handle
=
Handle
::
current
();
...
@@ -558,7 +568,7 @@ mod tests {
...
@@ -558,7 +568,7 @@ mod tests {
/// Create a block in the 'RESET' state.
/// Create a block in the 'RESET' state.
async
fn
get_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
get_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
pool
.allocate_blocks
(
1
)
pool
.allocate_blocks
(
1
)
.await
?
.await
?
...
@@ -569,7 +579,7 @@ mod tests {
...
@@ -569,7 +579,7 @@ mod tests {
/// Create a block in the 'PARTIAL' state.
/// Create a block in the 'PARTIAL' state.
async
fn
partial_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
partial_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
token
:
u32
,
token
:
u32
,
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
let
mut
block
=
get_block
(
pool
)
.await
?
;
let
mut
block
=
get_block
(
pool
)
.await
?
;
...
@@ -580,7 +590,7 @@ mod tests {
...
@@ -580,7 +590,7 @@ mod tests {
/// Create a block in the 'COMPLETED' state.
/// Create a block in the 'COMPLETED' state.
async
fn
completed_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
completed_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
tokens
:
[
u32
;
BLOCK_SIZE
],
tokens
:
[
u32
;
BLOCK_SIZE
],
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
let
mut
block
=
get_block
(
pool
)
.await
?
;
let
mut
block
=
get_block
(
pool
)
.await
?
;
...
@@ -666,7 +676,7 @@ mod tests {
...
@@ -666,7 +676,7 @@ mod tests {
async
fn
test_offload_invalid_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_invalid_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
// Check blocks in the 'RESET' state.
// Check blocks in the 'RESET' state.
let
immutable_block
=
ImmutableBlock
::
new
(
Arc
::
new
(
get_block
(
device_pool
)
.await
?
));
let
immutable_block
=
ImmutableBlock
::
new
(
Arc
::
new
(
get_block
(
device_pool
)
.await
?
));
...
@@ -699,8 +709,8 @@ mod tests {
...
@@ -699,8 +709,8 @@ mod tests {
async
fn
test_offload_registered_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_registered_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
// Create a block and register it with the offload manager
// Create a block and register it with the offload manager
let
block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
...
@@ -742,8 +752,8 @@ mod tests {
...
@@ -742,8 +752,8 @@ mod tests {
async
fn
test_no_host_blocks_available
()
->
Result
<
()
>
{
async
fn
test_no_host_blocks_available
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
host_blocks
=
host_pool
.allocate_blocks
(
4
)
.await
?
;
let
host_blocks
=
host_pool
.allocate_blocks
(
4
)
.await
?
;
assert_eq!
(
host_blocks
.len
(),
4
);
assert_eq!
(
host_blocks
.len
(),
4
);
...
@@ -790,8 +800,8 @@ mod tests {
...
@@ -790,8 +800,8 @@ mod tests {
async
fn
test_onboard
()
->
Result
<
()
>
{
async
fn
test_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
// Allocate and fill a block on the host.
// Allocate and fill a block on the host.
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
...
@@ -844,8 +854,8 @@ mod tests {
...
@@ -844,8 +854,8 @@ mod tests {
async
fn
test_offload_onboard
()
->
Result
<
()
>
{
async
fn
test_offload_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_device_block
=
device_pool
let
immutable_device_block
=
device_pool
...
@@ -913,8 +923,8 @@ mod tests {
...
@@ -913,8 +923,8 @@ mod tests {
async
fn
test_onboard_err_handling
()
->
Result
<
()
>
{
async
fn
test_onboard_err_handling
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_host_block
=
host_pool
let
immutable_host_block
=
host_pool
...
@@ -942,7 +952,7 @@ mod tests {
...
@@ -942,7 +952,7 @@ mod tests {
async
fn
test_offload_onboard_no_host_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_onboard_no_host_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_device_block
=
device_pool
let
immutable_device_block
=
device_pool
...
@@ -961,8 +971,8 @@ mod tests {
...
@@ -961,8 +971,8 @@ mod tests {
async
fn
test_offload_disk
()
->
Result
<
()
>
{
async
fn
test_offload_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
))
?
;
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
))
?
;
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_host_block
=
host_pool
let
immutable_host_block
=
host_pool
...
@@ -996,8 +1006,8 @@ mod tests {
...
@@ -996,8 +1006,8 @@ mod tests {
async
fn
test_onboard_disk
()
->
Result
<
()
>
{
async
fn
test_onboard_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
))
?
;
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
))
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
disk_block
=
completed_block
(
disk_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
disk_block
=
completed_block
(
disk_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_disk_block
=
disk_pool
let
immutable_disk_block
=
disk_pool
...
@@ -1032,9 +1042,9 @@ mod tests {
...
@@ -1032,9 +1042,9 @@ mod tests {
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
build_pools
(
8
,
Some
(
8
),
Some
(
8
))
?
;
build_pools
(
8
,
Some
(
8
),
Some
(
8
))
?
;
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
mut
host_blocks
=
Vec
::
new
();
let
mut
host_blocks
=
Vec
::
new
();
...
@@ -1076,4 +1086,39 @@ mod tests {
...
@@ -1076,4 +1086,39 @@ mod tests {
Ok
(())
Ok
(())
}
}
#[tokio::test]
async
fn
test_transfer_batcher
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
,
None
,
Some
(
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
),
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
mut
disk_blocks
=
Vec
::
new
();
for
i
in
0
..
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
{
disk_blocks
.push
(
completed_block
(
disk_pool
,
[
i
as
u32
;
4
])
.await
?
);
}
let
immutable_disk_blocks
=
disk_pool
.register_blocks
(
disk_blocks
)
.await
?
;
let
device_blocks
=
offload_manager
.onboard
(
immutable_disk_blocks
.clone
())
.await
?
;
assert_eq!
(
device_blocks
.len
(),
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
);
for
device_block
in
&
device_blocks
{
let
blocks
=
device_pool
.match_sequence_hashes
(
vec!
[
device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
blocks
.len
(),
1
);
compare_block_contents
(
&
blocks
[
0
],
device_block
)
?
;
}
Ok
(())
}
}
}
lib/llm/src/block_manager/offload/pending.rs
View file @
5d5080ba
...
@@ -33,11 +33,12 @@
...
@@ -33,11 +33,12 @@
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//!
//! ## Workflow
//! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::
begin
_transfer`]
//! 1. A transfer request is made by calling [`TransferManager::
enqueue
_transfer`]
//! 2. [`TransferManager::
begin
_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 2. [`TransferManager::
enqueue
_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use
std
::
marker
::
PhantomData
;
use
std
::
pin
::
Pin
;
use
std
::
pin
::
Pin
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
std
::
thread
::
spawn
;
...
@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
...
@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
futures
::{
future
::
join_all
,
stream
::
FuturesUnordered
,
StreamExt
};
use
futures
::{
stream
::
FuturesUnordered
,
StreamExt
};
use
super
::
BlockResult
;
use
super
::
BlockResult
;
...
@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
...
@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
/// The target pool that will receive the registered block.
/// The target pool that will receive the registered block.
target_
registration_
pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>
>
,
target_pool
:
Arc
<
BlockPool
<
Target
,
Metadata
>>
,
}
}
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
...
@@ -78,31 +79,35 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...
@@ -78,31 +79,35 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
target_
registration_
pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>
>
,
target_pool
:
Arc
<
BlockPool
<
Target
,
Metadata
>>
,
)
->
Self
{
)
->
Self
{
assert_eq!
(
sources
.len
(),
targets
.len
());
Self
{
Self
{
sources
,
sources
,
targets
,
targets
,
completion_indicator
,
completion_indicator
,
target_
registration_
pool
,
target_pool
,
}
}
}
}
fn
handle_complete
(
self
)
->
Result
<
()
>
{
fn
handle_complete
(
self
)
->
Result
<
()
>
{
let
Self
{
let
Self
{
targets
,
sources
,
target_registration_pool
,
mut
targets
,
target_pool
,
completion_indicator
,
completion_indicator
,
..
..
}
=
self
;
}
=
self
;
if
let
Some
(
target_registration_pool
)
=
target_registration_pool
.as_ref
()
{
for
(
source
,
target
)
in
sources
.iter
()
.zip
(
targets
.iter_mut
())
{
let
blocks
=
target_registration_pool
.register_blocks_blocking
(
targets
)
?
;
transfer_metadata
(
source
,
target
)
?
;
}
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
completion_indicator
.send
(
Ok
(
blocks
))
?
;
completion_indicator
.send
(
Ok
(
blocks
))
?
;
}
}
}
Ok
(())
Ok
(())
}
}
...
@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
...
@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
Send
+
Sync
Send
+
Sync
{
{
/// Begin a transfer. Blocks if the pending queue is full.
/// Begin a transfer. Blocks if the pending queue is full.
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
;
)
->
Result
<
()
>
;
...
@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
...
@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
{
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_depth
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
)
->
Self
{
let
(
tx
,
mut
rx
)
=
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
max_depth
);
max_concurrent_transfers
,
);
spawn
(
move
||
{
spawn
(
move
||
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
// Wait for the event.
// Wait for the event.
event
.synchronize
()
?
;
event
.synchronize
()
?
;
// Only finalize the transfer after the event is signaled.
// Only finalize the transfer after the event is signaled.
pending_transfer
.handle_complete
()
?
;
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
// The only case where this can fail is if the progress engine is shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
}
}
Ok
::
<
(),
anyhow
::
Error
>
(())
Ok
::
<
(),
anyhow
::
Error
>
(())
});
});
...
@@ -183,18 +196,15 @@ where
...
@@ -183,18 +196,15 @@ where
// Check that the target block is writable.
// Check that the target block is writable.
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
{
{
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
for
(
source
,
target
)
in
pending_transfer
pending_transfer
.sources
.write_to
(
.sources
&
mut
pending_transfer
.targets
,
.iter
()
None
,
.zip
(
pending_transfer
.targets
.iter_mut
())
self
.transfer_ctx
.clone
(),
{
)
?
;
transfer_metadata
(
source
,
target
)
?
;
source
.write_to
(
target
,
None
,
self
.transfer_ctx
.clone
())
?
;
}
// Use a cuda event to record the completion of the transfers.
// Use a cuda event to record the completion of the transfers.
let
event
=
self
let
event
=
self
...
@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
...
@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
}
}
impl
DiskTransferManager
{
impl
DiskTransferManager
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_
size
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_
concurrent_transfers
:
usize
)
->
Self
{
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
...
@@ -230,7 +240,7 @@ impl DiskTransferManager {
...
@@ -230,7 +240,7 @@ impl DiskTransferManager {
tokio
::
select!
{
tokio
::
select!
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
// If we're at max size, block the worker thread on the next() call until we have capacity.
// If we're at max size, block the worker thread on the next() call until we have capacity.
while
pending_transfers
.len
()
>=
max_
size
{
while
pending_transfers
.len
()
>=
max_
concurrent_transfers
{
pending_transfers
.next
()
.await
;
pending_transfers
.next
()
.await
;
}
}
// Once we have capacity, push the new future onto the queue.
// Once we have capacity, push the new future onto the queue.
...
@@ -267,26 +277,26 @@ where
...
@@ -267,26 +277,26 @@ where
// Check that the target block is writable.
// Check that the target block is writable.
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
{
{
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
let
futures
=
pending_transfer
let
future
=
pending_transfer
.sources
.nixl_write_to
(
.sources
&
mut
pending_transfer
.targets
,
.iter
()
None
,
.zip
(
pending_transfer
.targets
.iter_mut
())
self
.transfer_ctx
.clone
(),
.map
(|(
source
,
target
)|
{
)
?
;
transfer_metadata
(
source
,
target
)
.unwrap
();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to
(
target
,
None
,
self
.transfer_ctx
.clone
())
.unwrap
()
})
.collect
::
<
Vec
<
_
>>
();
let
completion_future
=
async
move
{
let
completion_future
=
async
move
{
let
_
=
join_all
(
futures
)
.await
;
let
_
=
future
.await
;
pending_transfer
.handle_complete
()
.unwrap
();
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
};
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
...
@@ -296,3 +306,112 @@ where
...
@@ -296,3 +306,112 @@ where
Ok
(())
Ok
(())
}
}
}
}
/// A transfer manager that enforces a max batch size for transfers.
pub
struct
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
,
_
phantom
:
PhantomData
<
(
Source
,
Target
,
Metadata
)
>
,
}
impl
<
Source
,
Target
,
Metadata
,
Manager
>
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
pub
fn
new
(
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
)
->
Self
{
Self
{
transfer_manager
,
max_transfer_batch_size
,
_
phantom
:
PhantomData
,
}
}
}
#[async_trait]
impl
<
Source
,
Target
,
Metadata
,
Manager
>
TransferManager
<
Source
,
Target
,
Metadata
>
for
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
async
fn
enqueue_transfer
(
&
self
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
// If it's smaller than the max batch size, just enqueue it.
if
pending_transfer
.sources
.len
()
<
self
.max_transfer_batch_size
{
return
self
.transfer_manager
.enqueue_transfer
(
pending_transfer
)
.await
;
}
// Otherwise, we need to split the transfer into multiple smaller transfers.
let
PendingTransfer
{
mut
sources
,
mut
targets
,
completion_indicator
,
target_pool
,
}
=
pending_transfer
;
let
mut
indicators
=
Vec
::
new
();
while
!
sources
.is_empty
()
{
let
sources
=
sources
.drain
(
..
std
::
cmp
::
min
(
self
.max_transfer_batch_size
,
sources
.len
()))
.collect
();
let
targets
=
targets
.drain
(
..
std
::
cmp
::
min
(
self
.max_transfer_batch_size
,
targets
.len
()))
.collect
();
// If we have a completion indicator, we need to create a new one for each sub-transfer.
let
indicator
=
if
completion_indicator
.is_some
()
{
let
(
batch_tx
,
batch_rx
)
=
oneshot
::
channel
();
indicators
.push
(
batch_rx
);
Some
(
batch_tx
)
}
else
{
None
};
let
request
=
PendingTransfer
::
new
(
sources
,
targets
,
indicator
,
target_pool
.clone
());
// Enqueue our reduced transfer. This may block if the queue is full.
self
.transfer_manager
.enqueue_transfer
(
request
)
.await
?
;
}
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
tokio
::
spawn
(
async
move
{
let
mut
results
=
Vec
::
new
();
for
indicator
in
indicators
.into_iter
()
{
// Await each sub-transfer, and append the results to our final results.
let
result
=
match
indicator
.await
.unwrap
()
{
Ok
(
result
)
=>
result
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Error receiving transfer results: {:?}"
,
e
);
completion_indicator
.send
(
Err
(
e
))
.unwrap
();
return
;
}
};
results
.extend
(
result
);
}
// Send the final results to the top-level completion indicator.
completion_indicator
.send
(
Ok
(
results
))
.unwrap
();
});
}
Ok
(())
}
}
lib/llm/src/block_manager/offload/request.rs
View file @
5d5080ba
...
@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
...
@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
storage
::
Storage
;
use
crate
::
block_manager
::
storage
::
Storage
;
#[derive(PartialEq,
Eq,
Ord,
PartialOrd)]
/// Higher priority offloads are done first.
/// If two offloads have the same priority, the one that was requested first is done first.
#[derive(PartialEq,
Eq)]
pub
struct
OffloadRequestKey
{
pub
struct
OffloadRequestKey
{
pub
priority
:
u64
,
pub
priority
:
u64
,
pub
timestamp
:
u64
,
pub
timestamp
:
u64
,
}
}
impl
PartialOrd
for
OffloadRequestKey
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
impl
Ord
for
OffloadRequestKey
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
other
.priority
.cmp
(
&
self
.priority
)
.then
(
self
.timestamp
.cmp
(
&
other
.timestamp
))
}
}
/// Data needed to offload a block.
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
/// This way, we don't prevent the block from being reused if needed.
...
...
lib/llm/src/block_manager/pool/inactive.rs
View file @
5d5080ba
...
@@ -518,6 +518,10 @@ pub(crate) mod tests {
...
@@ -518,6 +518,10 @@ pub(crate) mod tests {
fn
reset_metadata
(
&
mut
self
)
{
fn
reset_metadata
(
&
mut
self
)
{
self
.priority
=
0
;
self
.priority
=
0
;
}
}
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
{
Some
(
self
.priority
as
u64
)
}
}
}
type
TestPriorityKey
=
PriorityKey
<
TestMetadata
>
;
type
TestPriorityKey
=
PriorityKey
<
TestMetadata
>
;
...
...
lib/llm/src/block_manager/pool/state.rs
View file @
5d5080ba
...
@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let
immutable
=
self
.active
.register
(
mutable
)
?
;
let
immutable
=
self
.active
.register
(
mutable
)
?
;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if
offload
{
if
offload
{
immutable
.enqueue_offload
(
0
)
.await
.unwrap
();
if
let
Some
(
priority
)
=
immutable
.metadata
()
.offload_priority
()
{
immutable
.enqueue_offload
(
priority
)
.await
.unwrap
();
}
}
}
immutable_blocks
.push
(
immutable
);
immutable_blocks
.push
(
immutable
);
...
...
lib/llm/src/block_manager/state.rs
View file @
5d5080ba
...
@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
...
@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
,
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
,
disk_pool
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk_pool
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host_pool
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host_pool
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device_pool
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device_pool
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
local_block_set
:
NixlBlockSet
,
local_block_set
:
NixlBlockSet
,
remote_block_sets
:
RwLock
<
HashMap
<
WorkerID
,
HashMap
<
usize
,
RemoteBlocks
>>>
,
remote_block_sets
:
RwLock
<
HashMap
<
WorkerID
,
HashMap
<
usize
,
RemoteBlocks
>>>
,
...
@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
if
nixl_agent
.is_none
()
{
if
nixl_agent
.is_none
()
{
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
}
else
{
}
else
{
next_block_set_idx
+=
1
;
next_block_set_idx
+=
1
;
tracing
::
debug!
(
"Constructing disk pool."
);
tracing
::
debug!
(
"Constructing disk pool."
);
...
@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
}
}
else
{
}
else
{
tracing
::
debug!
(
"No disk layout provided; will not allocate disk blocks."
);
tracing
::
debug!
(
"No disk layout provided; will not allocate disk blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Create the host block pool if a host layout is provided
// Create the host block pool if a host layout is provided
...
@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
tracing
::
debug!
(
"No host layout provided; will not allocate host blocks."
);
tracing
::
debug!
(
"No host layout provided; will not allocate host blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Create the device block pool if a device layout is provided
// Create the device block pool if a device layout is provided
...
@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
tracing
::
debug!
(
"No device layout provided; will not allocate device blocks."
);
tracing
::
debug!
(
"No device layout provided; will not allocate device blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Finalize the local block set by adding NIXL metadata
// Finalize the local block set by adding NIXL metadata
...
@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
}
}
pub
fn
disk
(
&
self
)
->
Option
<&
BlockPool
<
DiskStorage
,
Metadata
>>
{
pub
fn
disk
(
&
self
)
->
Option
<&
BlockPool
<
DiskStorage
,
Metadata
>>
{
self
.disk_pool
.as_ref
()
.as_ref
()
self
.disk_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
self
.host_pool
.as_ref
()
.as_ref
()
self
.host_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
self
.device_pool
.as_ref
()
.as_ref
()
self
.device_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
...
...
lib/llm/src/block_manager/storage/disk.rs
View file @
5d5080ba
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
use
super
::
*
;
use
super
::
*
;
use
core
::
ffi
::
c_char
;
use
nix
::
fcntl
::{
fallocate
,
FallocateFlags
};
use
nix
::
fcntl
::{
fallocate
,
FallocateFlags
};
use
std
::
ffi
::
CString
;
use
std
::
ffi
::
CString
;
use
std
::
fs
::
File
;
use
std
::
fs
::
File
;
...
@@ -41,7 +42,7 @@ impl DiskStorage {
...
@@ -41,7 +42,7 @@ impl DiskStorage {
let
raw_fd
=
unsafe
{
let
raw_fd
=
unsafe
{
nix
::
libc
::
mkostemp
(
nix
::
libc
::
mkostemp
(
template_bytes
.as_mut_ptr
()
as
*
mut
i8
,
template_bytes
.as_mut_ptr
()
as
*
mut
c_char
,
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache.
// This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned.
// It also introduces the restriction that all accesses must be page-aligned.
...
@@ -80,6 +81,7 @@ impl DiskStorage {
...
@@ -80,6 +81,7 @@ impl DiskStorage {
impl
Drop
for
DiskStorage
{
impl
Drop
for
DiskStorage
{
// TODO: How robust is this actually?
// TODO: How robust is this actually?
fn
drop
(
&
mut
self
)
{
fn
drop
(
&
mut
self
)
{
self
.handles
.release
();
std
::
fs
::
remove_file
(
self
.file_name
.clone
())
.unwrap
();
std
::
fs
::
remove_file
(
self
.file_name
.clone
())
.unwrap
();
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment