Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
5d5080ba
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "c3bfbd2040032d9b5f6e52fa8e1bf7e327533e98"
Unverified
Commit
5d5080ba
authored
May 22, 2025
by
jthomson04
Committed by
GitHub
May 22, 2025
Browse files
feat: Various KVBM improvements (#1134)
parent
d3b0cae1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
429 additions
and
244 deletions
+429
-244
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+10
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+8
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+23
-11
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+66
-99
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+118
-73
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+163
-44
lib/llm/src/block_manager/offload/request.rs
lib/llm/src/block_manager/offload/request.rs
+18
-1
lib/llm/src/block_manager/pool/inactive.rs
lib/llm/src/block_manager/pool/inactive.rs
+4
-0
lib/llm/src/block_manager/pool/state.rs
lib/llm/src/block_manager/pool/state.rs
+3
-2
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+13
-13
lib/llm/src/block_manager/storage/disk.rs
lib/llm/src/block_manager/storage/disk.rs
+3
-1
No files found.
lib/llm/src/block_manager.rs
View file @
5d5080ba
...
@@ -192,11 +192,21 @@ mod tests {
...
@@ -192,11 +192,21 @@ mod tests {
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
fn
create_reference_block_manager
()
->
ReferenceBlockManager
{
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
let
worker_id
=
WORKER_ID
.fetch_add
(
1
,
Ordering
::
SeqCst
);
// Check if we're already in a Tokio runtime context
let
async_runtime
=
if
tokio
::
runtime
::
Handle
::
try_current
()
.is_ok
()
{
None
// If we're already in a runtime, don't create a new one
}
else
{
// Only create a new runtime if not already in one
Some
(
Arc
::
new
(
tokio
::
runtime
::
Runtime
::
new
()
.unwrap
()))
};
let
config
=
KvBlockManagerConfig
::
builder
()
let
config
=
KvBlockManagerConfig
::
builder
()
.runtime
(
.runtime
(
KvManagerRuntimeConfig
::
builder
()
KvManagerRuntimeConfig
::
builder
()
.worker_id
(
worker_id
)
.worker_id
(
worker_id
)
.enable_nixl
()
.enable_nixl
()
.async_runtime
(
async_runtime
)
.build
()
.build
()
.unwrap
(),
.unwrap
(),
)
)
...
...
lib/llm/src/block_manager/block.rs
View file @
5d5080ba
...
@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
...
@@ -82,6 +82,10 @@ pub trait BlockMetadata: Default + std::fmt::Debug + Clone + Ord + Send + Sync +
/// Resets the metadata to the default value
/// Resets the metadata to the default value
/// If called, the [BlockMetadata::is_reset()] should return true
/// If called, the [BlockMetadata::is_reset()] should return true
fn
reset_metadata
(
&
mut
self
);
fn
reset_metadata
(
&
mut
self
);
/// The offload priority of the block. Higher priority blocks are offloaded first.
/// If the block should not be offloaded, return None.
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
;
}
}
/// Marker trait for types that are mutable blocks
/// Marker trait for types that are mutable blocks
...
@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
...
@@ -536,6 +540,10 @@ impl BlockMetadata for BasicMetadata {
fn
reset_metadata
(
&
mut
self
)
{
fn
reset_metadata
(
&
mut
self
)
{
self
.priority
=
0
;
self
.priority
=
0
;
}
}
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
{
Some
(
self
.priority
as
u64
)
}
}
}
/// Collection that holds shared storage and layout
/// Collection that holds shared storage and layout
#[derive(Debug)]
#[derive(Debug)]
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
5d5080ba
...
@@ -133,7 +133,7 @@ where
...
@@ -133,7 +133,7 @@ where
pub
trait
WriteTo
<
Target
>
{
pub
trait
WriteTo
<
Target
>
{
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
Target
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
;
)
->
Result
<
(),
TransferError
>
;
...
@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
...
@@ -143,31 +143,44 @@ pub trait WriteTo<Target> {
/// Returns a future that will complete when the transfer is complete.
/// Returns a future that will complete when the transfer is complete.
fn
nixl_write_to
(
fn
nixl_write_to
(
&
self
,
&
self
,
dst
:
&
mut
Target
,
dst
:
&
mut
Vec
<
Target
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
;
}
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
RB
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
Vec
<
Arc
<
RB
>>
where
where
RB
:
WriteToStrategy
<
WB
>
+
Local
,
RB
:
WriteToStrategy
<
WB
>
+
Local
,
{
{
fn
write_to
(
fn
write_to
(
&
self
,
&
self
,
dst
:
&
mut
WB
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
(),
TransferError
>
{
)
->
Result
<
(),
TransferError
>
{
match
Self
::
write_to_strategy
()
{
match
RB
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
memcpy
::
copy_block
(
self
,
dst
),
TransferStrategy
::
Memcpy
=>
{
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
memcpy
::
copy_block
(
src
.as_ref
(),
dst
)
?
;
}
Ok
(())
}
TransferStrategy
::
CudaAsyncH2D
TransferStrategy
::
CudaAsyncH2D
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2H
|
TransferStrategy
::
CudaAsyncD2D
=>
{
|
TransferStrategy
::
CudaAsyncD2D
=>
{
cuda
::
copy_block
(
self
,
dst
,
ctx
.stream
()
.as_ref
(),
RB
::
write_to_strategy
())
for
(
src
,
dst
)
in
self
.iter
()
.zip
(
dst
.iter_mut
())
{
cuda
::
copy_block
(
src
.as_ref
(),
dst
,
ctx
.stream
()
.as_ref
(),
RB
::
write_to_strategy
(),
)
?
;
}
Ok
(())
}
}
TransferStrategy
::
NixlWrite
=>
{
TransferStrategy
::
NixlWrite
=>
{
std
::
mem
::
drop
(
nixl
::
write_block_to
(
self
,
dst
,
ctx
,
notify
)
?
);
std
::
mem
::
drop
(
nixl
::
write_block
s
_to
(
self
,
dst
,
ctx
,
notify
)
?
);
Ok
(())
Ok
(())
}
}
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
...
@@ -175,17 +188,16 @@ where
...
@@ -175,17 +188,16 @@ where
RB
::
write_to_strategy
()
RB
::
write_to_strategy
()
))),
))),
}
}
// dispatch_copy_to(self, dst, self.transfer_context())
}
}
fn
nixl_write_to
(
fn
nixl_write_to
(
&
self
,
&
self
,
dst
:
&
mut
WB
,
dst
:
&
mut
Vec
<
WB
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
if
let
TransferStrategy
::
NixlWrite
=
RB
::
write_to_strategy
()
{
if
let
TransferStrategy
::
NixlWrite
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_block_to
(
self
,
dst
,
ctx
,
notify
)
?
)
Ok
(
nixl
::
write_block
s
_to
(
self
,
dst
,
ctx
,
notify
)
?
)
}
else
{
}
else
{
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Expected NIXL transfer strategy, got: {:?}"
,
"Expected NIXL transfer strategy, got: {:?}"
,
...
...
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
5d5080ba
...
@@ -18,16 +18,14 @@ use super::*;
...
@@ -18,16 +18,14 @@ use super::*;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
ops
::
Range
;
use
std
::
task
::
Poll
;
use
std
::
task
::
Poll
;
/// Copy a block from a source to a destination using CUDA memcpy
fn
append_xfer_request
<
Source
,
Destination
>
(
pub
fn
write_block_to
<
'a
,
Source
,
Destination
>
(
src
:
&
Arc
<
Source
>
,
src
:
&
'a
Source
,
dst
:
&
mut
Destination
,
dst
:
&
'a
mut
Destination
,
src_dl
:
&
mut
XferDescList
,
ctx
:
Arc
<
TransferContext
>
,
dst_dl
:
&
mut
XferDescList
,
notify
:
Option
<
String
>
,
)
->
Result
<
()
>
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
where
Source
:
BlockDataProvider
,
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
Destination
:
BlockDataProviderMut
,
...
@@ -36,17 +34,6 @@ where
...
@@ -36,17 +34,6 @@ where
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
if
src_data
.is_fully_contiguous
()
&&
dst_data
.is_fully_contiguous
()
{
// Keep the arc to use in the returned future.
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
src_desc
=
src_data
.block_view
()
?
.as_nixl_descriptor
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
let
dst_desc
=
dst_data
.block_view_mut
()
?
.as_nixl_descriptor_mut
();
...
@@ -64,45 +51,42 @@ where
...
@@ -64,45 +51,42 @@ where
)
?
;
)
?
;
}
}
let
xfer_req
=
nixl_agent
Ok
(())
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
)
.unwrap
();
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
xfer_args
.set_has_notification
(
true
)
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
// Return a future that completes when the transfer is complete.
// TODO: How efficient is this? Can we do better?
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
let
nixl_agent
=
nixl_agent_arc
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
}
else
{
Poll
::
Pending
}
})))
}
else
{
}
else
{
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
assert_eq!
(
src_data
.num_layers
(),
dst_data
.num_layers
());
write_layers_to
(
0
..
src_data
.num_layers
(),
src
,
dst
,
ctx
,
notify
)
for
layer_idx
in
0
..
src_data
.num_layers
()
{
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
}
}
Ok
(())
}
}
}
}
/// Copy a range of layers from a source to a destination using CUDA memcpy
/// Copy a block from a source to a destination using CUDA memcpy
pub
fn
write_layers_to
<
'a
,
Source
,
Destination
>
(
pub
fn
write_blocks_to
<
Source
,
Destination
>
(
layer_range
:
Range
<
usize
>
,
src
:
&
[
Arc
<
Source
>
],
src
:
&
'a
Source
,
dst
:
&
mut
[
Destination
],
dst
:
&
'a
mut
Destination
,
ctx
:
Arc
<
TransferContext
>
,
ctx
:
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
notify
:
Option
<
String
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
...
@@ -110,8 +94,10 @@ where
...
@@ -110,8 +94,10 @@ where
Source
:
BlockDataProvider
,
Source
:
BlockDataProvider
,
Destination
:
BlockDataProviderMut
,
Destination
:
BlockDataProviderMut
,
{
{
let
src_data
=
src
.block_data
(
private
::
PrivateToken
);
if
src
.is_empty
()
||
dst
.is_empty
()
{
let
dst_data
=
dst
.block_data_mut
(
private
::
PrivateToken
);
return
Ok
(
Box
::
new
(
std
::
future
::
ready
(())));
}
assert_eq!
(
src
.len
(),
dst
.len
());
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent_arc
=
ctx
.as_ref
()
.nixl_agent
();
let
nixl_agent
=
nixl_agent_arc
let
nixl_agent
=
nixl_agent_arc
...
@@ -119,44 +105,31 @@ where
...
@@ -119,44 +105,31 @@ where
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
.expect
(
"NIXL agent not found"
);
let
remote_worker_id
=
dst_data
.worker_id
.to_string
();
let
src_mem_type
=
src
let
mut
src_dl
=
XferDescList
::
new
(
src_data
.storage_type
()
.nixl_mem_type
())
?
;
.first
()
let
mut
dst_dl
=
XferDescList
::
new
(
dst_data
.storage_type
()
.nixl_mem_type
())
?
;
.unwrap
()
.block_data
(
private
::
PrivateToken
)
// #[cfg(debug_assertions)]
.storage_type
()
// {
.nixl_mem_type
();
// let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
let
dst_mem_type
=
dst
// Destination::StorageType,
.first
()
// >>::write_to_strategy();
.unwrap
()
// assert_eq!(strategy, expected_strategy);
.block_data
(
private
::
PrivateToken
)
// }
.storage_type
()
.nixl_mem_type
();
for
layer_idx
in
layer_range
{
for
outer_idx
in
0
..
src_data
.num_outer_dims
()
{
let
mut
src_dl
=
XferDescList
::
new
(
src_mem_type
)
?
;
let
src_view
=
src_data
.layer_view
(
layer_idx
,
outer_idx
)
?
;
let
mut
dst_dl
=
XferDescList
::
new
(
dst_mem_type
)
?
;
let
mut
dst_view
=
dst_data
.layer_view_mut
(
layer_idx
,
outer_idx
)
?
;
for
(
src
,
dst
)
in
src
.iter
()
.zip
(
dst
.iter_mut
())
{
debug_assert_eq!
(
src_view
.size
(),
dst_view
.size
());
append_xfer_request
(
src
,
dst
,
&
mut
src_dl
,
&
mut
dst_dl
)
?
;
let
src_desc
=
src_view
.as_nixl_descriptor
();
let
dst_desc
=
dst_view
.as_nixl_descriptor_mut
();
unsafe
{
src_dl
.add_desc
(
src_desc
.as_ptr
()
as
usize
,
src_desc
.size
(),
src_desc
.device_id
(),
)
?
;
dst_dl
.add_desc
(
dst_desc
.as_ptr
()
as
usize
,
dst_desc
.size
(),
dst_desc
.device_id
(),
)
?
;
}
}
}
}
debug_assert!
(
!
src_dl
.has_overlaps
()
?
&&
!
dst_dl
.has_overlaps
()
?
);
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
if
let
Some
(
notify
)
=
notify
{
if
let
Some
(
notify
)
=
notify
{
...
@@ -164,14 +137,6 @@ where
...
@@ -164,14 +137,6 @@ where
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
xfer_args
.set_notification_message
(
notify
.as_bytes
())
?
;
}
}
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
remote_worker_id
,
Some
(
&
xfer_args
),
)
?
;
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
let
_
=
nixl_agent
.post_xfer_req
(
&
xfer_req
,
Some
(
&
xfer_args
))
?
;
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
Ok
(
Box
::
new
(
poll_fn
(
move
|
_
cx
|
{
...
@@ -179,6 +144,8 @@ where
...
@@ -179,6 +144,8 @@ where
.as_ref
()
.as_ref
()
.as_ref
()
.as_ref
()
.expect
(
"NIXL agent not found"
);
.expect
(
"NIXL agent not found"
);
// The nixl agent returns true if the transfer is still in progress.
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
if
!
nixl_agent
.get_xfer_status
(
&
xfer_req
)
.unwrap
()
{
Poll
::
Ready
(())
Poll
::
Ready
(())
}
else
{
}
else
{
...
...
lib/llm/src/block_manager/offload.rs
View file @
5d5080ba
...
@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
...
@@ -65,18 +65,20 @@ use std::collections::BTreeSet;
mod
pending
;
mod
pending
;
pub
mod
request
;
pub
mod
request
;
use
pending
::{
CudaTransferManager
,
DiskTransferManager
,
PendingTransfer
,
TransferManager
};
use
pending
::{
CudaTransferManager
,
DiskTransferManager
,
PendingTransfer
,
TransferBatcher
,
TransferManager
,
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
use
request
::{
BlockResult
,
OffloadRequest
,
OffloadRequestKey
,
OnboardRequest
};
// TODO: This should be dynamic
const
MAX_CONCURRENT_TRANSFERS
:
usize
=
4
;
const
MAX_
OFFLOAD_STREAM_DEPTH
:
usize
=
4
;
const
MAX_
TRANSFER_BATCH_SIZE
:
usize
=
16
;
/// The offload manager handles all block transfers between different cache levels.
/// The offload manager handles all block transfers between different cache levels.
pub
struct
OffloadManager
<
Metadata
:
BlockMetadata
>
{
pub
struct
OffloadManager
<
Metadata
:
BlockMetadata
>
{
// Handles to the device, host, and disk pools.
// Handles to the device, host, and disk pools.
disk
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
/// Queue of offloading requests.
/// Queue of offloading requests.
device_offload_tx
:
mpsc
::
UnboundedSender
<
OffloadRequest
<
DeviceStorage
,
Metadata
>>
,
device_offload_tx
:
mpsc
::
UnboundedSender
<
OffloadRequest
<
DeviceStorage
,
Metadata
>>
,
...
@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
...
@@ -92,9 +94,9 @@ pub struct OffloadManager<Metadata: BlockMetadata> {
impl
<
Metadata
:
BlockMetadata
>
OffloadManager
<
Metadata
>
{
impl
<
Metadata
:
BlockMetadata
>
OffloadManager
<
Metadata
>
{
pub
fn
new
(
pub
fn
new
(
disk
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
async_rt_handle
:
Handle
,
async_rt_handle
:
Handle
,
)
->
Result
<
Arc
<
Self
>>
{
)
->
Result
<
Arc
<
Self
>>
{
...
@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -129,17 +131,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
host_clone
=
this
.host
.clone
();
let
host_clone
=
this
.host
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
offload_worker
(
let
res
=
OffloadManager
::
offload_worker
(
device_clone
,
device_clone
,
host_clone
,
host_clone
,
device_offload_rx
,
device_offload_rx
,
Arc
::
new
(
Cuda
Transfer
Manag
er
::
new
(
Arc
::
new
(
Transfer
Batch
er
::
new
(
device_offload_transfer_ctx
,
CudaTransferManager
::
new
(
device_offload_transfer_ctx
,
MAX_CONCURRENT_TRANSFERS
),
MAX_
OFFLOAD_STREAM_DEPTH
,
MAX_
TRANSFER_BATCH_SIZE
,
)),
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
});
});
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
let
transfer_ctx
=
Arc
::
new
(
TransferContext
::
new
(
...
@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -152,17 +154,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
disk_clone
=
this
.disk
.clone
();
let
disk_clone
=
this
.disk
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
offload_worker
(
let
res
=
OffloadManager
::
offload_worker
(
host_clone
,
host_clone
,
disk_clone
,
disk_clone
,
host_offload_rx
,
host_offload_rx
,
Arc
::
new
(
Disk
Transfer
Manag
er
::
new
(
Arc
::
new
(
Transfer
Batch
er
::
new
(
transfer_ctx_clone
,
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
)
,
MAX_
OFFLOAD_STREAM_DEPTH
,
MAX_
TRANSFER_BATCH_SIZE
,
)),
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Offload worker terminated: {:?}"
,
res
);
});
});
// Host -> Device onboarding
// Host -> Device onboarding
...
@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -170,14 +172,17 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
onboard_worker
(
let
res
=
OffloadManager
::
onboard_worker
(
host_clone
,
host_clone
,
device_clone
,
device_clone
,
host_onboard_rx
,
host_onboard_rx
,
Arc
::
new
(
CudaTransferManager
::
new
(
transfer_ctx_clone
,
16384
)),
Arc
::
new
(
TransferBatcher
::
new
(
CudaTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
MAX_TRANSFER_BATCH_SIZE
,
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
});
});
// Disk -> Device onboarding
// Disk -> Device onboarding
...
@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -185,31 +190,34 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
let
device_clone
=
this
.device
.clone
();
let
device_clone
=
this
.device
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
let
transfer_ctx_clone
=
transfer_ctx
.clone
();
async_rt_handle
.spawn
(
async
move
{
async_rt_handle
.spawn
(
async
move
{
OffloadManager
::
onboard_worker
(
let
res
=
OffloadManager
::
onboard_worker
(
disk_clone
,
disk_clone
,
device_clone
,
device_clone
,
disk_onboard_rx
,
disk_onboard_rx
,
Arc
::
new
(
DiskTransferManager
::
new
(
transfer_ctx_clone
,
16384
)),
Arc
::
new
(
TransferBatcher
::
new
(
DiskTransferManager
::
new
(
transfer_ctx_clone
,
MAX_CONCURRENT_TRANSFERS
),
MAX_TRANSFER_BATCH_SIZE
,
)),
)
)
.await
.await
;
.unwrap
()
tracing
::
warn!
(
"Onboard worker terminated: {:?}"
,
res
);
});
});
Ok
(
this_clone
)
Ok
(
this_clone
)
}
}
async
fn
offload_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
async
fn
offload_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
source_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Source
,
Metadata
>>>
,
source_pool
:
Option
<
Arc
<
BlockPool
<
Source
,
Metadata
>>>
,
target_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
mut
offload_rx
:
mpsc
::
UnboundedReceiver
<
OffloadRequest
<
Source
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
_arc
.is_none
()
||
target_pool
_arc
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
}
}
let
source_pool
=
source_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
source_pool
=
source_pool
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
let
mut
queue
=
BTreeSet
::
new
();
let
mut
queue
=
BTreeSet
::
new
();
...
@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -252,7 +260,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
// Allocate a block from the host pool.
// Allocate a block from the host pool.
// TODO: The most likely error here is that the
hos
t pool is full.
// TODO: The most likely error here is that the
targe
t pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let
target_blocks
=
match
target_pool
.allocate_blocks
(
1
)
.await
{
let
target_blocks
=
match
target_pool
.allocate_blocks
(
1
)
.await
{
Ok
(
blocks
)
=>
blocks
,
Ok
(
blocks
)
=>
blocks
,
...
@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -263,11 +271,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
if
let
Some
(
target_block
)
=
target_blocks
.into_iter
()
.next
()
{
if
let
Some
(
target_block
)
=
target_blocks
.into_iter
()
.next
()
{
transfer_manager
transfer_manager
.
begin
_transfer
(
PendingTransfer
::
new
(
.
enqueue
_transfer
(
PendingTransfer
::
new
(
vec!
[
block
],
vec!
[
block
],
vec!
[
target_block
],
vec!
[
target_block
],
None
,
None
,
target_pool
_arc
.clone
(),
target_pool
.clone
(),
))
))
.await
?
;
.await
?
;
}
}
...
@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -282,16 +290,16 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
async
fn
onboard_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
async
fn
onboard_worker
<
Source
:
Storage
,
Target
:
Storage
>
(
source_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Source
,
Metadata
>>>
,
source_pool
:
Option
<
Arc
<
BlockPool
<
Source
,
Metadata
>>>
,
target_pool
_arc
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
target_pool
:
Option
<
Arc
<
BlockPool
<
Target
,
Metadata
>>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
mut
onboard_rx
:
mpsc
::
UnboundedReceiver
<
OnboardRequest
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
transfer_manager
:
Arc
<
dyn
TransferManager
<
Source
,
Target
,
Metadata
>>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
if
source_pool
_arc
.is_none
()
||
target_pool
_arc
.is_none
()
{
if
source_pool
.is_none
()
||
target_pool
.is_none
()
{
return
Ok
(());
return
Ok
(());
}
}
let
target_pool
=
target_pool
_arc
.as_ref
()
.as_ref
()
.unwrap
();
let
target_pool
=
target_pool
.as_ref
()
.unwrap
();
// Loop on incoming requests
// Loop on incoming requests
while
let
Some
(
request
)
=
onboard_rx
.recv
()
.await
{
while
let
Some
(
request
)
=
onboard_rx
.recv
()
.await
{
...
@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -311,11 +319,11 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
.collect
();
.collect
();
transfer_manager
transfer_manager
.
begin
_transfer
(
PendingTransfer
::
new
(
.
enqueue
_transfer
(
PendingTransfer
::
new
(
sources
,
sources
,
target_blocks
,
target_blocks
,
Some
(
request
.response_tx
),
Some
(
request
.response_tx
),
target_pool
_arc
.clone
(),
target_pool
.clone
(),
))
))
.await
?
;
.await
?
;
}
}
...
@@ -478,9 +486,9 @@ mod tests {
...
@@ -478,9 +486,9 @@ mod tests {
const
BLOCK_SIZE
:
usize
=
4
;
const
BLOCK_SIZE
:
usize
=
4
;
type
DevicePool
=
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
BasicMetadata
>>>
;
type
DevicePool
=
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
BasicMetadata
>>>
;
type
HostPool
=
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
BasicMetadata
>>>
;
type
HostPool
=
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
BasicMetadata
>>>
;
type
DiskPool
=
Arc
<
Option
<
BlockPool
<
DiskStorage
,
BasicMetadata
>>>
;
type
DiskPool
=
Option
<
Arc
<
BlockPool
<
DiskStorage
,
BasicMetadata
>>>
;
lazy_static
::
lazy_static!
{
lazy_static
::
lazy_static!
{
static
ref
NIXL_AGENT
:
Arc
<
Option
<
NixlAgent
>>
=
{
static
ref
NIXL_AGENT
:
Arc
<
Option
<
NixlAgent
>>
=
{
...
@@ -521,16 +529,18 @@ mod tests {
...
@@ -521,16 +529,18 @@ mod tests {
device
.nixl_register
(
agent
,
None
)
?
;
device
.nixl_register
(
agent
,
None
)
?
;
let
device_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
device
,
42
,
0
)
?
.into_blocks
()
?
;
let
device_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
device
,
42
,
0
)
?
.into_blocks
()
?
;
let
device_pool
=
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
device_blocks
)
.build
()
?
));
let
device_pool
=
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
device_blocks
)
.build
()
?
,
));
let
host_pool
=
if
let
Some
(
host_blocks
)
=
host_blocks
{
let
host_pool
=
if
let
Some
(
host_blocks
)
=
host_blocks
{
config
.num_blocks
=
host_blocks
;
config
.num_blocks
=
host_blocks
;
let
mut
host
=
FullyContiguous
::
allocate
(
config
.clone
(),
&
PinnedAllocator
::
default
())
?
;
let
mut
host
=
FullyContiguous
::
allocate
(
config
.clone
(),
&
PinnedAllocator
::
default
())
?
;
host
.nixl_register
(
agent
,
None
)
?
;
host
.nixl_register
(
agent
,
None
)
?
;
let
host_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
host
,
42
,
0
)
?
.into_blocks
()
?
;
let
host_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
host
,
42
,
0
)
?
.into_blocks
()
?
;
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
host_blocks
)
.build
()
?
))
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
host_blocks
)
.build
()
?
))
}
else
{
}
else
{
Arc
::
new
(
None
)
None
};
};
let
disk_pool
=
if
let
Some
(
disk_blocks
)
=
disk_blocks
{
let
disk_pool
=
if
let
Some
(
disk_blocks
)
=
disk_blocks
{
...
@@ -538,9 +548,9 @@ mod tests {
...
@@ -538,9 +548,9 @@ mod tests {
let
mut
disk
=
FullyContiguous
::
allocate
(
config
,
&
DiskAllocator
)
?
;
let
mut
disk
=
FullyContiguous
::
allocate
(
config
,
&
DiskAllocator
)
?
;
disk
.nixl_register
(
agent
,
None
)
?
;
disk
.nixl_register
(
agent
,
None
)
?
;
let
disk_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
disk
,
42
,
0
)
?
.into_blocks
()
?
;
let
disk_blocks
=
Blocks
::
<
_
,
BasicMetadata
>
::
new
(
disk
,
42
,
0
)
?
.into_blocks
()
?
;
Arc
::
new
(
Some
(
BlockPool
::
builder
()
.blocks
(
disk_blocks
)
.build
()
?
))
Some
(
Arc
::
new
(
BlockPool
::
builder
()
.blocks
(
disk_blocks
)
.build
()
?
))
}
else
{
}
else
{
Arc
::
new
(
None
)
None
};
};
let
async_rt_handle
=
Handle
::
current
();
let
async_rt_handle
=
Handle
::
current
();
...
@@ -558,7 +568,7 @@ mod tests {
...
@@ -558,7 +568,7 @@ mod tests {
/// Create a block in the 'RESET' state.
/// Create a block in the 'RESET' state.
async
fn
get_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
get_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
pool
.allocate_blocks
(
1
)
pool
.allocate_blocks
(
1
)
.await
?
.await
?
...
@@ -569,7 +579,7 @@ mod tests {
...
@@ -569,7 +579,7 @@ mod tests {
/// Create a block in the 'PARTIAL' state.
/// Create a block in the 'PARTIAL' state.
async
fn
partial_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
partial_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
token
:
u32
,
token
:
u32
,
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
let
mut
block
=
get_block
(
pool
)
.await
?
;
let
mut
block
=
get_block
(
pool
)
.await
?
;
...
@@ -580,7 +590,7 @@ mod tests {
...
@@ -580,7 +590,7 @@ mod tests {
/// Create a block in the 'COMPLETED' state.
/// Create a block in the 'COMPLETED' state.
async
fn
completed_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
async
fn
completed_block
<
S
:
Storage
,
Metadata
:
BlockMetadata
>
(
pool
:
&
BlockPool
<
S
,
Metadata
>
,
pool
:
&
Arc
<
BlockPool
<
S
,
Metadata
>
>
,
tokens
:
[
u32
;
BLOCK_SIZE
],
tokens
:
[
u32
;
BLOCK_SIZE
],
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
)
->
Result
<
MutableBlock
<
S
,
Metadata
>>
{
let
mut
block
=
get_block
(
pool
)
.await
?
;
let
mut
block
=
get_block
(
pool
)
.await
?
;
...
@@ -666,7 +676,7 @@ mod tests {
...
@@ -666,7 +676,7 @@ mod tests {
async
fn
test_offload_invalid_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_invalid_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
// Check blocks in the 'RESET' state.
// Check blocks in the 'RESET' state.
let
immutable_block
=
ImmutableBlock
::
new
(
Arc
::
new
(
get_block
(
device_pool
)
.await
?
));
let
immutable_block
=
ImmutableBlock
::
new
(
Arc
::
new
(
get_block
(
device_pool
)
.await
?
));
...
@@ -699,8 +709,8 @@ mod tests {
...
@@ -699,8 +709,8 @@ mod tests {
async
fn
test_offload_registered_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_registered_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
// Create a block and register it with the offload manager
// Create a block and register it with the offload manager
let
block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
...
@@ -742,8 +752,8 @@ mod tests {
...
@@ -742,8 +752,8 @@ mod tests {
async
fn
test_no_host_blocks_available
()
->
Result
<
()
>
{
async
fn
test_no_host_blocks_available
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
host_blocks
=
host_pool
.allocate_blocks
(
4
)
.await
?
;
let
host_blocks
=
host_pool
.allocate_blocks
(
4
)
.await
?
;
assert_eq!
(
host_blocks
.len
(),
4
);
assert_eq!
(
host_blocks
.len
(),
4
);
...
@@ -790,8 +800,8 @@ mod tests {
...
@@ -790,8 +800,8 @@ mod tests {
async
fn
test_onboard
()
->
Result
<
()
>
{
async
fn
test_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
// Allocate and fill a block on the host.
// Allocate and fill a block on the host.
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
...
@@ -844,8 +854,8 @@ mod tests {
...
@@ -844,8 +854,8 @@ mod tests {
async
fn
test_offload_onboard
()
->
Result
<
()
>
{
async
fn
test_offload_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_device_block
=
device_pool
let
immutable_device_block
=
device_pool
...
@@ -913,8 +923,8 @@ mod tests {
...
@@ -913,8 +923,8 @@ mod tests {
async
fn
test_onboard_err_handling
()
->
Result
<
()
>
{
async
fn
test_onboard_err_handling
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_host_block
=
host_pool
let
immutable_host_block
=
host_pool
...
@@ -942,7 +952,7 @@ mod tests {
...
@@ -942,7 +952,7 @@ mod tests {
async
fn
test_offload_onboard_no_host_blocks
()
->
Result
<
()
>
{
async
fn
test_offload_onboard_no_host_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
device_block
=
completed_block
(
device_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_device_block
=
device_pool
let
immutable_device_block
=
device_pool
...
@@ -961,8 +971,8 @@ mod tests {
...
@@ -961,8 +971,8 @@ mod tests {
async
fn
test_offload_disk
()
->
Result
<
()
>
{
async
fn
test_offload_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
))
?
;
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
))
?
;
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
host_block
=
completed_block
(
host_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_host_block
=
host_pool
let
immutable_host_block
=
host_pool
...
@@ -996,8 +1006,8 @@ mod tests {
...
@@ -996,8 +1006,8 @@ mod tests {
async
fn
test_onboard_disk
()
->
Result
<
()
>
{
async
fn
test_onboard_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
))
?
;
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
))
?
;
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
disk_block
=
completed_block
(
disk_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
disk_block
=
completed_block
(
disk_pool
,
[
0
,
1
,
2
,
3
])
.await
?
;
let
immutable_disk_block
=
disk_pool
let
immutable_disk_block
=
disk_pool
...
@@ -1032,9 +1042,9 @@ mod tests {
...
@@ -1032,9 +1042,9 @@ mod tests {
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
build_pools
(
8
,
Some
(
8
),
Some
(
8
))
?
;
build_pools
(
8
,
Some
(
8
),
Some
(
8
))
?
;
let
disk_pool
=
disk_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.
as_ref
()
.
unwrap
();
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
mut
host_blocks
=
Vec
::
new
();
let
mut
host_blocks
=
Vec
::
new
();
...
@@ -1076,4 +1086,39 @@ mod tests {
...
@@ -1076,4 +1086,39 @@ mod tests {
Ok
(())
Ok
(())
}
}
#[tokio::test]
async
fn
test_transfer_batcher
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
,
None
,
Some
(
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
),
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
mut
disk_blocks
=
Vec
::
new
();
for
i
in
0
..
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
{
disk_blocks
.push
(
completed_block
(
disk_pool
,
[
i
as
u32
;
4
])
.await
?
);
}
let
immutable_disk_blocks
=
disk_pool
.register_blocks
(
disk_blocks
)
.await
?
;
let
device_blocks
=
offload_manager
.onboard
(
immutable_disk_blocks
.clone
())
.await
?
;
assert_eq!
(
device_blocks
.len
(),
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
);
for
device_block
in
&
device_blocks
{
let
blocks
=
device_pool
.match_sequence_hashes
(
vec!
[
device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
blocks
.len
(),
1
);
compare_block_contents
(
&
blocks
[
0
],
device_block
)
?
;
}
Ok
(())
}
}
}
lib/llm/src/block_manager/offload/pending.rs
View file @
5d5080ba
...
@@ -33,11 +33,12 @@
...
@@ -33,11 +33,12 @@
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//!
//! ## Workflow
//! ## Workflow
//! 1. A transfer request is made by calling [`TransferManager::
begin
_transfer`]
//! 1. A transfer request is made by calling [`TransferManager::
enqueue
_transfer`]
//! 2. [`TransferManager::
begin
_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 2. [`TransferManager::
enqueue
_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.
use
std
::
marker
::
PhantomData
;
use
std
::
pin
::
Pin
;
use
std
::
pin
::
Pin
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
std
::
thread
::
spawn
;
...
@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
...
@@ -55,7 +56,7 @@ use crate::block_manager::BlockPool;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
cudarc
::
driver
::{
sys
::
CUevent_flags
,
CudaEvent
};
use
futures
::{
future
::
join_all
,
stream
::
FuturesUnordered
,
StreamExt
};
use
futures
::{
stream
::
FuturesUnordered
,
StreamExt
};
use
super
::
BlockResult
;
use
super
::
BlockResult
;
...
@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
...
@@ -68,7 +69,7 @@ pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMeta
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
/// The target pool that will receive the registered block.
/// The target pool that will receive the registered block.
target_
registration_
pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>
>
,
target_pool
:
Arc
<
BlockPool
<
Target
,
Metadata
>>
,
}
}
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
...
@@ -78,30 +79,34 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
...
@@ -78,30 +79,34 @@ impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
BlockResult
<
Target
,
Metadata
>>>
,
target_
registration_
pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>
>
,
target_pool
:
Arc
<
BlockPool
<
Target
,
Metadata
>>
,
)
->
Self
{
)
->
Self
{
assert_eq!
(
sources
.len
(),
targets
.len
());
Self
{
Self
{
sources
,
sources
,
targets
,
targets
,
completion_indicator
,
completion_indicator
,
target_
registration_
pool
,
target_pool
,
}
}
}
}
fn
handle_complete
(
self
)
->
Result
<
()
>
{
fn
handle_complete
(
self
)
->
Result
<
()
>
{
let
Self
{
let
Self
{
targets
,
sources
,
target_registration_pool
,
mut
targets
,
target_pool
,
completion_indicator
,
completion_indicator
,
..
..
}
=
self
;
}
=
self
;
if
let
Some
(
target_registration_pool
)
=
target_registration_pool
.as_ref
()
{
for
(
source
,
target
)
in
sources
.iter
()
.zip
(
targets
.iter_mut
())
{
let
blocks
=
target_registration_pool
.register_blocks_blocking
(
targets
)
?
;
transfer_metadata
(
source
,
target
)
?
;
}
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
completion_indicator
.send
(
Ok
(
blocks
))
?
;
completion_indicator
.send
(
Ok
(
blocks
))
?
;
}
}
}
Ok
(())
Ok
(())
...
@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
...
@@ -134,7 +139,7 @@ pub trait TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetad
Send
+
Sync
Send
+
Sync
{
{
/// Begin a transfer. Blocks if the pending queue is full.
/// Begin a transfer. Blocks if the pending queue is full.
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
;
)
->
Result
<
()
>
;
...
@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
...
@@ -148,16 +153,24 @@ pub struct CudaTransferManager<Source: Storage, Target: Storage, Metadata: Block
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
CudaTransferManager
<
Source
,
Target
,
Metadata
>
{
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_depth
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_concurrent_transfers
:
usize
)
->
Self
{
let
(
tx
,
mut
rx
)
=
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
mpsc
::
channel
::
<
(
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
CudaEvent
)
>
(
max_depth
);
max_concurrent_transfers
,
);
spawn
(
move
||
{
spawn
(
move
||
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
while
let
Some
((
pending_transfer
,
event
))
=
rx
.blocking_recv
()
{
// Wait for the event.
// Wait for the event.
event
.synchronize
()
?
;
event
.synchronize
()
?
;
// Only finalize the transfer after the event is signaled.
// Only finalize the transfer after the event is signaled.
pending_transfer
.handle_complete
()
?
;
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
// The only case where this can fail is if the progress engine is shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
}
}
Ok
::
<
(),
anyhow
::
Error
>
(())
Ok
::
<
(),
anyhow
::
Error
>
(())
});
});
...
@@ -183,18 +196,15 @@ where
...
@@ -183,18 +196,15 @@ where
// Check that the target block is writable.
// Check that the target block is writable.
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
{
{
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
for
(
source
,
target
)
in
pending_transfer
pending_transfer
.sources
.write_to
(
.sources
&
mut
pending_transfer
.targets
,
.iter
()
None
,
.zip
(
pending_transfer
.targets
.iter_mut
())
self
.transfer_ctx
.clone
(),
{
)
?
;
transfer_metadata
(
source
,
target
)
?
;
source
.write_to
(
target
,
None
,
self
.transfer_ctx
.clone
())
?
;
}
// Use a cuda event to record the completion of the transfers.
// Use a cuda event to record the completion of the transfers.
let
event
=
self
let
event
=
self
...
@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
...
@@ -218,7 +228,7 @@ pub struct DiskTransferManager {
}
}
impl
DiskTransferManager
{
impl
DiskTransferManager
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_
size
:
usize
)
->
Self
{
pub
fn
new
(
transfer_ctx
:
Arc
<
TransferContext
>
,
max_
concurrent_transfers
:
usize
)
->
Self
{
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
let
(
futures_tx
,
mut
futures_rx
)
=
mpsc
::
channel
(
1
);
tokio
::
spawn
(
async
move
{
tokio
::
spawn
(
async
move
{
...
@@ -230,7 +240,7 @@ impl DiskTransferManager {
...
@@ -230,7 +240,7 @@ impl DiskTransferManager {
tokio
::
select!
{
tokio
::
select!
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
Some
(
future
)
=
futures_rx
.recv
()
=>
{
// If we're at max size, block the worker thread on the next() call until we have capacity.
// If we're at max size, block the worker thread on the next() call until we have capacity.
while
pending_transfers
.len
()
>=
max_
size
{
while
pending_transfers
.len
()
>=
max_
concurrent_transfers
{
pending_transfers
.next
()
.await
;
pending_transfers
.next
()
.await
;
}
}
// Once we have capacity, push the new future onto the queue.
// Once we have capacity, push the new future onto the queue.
...
@@ -267,26 +277,26 @@ where
...
@@ -267,26 +277,26 @@ where
// Check that the target block is writable.
// Check that the target block is writable.
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
MutableBlock
<
Target
,
Metadata
>
:
WritableBlock
<
StorageType
=
Target
>
,
{
{
async
fn
begin
_transfer
(
async
fn
enqueue
_transfer
(
&
self
,
&
self
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
mut
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
let
futures
=
pending_transfer
let
future
=
pending_transfer
.sources
.nixl_write_to
(
.sources
&
mut
pending_transfer
.targets
,
.iter
()
None
,
.zip
(
pending_transfer
.targets
.iter_mut
())
self
.transfer_ctx
.clone
(),
.map
(|(
source
,
target
)|
{
)
?
;
transfer_metadata
(
source
,
target
)
.unwrap
();
// Initiate the transfer, and get a future indicating completion.
source
.nixl_write_to
(
target
,
None
,
self
.transfer_ctx
.clone
())
.unwrap
()
})
.collect
::
<
Vec
<
_
>>
();
let
completion_future
=
async
move
{
let
completion_future
=
async
move
{
let
_
=
join_all
(
futures
)
.await
;
let
_
=
future
.await
;
pending_transfer
.handle_complete
()
.unwrap
();
match
pending_transfer
.handle_complete
()
{
Ok
(
_
)
=>
{}
Err
(
e
)
=>
{
// The only case where this can fail is if the progress engine is being shutdown.
// This is not a problem, so we can just ignore it.
tracing
::
warn!
(
"Error handling transfer completion: {:?}"
,
e
);
}
}
};
};
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
// Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
...
@@ -296,3 +306,112 @@ where
...
@@ -296,3 +306,112 @@ where
Ok
(())
Ok
(())
}
}
}
}
/// A transfer manager that enforces a max batch size for transfers.
pub
struct
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
,
_
phantom
:
PhantomData
<
(
Source
,
Target
,
Metadata
)
>
,
}
impl
<
Source
,
Target
,
Metadata
,
Manager
>
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
pub
fn
new
(
transfer_manager
:
Manager
,
max_transfer_batch_size
:
usize
)
->
Self
{
Self
{
transfer_manager
,
max_transfer_batch_size
,
_
phantom
:
PhantomData
,
}
}
}
#[async_trait]
impl
<
Source
,
Target
,
Metadata
,
Manager
>
TransferManager
<
Source
,
Target
,
Metadata
>
for
TransferBatcher
<
Source
,
Target
,
Metadata
,
Manager
>
where
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
,
Manager
:
TransferManager
<
Source
,
Target
,
Metadata
>
,
{
async
fn
enqueue_transfer
(
&
self
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
// If it's smaller than the max batch size, just enqueue it.
if
pending_transfer
.sources
.len
()
<
self
.max_transfer_batch_size
{
return
self
.transfer_manager
.enqueue_transfer
(
pending_transfer
)
.await
;
}
// Otherwise, we need to split the transfer into multiple smaller transfers.
let
PendingTransfer
{
mut
sources
,
mut
targets
,
completion_indicator
,
target_pool
,
}
=
pending_transfer
;
let
mut
indicators
=
Vec
::
new
();
while
!
sources
.is_empty
()
{
let
sources
=
sources
.drain
(
..
std
::
cmp
::
min
(
self
.max_transfer_batch_size
,
sources
.len
()))
.collect
();
let
targets
=
targets
.drain
(
..
std
::
cmp
::
min
(
self
.max_transfer_batch_size
,
targets
.len
()))
.collect
();
// If we have a completion indicator, we need to create a new one for each sub-transfer.
let
indicator
=
if
completion_indicator
.is_some
()
{
let
(
batch_tx
,
batch_rx
)
=
oneshot
::
channel
();
indicators
.push
(
batch_rx
);
Some
(
batch_tx
)
}
else
{
None
};
let
request
=
PendingTransfer
::
new
(
sources
,
targets
,
indicator
,
target_pool
.clone
());
// Enqueue our reduced transfer. This may block if the queue is full.
self
.transfer_manager
.enqueue_transfer
(
request
)
.await
?
;
}
if
let
Some
(
completion_indicator
)
=
completion_indicator
{
tokio
::
spawn
(
async
move
{
let
mut
results
=
Vec
::
new
();
for
indicator
in
indicators
.into_iter
()
{
// Await each sub-transfer, and append the results to our final results.
let
result
=
match
indicator
.await
.unwrap
()
{
Ok
(
result
)
=>
result
,
Err
(
e
)
=>
{
tracing
::
error!
(
"Error receiving transfer results: {:?}"
,
e
);
completion_indicator
.send
(
Err
(
e
))
.unwrap
();
return
;
}
};
results
.extend
(
result
);
}
// Send the final results to the top-level completion indicator.
completion_indicator
.send
(
Ok
(
results
))
.unwrap
();
});
}
Ok
(())
}
}
lib/llm/src/block_manager/offload/request.rs
View file @
5d5080ba
...
@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
...
@@ -20,12 +20,29 @@ use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
storage
::
Storage
;
use
crate
::
block_manager
::
storage
::
Storage
;
#[derive(PartialEq,
Eq,
Ord,
PartialOrd)]
/// Higher priority offloads are done first.
/// If two offloads have the same priority, the one that was requested first is done first.
#[derive(PartialEq,
Eq)]
pub
struct
OffloadRequestKey
{
pub
struct
OffloadRequestKey
{
pub
priority
:
u64
,
pub
priority
:
u64
,
pub
timestamp
:
u64
,
pub
timestamp
:
u64
,
}
}
impl
PartialOrd
for
OffloadRequestKey
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
impl
Ord
for
OffloadRequestKey
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
other
.priority
.cmp
(
&
self
.priority
)
.then
(
self
.timestamp
.cmp
(
&
other
.timestamp
))
}
}
/// Data needed to offload a block.
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
/// This way, we don't prevent the block from being reused if needed.
...
...
lib/llm/src/block_manager/pool/inactive.rs
View file @
5d5080ba
...
@@ -518,6 +518,10 @@ pub(crate) mod tests {
...
@@ -518,6 +518,10 @@ pub(crate) mod tests {
fn
reset_metadata
(
&
mut
self
)
{
fn
reset_metadata
(
&
mut
self
)
{
self
.priority
=
0
;
self
.priority
=
0
;
}
}
fn
offload_priority
(
&
self
)
->
Option
<
u64
>
{
Some
(
self
.priority
as
u64
)
}
}
}
type
TestPriorityKey
=
PriorityKey
<
TestMetadata
>
;
type
TestPriorityKey
=
PriorityKey
<
TestMetadata
>
;
...
...
lib/llm/src/block_manager/pool/state.rs
View file @
5d5080ba
...
@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -179,9 +179,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let
immutable
=
self
.active
.register
(
mutable
)
?
;
let
immutable
=
self
.active
.register
(
mutable
)
?
;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if
offload
{
if
offload
{
immutable
.enqueue_offload
(
0
)
.await
.unwrap
();
if
let
Some
(
priority
)
=
immutable
.metadata
()
.offload_priority
()
{
immutable
.enqueue_offload
(
priority
)
.await
.unwrap
();
}
}
}
immutable_blocks
.push
(
immutable
);
immutable_blocks
.push
(
immutable
);
...
...
lib/llm/src/block_manager/state.rs
View file @
5d5080ba
...
@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
...
@@ -51,9 +51,9 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_agent
:
Arc
<
Option
<
NixlAgent
>>
,
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
,
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
,
disk_pool
:
Arc
<
Option
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
disk_pool
:
Option
<
Arc
<
BlockPool
<
DiskStorage
,
Metadata
>>>
,
host_pool
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
host_pool
:
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
Metadata
>>>
,
device_pool
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
device_pool
:
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
Metadata
>>>
,
local_block_set
:
NixlBlockSet
,
local_block_set
:
NixlBlockSet
,
remote_block_sets
:
RwLock
<
HashMap
<
WorkerID
,
HashMap
<
usize
,
RemoteBlocks
>>>
,
remote_block_sets
:
RwLock
<
HashMap
<
WorkerID
,
HashMap
<
usize
,
RemoteBlocks
>>>
,
...
@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -126,7 +126,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
if
nixl_agent
.is_none
()
{
if
nixl_agent
.is_none
()
{
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
}
else
{
}
else
{
next_block_set_idx
+=
1
;
next_block_set_idx
+=
1
;
tracing
::
debug!
(
"Constructing disk pool."
);
tracing
::
debug!
(
"Constructing disk pool."
);
...
@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -139,11 +139,11 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
}
}
else
{
}
else
{
tracing
::
debug!
(
"No disk layout provided; will not allocate disk blocks."
);
tracing
::
debug!
(
"No disk layout provided; will not allocate disk blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Create the host block pool if a host layout is provided
// Create the host block pool if a host layout is provided
...
@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -159,10 +159,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
tracing
::
debug!
(
"No host layout provided; will not allocate host blocks."
);
tracing
::
debug!
(
"No host layout provided; will not allocate host blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Create the device block pool if a device layout is provided
// Create the device block pool if a device layout is provided
...
@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -178,10 +178,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
)
?
;
)
?
;
(
Arc
::
new
(
Some
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
tracing
::
debug!
(
"No device layout provided; will not allocate device blocks."
);
tracing
::
debug!
(
"No device layout provided; will not allocate device blocks."
);
(
Arc
::
new
(
None
)
,
None
)
(
None
,
None
)
};
};
// Finalize the local block set by adding NIXL metadata
// Finalize the local block set by adding NIXL metadata
...
@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -414,15 +414,15 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
}
}
pub
fn
disk
(
&
self
)
->
Option
<&
BlockPool
<
DiskStorage
,
Metadata
>>
{
pub
fn
disk
(
&
self
)
->
Option
<&
BlockPool
<
DiskStorage
,
Metadata
>>
{
self
.disk_pool
.as_ref
()
.as_ref
()
self
.disk_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
self
.host_pool
.as_ref
()
.as_ref
()
self
.host_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
self
.device_pool
.as_ref
()
.as_ref
()
self
.device_pool
.as_ref
()
.
map
(|
pool
|
pool
.
as_ref
()
)
}
}
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
...
...
lib/llm/src/block_manager/storage/disk.rs
View file @
5d5080ba
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
use
super
::
*
;
use
super
::
*
;
use
core
::
ffi
::
c_char
;
use
nix
::
fcntl
::{
fallocate
,
FallocateFlags
};
use
nix
::
fcntl
::{
fallocate
,
FallocateFlags
};
use
std
::
ffi
::
CString
;
use
std
::
ffi
::
CString
;
use
std
::
fs
::
File
;
use
std
::
fs
::
File
;
...
@@ -41,7 +42,7 @@ impl DiskStorage {
...
@@ -41,7 +42,7 @@ impl DiskStorage {
let
raw_fd
=
unsafe
{
let
raw_fd
=
unsafe
{
nix
::
libc
::
mkostemp
(
nix
::
libc
::
mkostemp
(
template_bytes
.as_mut_ptr
()
as
*
mut
i8
,
template_bytes
.as_mut_ptr
()
as
*
mut
c_char
,
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// For maximum performance, GPU DirectStorage requires O_DIRECT.
// This allows transfers to bypass the kernel page cache.
// This allows transfers to bypass the kernel page cache.
// It also introduces the restriction that all accesses must be page-aligned.
// It also introduces the restriction that all accesses must be page-aligned.
...
@@ -80,6 +81,7 @@ impl DiskStorage {
...
@@ -80,6 +81,7 @@ impl DiskStorage {
impl
Drop
for
DiskStorage
{
impl
Drop
for
DiskStorage
{
// TODO: How robust is this actually?
// TODO: How robust is this actually?
fn
drop
(
&
mut
self
)
{
fn
drop
(
&
mut
self
)
{
self
.handles
.release
();
std
::
fs
::
remove_file
(
self
.file_name
.clone
())
.unwrap
();
std
::
fs
::
remove_file
(
self
.file_name
.clone
())
.unwrap
();
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment