Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
6d9aac77
Unverified
Commit
6d9aac77
authored
May 24, 2025
by
jthomson04
Committed by
GitHub
May 24, 2025
Browse files
feat: kvbm offload fixes and tests (#1191)
parent
e5845b53
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
351 additions
and
93 deletions
+351
-93
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+17
-10
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+37
-10
lib/llm/src/block_manager/block/transfer/nixl.rs
lib/llm/src/block_manager/block/transfer/nixl.rs
+9
-3
lib/llm/src/block_manager/block/transfer/strategy.rs
lib/llm/src/block_manager/block/transfer/strategy.rs
+15
-15
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+246
-55
lib/llm/src/block_manager/offload/request.rs
lib/llm/src/block_manager/offload/request.rs
+27
-0
No files found.
lib/llm/src/block_manager/block.rs
View file @
6d9aac77
...
...
@@ -38,12 +38,12 @@ use super::{
WorkerID
,
};
use
derive_getters
::
Getters
;
use
std
::{
fmt
::
Debug
,
ops
::{
Deref
,
DerefMut
},
sync
::
Arc
,
};
use
thiserror
::
Error
;
mod
private
{
...
...
@@ -192,8 +192,6 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
self
.manager
=
Some
(
manager
);
}
// TODO(#967) - Enable with TransferEngine
#[allow(dead_code)]
pub
(
crate
)
fn
manager
(
&
self
)
->
Option
<&
Arc
<
BlockManager
<
M
>>>
{
self
.manager
.as_ref
()
}
...
...
@@ -521,13 +519,26 @@ pub trait BlockDataProviderMut: BlockDataProvider {
fn
block_data_mut
(
&
mut
self
,
_
:
private
::
PrivateToken
)
->
&
mut
BlockData
<
Self
::
StorageType
>
;
}
#[derive(Clone,
Debug,
Default,
Eq,
PartialEq,
Ord,
PartialOrd)]
#[derive(Clone,
Debug,
Default,
Eq,
PartialEq,
Ord,
PartialOrd
,
Getters
)]
pub
struct
BasicMetadata
{
#[getter(copy)]
priority
:
u32
,
#[getter(copy)]
returned_tick
:
u64
,
#[getter(copy)]
acquired_tick
:
u64
,
}
impl
BasicMetadata
{
pub
fn
update_priority
(
&
self
,
priority
:
u32
)
->
Self
{
BasicMetadata
{
priority
,
returned_tick
:
self
.returned_tick
,
acquired_tick
:
self
.acquired_tick
,
}
}
}
impl
BlockMetadata
for
BasicMetadata
{
fn
on_acquired
(
&
mut
self
,
tick
:
u64
)
{
self
.acquired_tick
=
tick
;
...
...
@@ -755,11 +766,6 @@ impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
Self
{
block
}
}
pub
fn
manager
(
&
self
)
->
Option
<&
Arc
<
BlockManager
<
M
>>>
{
// Access the underlying Block's manager field directly through deref
self
.manager
.as_ref
()
}
pub
fn
mutable_block
(
&
self
)
->
&
Arc
<
MutableBlock
<
S
,
M
>>
{
&
self
.block
}
...
...
@@ -859,9 +865,10 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
ImmutableBlock
<
S
,
M
>
{
pub
async
fn
enqueue_offload
(
&
self
,
priority
:
u64
)
->
Result
<
()
>
{
// TODO: Is it ok to silently fail if the block is not managed?
if
let
Some
(
manager
)
=
self
.manager
()
{
manager
.enqueue_offload_block
(
self
,
priority
)
.await
?
;
}
else
{
tracing
::
warn!
(
"Block is not managed. Unable to enqueue offload."
);
}
Ok
(())
}
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
6d9aac77
...
...
@@ -28,6 +28,7 @@ use crate::block_manager::storage::{
use
cudarc
::
driver
::
CudaStream
;
use
nixl_sys
::
XferOp
::{
Read
,
Write
};
use
std
::
future
::
Future
;
use
std
::
ops
::
Range
;
...
...
@@ -77,6 +78,21 @@ pub enum TransferError {
Other
(
#[from]
anyhow
::
Error
),
}
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
NixlTransfer
{
Read
,
Write
,
}
impl
NixlTransfer
{
pub
fn
as_xfer_op
(
&
self
)
->
nixl_sys
::
XferOp
{
match
self
{
NixlTransfer
::
Read
=>
Read
,
NixlTransfer
::
Write
=>
Write
,
}
}
}
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
enum
TransferStrategy
{
Memcpy
,
...
...
@@ -85,8 +101,7 @@ pub enum TransferStrategy {
CudaAsyncD2D
,
CudaBlockingH2D
,
CudaBlockingD2H
,
NixlWrite
,
// aka PUT
NixlRead
,
// aka GET
Nixl
(
NixlTransfer
),
Invalid
,
}
...
...
@@ -126,7 +141,7 @@ where
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlRead
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
...
...
@@ -179,8 +194,14 @@ where
}
Ok
(())
}
TransferStrategy
::
NixlWrite
=>
{
std
::
mem
::
drop
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
)
?
);
TransferStrategy
::
Nixl
(
transfer_type
)
=>
{
std
::
mem
::
drop
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
,
transfer_type
,
)
?
);
Ok
(())
}
_
=>
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
...
...
@@ -196,8 +217,14 @@ where
notify
:
Option
<
String
>
,
ctx
:
Arc
<
TransferContext
>
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>
,
TransferError
>
{
if
let
TransferStrategy
::
NixlWrite
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
)
?
)
if
let
TransferStrategy
::
Nixl
(
transfer_type
)
=
RB
::
write_to_strategy
()
{
Ok
(
nixl
::
write_blocks_to
(
self
,
dst
,
ctx
,
notify
,
transfer_type
,
)
?
)
}
else
{
Err
(
TransferError
::
IncompatibleTypes
(
format!
(
"Expected NIXL transfer strategy, got: {:?}"
,
...
...
@@ -626,7 +653,7 @@ mod tests {
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Pinned to ...
...
...
@@ -644,7 +671,7 @@ mod tests {
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Device to ...
...
...
@@ -662,7 +689,7 @@ mod tests {
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Nixl to ... should fail to compile
...
...
lib/llm/src/block_manager/block/transfer/nixl.rs
View file @
6d9aac77
...
...
@@ -16,7 +16,7 @@
use
super
::
*
;
use
anyhow
::
Result
;
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
,
XferOp
};
use
nixl_sys
::{
MemoryRegion
,
NixlDescriptor
,
OptArgs
,
XferDescList
};
use
std
::
future
::{
poll_fn
,
Future
};
use
std
::
task
::
Poll
;
...
...
@@ -89,6 +89,7 @@ pub fn write_blocks_to<Source, Destination>(
dst
:
&
mut
[
Destination
],
ctx
:
Arc
<
TransferContext
>
,
notify
:
Option
<
String
>
,
transfer_type
:
NixlTransfer
,
)
->
Result
<
Box
<
dyn
Future
<
Output
=
()
>
+
Send
+
Sync
+
Unpin
>>
where
Source
:
BlockDataProvider
,
...
...
@@ -127,8 +128,13 @@ where
debug_assert!
(
!
src_dl
.has_overlaps
()
?
&&
!
dst_dl
.has_overlaps
()
?
);
let
xfer_req
=
nixl_agent
.create_xfer_req
(
XferOp
::
Write
,
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
)
?
;
let
xfer_req
=
nixl_agent
.create_xfer_req
(
transfer_type
.as_xfer_op
(),
&
src_dl
,
&
dst_dl
,
&
nixl_agent
.name
(),
None
,
)
?
;
let
mut
xfer_args
=
OptArgs
::
new
()
?
;
...
...
lib/llm/src/block_manager/block/transfer/strategy.rs
View file @
6d9aac77
...
...
@@ -21,35 +21,35 @@ use super::*;
impl
WriteToStrategy
<
DiskStorage
>
for
DiskStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
}
}
impl
WriteToStrategy
<
SystemStorage
>
for
DiskStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Nixl
Write
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
impl
WriteToStrategy
<
PinnedStorage
>
for
DiskStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Nixl
Write
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
impl
WriteToStrategy
<
DeviceStorage
>
for
DiskStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Nixl
Write
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
impl
WriteToStrategy
<
DiskStorage
>
for
SystemStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
}
}
...
...
@@ -77,7 +77,7 @@ impl WriteToStrategy<DeviceStorage> for SystemStorage {
impl
WriteToStrategy
<
DiskStorage
>
for
PinnedStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
}
}
...
...
@@ -105,7 +105,7 @@ impl WriteToStrategy<DeviceStorage> for PinnedStorage {
impl
WriteToStrategy
<
DiskStorage
>
for
DeviceStorage
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
Nixl
Write
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
...
...
@@ -133,7 +133,7 @@ impl WriteToStrategy<DeviceStorage> for DeviceStorage {
impl
<
S
:
Storage
+
Local
>
WriteToStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
write_to_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
}
}
...
...
@@ -170,7 +170,7 @@ where
impl
<
S
:
Storage
+
Local
>
ReadFromStrategy
<
NixlStorage
>
for
S
{
#[inline(always)]
fn
read_from_strategy
()
->
TransferStrategy
{
TransferStrategy
::
NixlRead
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
}
}
...
...
@@ -198,7 +198,7 @@ mod tests {
assert_eq!
(
<
SystemStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Pinned to ...
...
...
@@ -216,7 +216,7 @@ mod tests {
);
assert_eq!
(
<
PinnedStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Device to ...
...
...
@@ -234,7 +234,7 @@ mod tests {
);
assert_eq!
(
<
DeviceStorage
as
WriteToStrategy
<
NixlStorage
>>
::
write_to_strategy
(),
TransferStrategy
::
NixlWrite
TransferStrategy
::
Nixl
(
NixlTransfer
::
Write
)
);
// Nixl to ... should fail to compile
...
...
@@ -276,7 +276,7 @@ mod tests {
assert_eq!
(
<
SystemStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
);
// Pinned to ...
...
...
@@ -297,7 +297,7 @@ mod tests {
assert_eq!
(
<
PinnedStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
);
// Device to ...
...
...
@@ -318,7 +318,7 @@ mod tests {
assert_eq!
(
<
DeviceStorage
as
ReadFromStrategy
<
NixlStorage
>>
::
read_from_strategy
(),
TransferStrategy
::
NixlRead
TransferStrategy
::
Nixl
(
NixlTransfer
::
Read
)
);
// Nixl to ... should fail to compile
...
...
lib/llm/src/block_manager/offload.rs
View file @
6d9aac77
...
...
@@ -259,12 +259,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
}
}
// Allocate a block from the host pool.
// TODO: The most likely error here is that the target pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let
target_blocks
=
match
target_pool
.allocate_blocks
(
1
)
.await
{
Ok
(
blocks
)
=>
blocks
,
Err
(
_
)
=>
{
tracing
::
warn!
(
"Target pool full. Skipping offload. This should only ever happen with very small pool sizes."
);
continue
;
}
};
...
...
@@ -451,6 +449,10 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
self
.disk_onboard_tx
.send
(
OnboardRequest
::
new
(
disk_blocks
,
tx
))
.map_err
(|
_
|
BlockPoolError
::
ProgressEngineShutdown
)
?
;
}
else
{
return
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
Other
(
anyhow
::
anyhow!
(
"Block type not supported for onboarding."
),
)));
}
match
rx
.await
{
...
...
@@ -466,12 +468,15 @@ mod tests {
use
crate
::
block_manager
::
block
::
test_utils
::
get_private_token
;
use
crate
::
block_manager
::{
block
::{
BasicMetadata
,
BlockDataExt
,
BlockDataProvider
,
BlockExt
,
Blocks
,
MutableBlock
},
block
::{
nixl
::
BlockHandleInfo
,
BasicMetadata
,
BlockDataExt
,
BlockDataProvider
,
BlockExt
,
Blocks
,
MutableBlock
,
},
layout
::{
nixl
::
NixlLayout
,
FullyContiguous
},
pool
::
BlockPool
,
storage
::{
cuda
::
CudaAccessible
,
DeviceAllocator
,
DeviceStorage
,
DiskAllocator
,
DiskStorage
,
PinnedAllocator
,
PinnedStorage
,
StorageType
,
DeviceAllocator
,
DeviceStorage
,
DiskAllocator
,
DiskStorage
,
PinnedAllocator
,
PinnedStorage
,
StorageType
,
},
DType
,
LayoutConfig
,
};
...
...
@@ -480,11 +485,12 @@ mod tests {
use
aligned_vec
::
avec
;
use
cudarc
::
runtime
::
sys
::{
cudaMemcpy
,
cudaMemcpyKind
,
cudaMemset
};
use
std
::
fs
::
File
;
use
std
::
io
::{
Read
,
Seek
,
SeekFrom
};
use
std
::
io
::{
Read
,
Seek
,
SeekFrom
,
Write
};
use
std
::
mem
::
ManuallyDrop
;
use
std
::
os
::
unix
::
io
::
FromRawFd
;
const
BLOCK_SIZE
:
usize
=
4
;
const
NUM_LAYERS
:
usize
=
8
;
type
DevicePool
=
Option
<
Arc
<
BlockPool
<
DeviceStorage
,
BasicMetadata
>>>
;
type
HostPool
=
Option
<
Arc
<
BlockPool
<
PinnedStorage
,
BasicMetadata
>>>
;
...
...
@@ -505,6 +511,7 @@ mod tests {
device_blocks
:
usize
,
host_blocks
:
Option
<
usize
>
,
disk_blocks
:
Option
<
usize
>
,
inner_dim
:
Option
<
usize
>
,
)
->
Result
<
(
Arc
<
OffloadManager
<
BasicMetadata
>>
,
DevicePool
,
...
...
@@ -513,10 +520,10 @@ mod tests {
)
>
{
let
mut
config
=
LayoutConfig
{
num_blocks
:
device_blocks
,
num_layers
:
8
,
num_layers
:
NUM_LAYERS
,
outer_dim
:
1
,
page_size
:
BLOCK_SIZE
,
inner_dim
:
1024
,
inner_dim
:
inner_dim
.unwrap_or
(
1024
)
,
alignment
:
1
,
dtype
:
DType
::
FP16
,
};
...
...
@@ -602,21 +609,39 @@ mod tests {
Ok
(
block
)
}
fn
populate_
cuda_
block
<
S
:
Storage
+
CudaAccessible
+
NixlDescriptor
>
(
fn
populate_block
<
S
:
Storage
+
NixlDescriptor
>
(
block
:
&
impl
BlockDataProvider
<
StorageType
=
S
>
,
value
:
i32
,
value
:
u8
,
)
->
Result
<
()
>
{
let
block_data
=
block
.block_data
(
get_private_token
())
.block_view
()
?
;
let
block_size
=
block_data
.size
();
unsafe
{
cudaMemset
(
block_data
.as_ptr
()
as
*
mut
std
::
ffi
::
c_void
,
value
,
block_size
,
)
.result
()
?
;
let
block_data
=
block
.block_data
(
get_private_token
());
let
block_view
=
block_data
.block_view
()
?
;
let
block_size
=
block_view
.size
();
match
block_data
.storage_type
()
{
StorageType
::
Device
(
_
)
|
StorageType
::
Pinned
=>
unsafe
{
cudaMemset
(
block_view
.as_ptr
()
as
*
mut
std
::
ffi
::
c_void
,
value
as
i32
,
block_size
,
)
.result
()
?
;
},
StorageType
::
Disk
=>
{
let
nixl_desc
=
block_view
.as_nixl_descriptor
();
let
mut
file
:
ManuallyDrop
<
File
>
;
let
data
=
avec!
[[
4096
]
|
value
;
block_size
];
unsafe
{
file
=
ManuallyDrop
::
new
(
File
::
from_raw_fd
(
nixl_desc
.device_id
()
as
i32
));
file
.seek
(
SeekFrom
::
Start
(
nixl_desc
.as_ptr
()
as
u64
))
?
;
}
file
.write_all
(
&
data
)
?
;
file
.sync_all
()
?
;
file
.flush
()
?
;
}
_
=>
panic!
(),
}
Ok
(())
}
...
...
@@ -654,27 +679,31 @@ mod tests {
file
.read_exact
(
&
mut
aligned
)
?
;
contents
=
aligned
.to_vec
();
}
_
=>
{
panic!
();
}
_
=>
anyhow
::
bail!
(
"Unsupported storage type."
),
}
Ok
(
contents
.to_vec
())
}
/// Compare the contents of a device block and a host block.
fn
compare_block_contents
(
fn
check_block_contents
(
block1
:
&
impl
BlockDataProvider
<
StorageType
=
impl
Storage
+
NixlDescriptor
>
,
block2
:
&
impl
BlockDataProvider
<
StorageType
=
impl
Storage
+
NixlDescriptor
>
,
value
:
u8
,
)
->
Result
<
()
>
{
assert_eq!
(
get_block_contents
(
block1
)
?
,
get_block_contents
(
block2
)
?
);
let
contents1
=
get_block_contents
(
block1
)
?
;
let
contents2
=
get_block_contents
(
block2
)
?
;
for
(
c1_value
,
c2_value
)
in
contents1
.iter
()
.zip
(
contents2
.iter
())
{
if
*
c1_value
!=
*
c2_value
||
*
c1_value
!=
value
{
panic!
(
"{} != {} != {}"
,
c1_value
,
c2_value
,
value
);
}
}
Ok
(())
}
#[tokio::test]
async
fn
test_offload_invalid_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
...
...
@@ -707,7 +736,7 @@ mod tests {
#[tokio::test]
async
fn
test_offload_registered_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -722,7 +751,7 @@ mod tests {
.next
()
.ok_or
(
anyhow
::
anyhow!
(
"Failed to register block"
))
?
;
populate_
cuda_
block
(
&
immutable_device_block
,
42
)
?
;
populate_block
(
&
immutable_device_block
,
42
)
?
;
// Offloads should only go to G2 (for now)
offload_manager
.offload
(
&
immutable_device_block
,
0
)
.await
?
;
...
...
@@ -743,14 +772,14 @@ mod tests {
immutable_device_block
.sequence_hash
()
?
);
c
ompare
_block_contents
(
&
immutable_device_block
,
&
host_blocks
[
0
])
?
;
c
heck
_block_contents
(
&
immutable_device_block
,
&
host_blocks
[
0
]
,
42
)
?
;
Ok
(())
}
#[tokio::test]
async
fn
test_no_host_blocks_available
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -798,7 +827,7 @@ mod tests {
#[tokio::test]
async
fn
test_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -812,7 +841,7 @@ mod tests {
.next
()
.unwrap
();
populate_
cuda_
block
(
&
immutable_host_block
,
42
)
?
;
populate_block
(
&
immutable_host_block
,
42
)
?
;
// Onboard the block.
let
onboarded_blocks
=
offload_manager
...
...
@@ -831,7 +860,7 @@ mod tests {
BlockState
::
Registered
(
_
)
));
c
ompare
_block_contents
(
&
onboarded_blocks
[
0
],
&
immutable_host_block
)
?
;
c
heck
_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
// Wait for the new value to show up in the device pool.
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
...
...
@@ -845,14 +874,14 @@ mod tests {
);
// Check that this is the same block.
c
ompare
_block_contents
(
&
device_blocks
[
0
],
&
immutable_host_block
)
?
;
c
heck
_block_contents
(
&
immutable_host_block
,
&
device_blocks
[
0
],
42
)
?
;
Ok
(())
}
#[tokio::test]
async
fn
test_offload_onboard
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -865,7 +894,7 @@ mod tests {
.next
()
.unwrap
();
populate_
cuda_
block
(
&
immutable_device_block
,
42
)
?
;
populate_block
(
&
immutable_device_block
,
42
)
?
;
// Offload the block to the host.
offload_manager
.offload
(
&
immutable_device_block
,
0
)
.await
?
;
...
...
@@ -880,7 +909,7 @@ mod tests {
.next
()
.unwrap
();
c
ompare
_block_contents
(
&
immutable_device_block
,
&
immutable_host_block
)
?
;
c
heck
_block_contents
(
&
immutable_device_block
,
&
immutable_host_block
,
42
)
?
;
// Remove the device block from the pool by dropping it and allocating more blocks.
drop
(
immutable_device_block
);
...
...
@@ -914,14 +943,14 @@ mod tests {
BlockState
::
Registered
(
_
)
));
c
ompare
_block_contents
(
&
onboarded_blocks
[
0
],
&
immutable_host_block
)
?
;
c
heck
_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
Ok
(())
}
#[tokio::test]
async
fn
test_onboard_err_handling
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
)
?
;
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -950,7 +979,7 @@ mod tests {
#[tokio::test]
async
fn
test_offload_onboard_no_host_blocks
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
)
?
;
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
4
,
None
,
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
...
...
@@ -969,7 +998,7 @@ mod tests {
#[tokio::test]
async
fn
test_offload_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
))
?
;
let
(
offload_manager
,
_
,
host_pool
,
disk_pool
)
=
build_pools
(
4
,
Some
(
4
),
Some
(
4
)
,
None
)
?
;
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
...
...
@@ -982,7 +1011,7 @@ mod tests {
.next
()
.unwrap
();
populate_
cuda_
block
(
&
immutable_host_block
,
42
)
?
;
populate_block
(
&
immutable_host_block
,
42
)
?
;
offload_manager
.offload
(
&
immutable_host_block
,
0
)
.await
?
;
...
...
@@ -997,14 +1026,14 @@ mod tests {
immutable_host_block
.sequence_hash
()
?
);
c
ompare
_block_contents
(
&
disk_blocks
[
0
],
&
immutable_host_block
)
?
;
c
heck
_block_contents
(
&
immutable_host_block
,
&
disk_blocks
[
0
],
42
)
?
;
Ok
(())
}
#[tokio::test]
async
fn
test_onboard_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
))
?
;
let
(
offload_manager
,
device_pool
,
_
,
disk_pool
)
=
build_pools
(
4
,
None
,
Some
(
4
)
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
...
...
@@ -1017,10 +1046,14 @@ mod tests {
.next
()
.unwrap
();
populate_block
(
&
immutable_disk_block
,
42
)
?
;
let
device_block
=
offload_manager
.onboard
(
vec!
[
immutable_disk_block
.clone
()])
.await
?
;
check_block_contents
(
&
immutable_disk_block
,
&
device_block
[
0
],
42
)
?
;
assert_eq!
(
device_block
.len
(),
1
);
assert_eq!
(
device_block
[
0
]
.sequence_hash
()
?
,
...
...
@@ -1040,7 +1073,7 @@ mod tests {
#[tokio::test]
async
fn
test_bulk_transfer_disk
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
build_pools
(
8
,
Some
(
8
),
Some
(
8
))
?
;
build_pools
(
8
,
Some
(
8
),
Some
(
8
)
,
None
)
?
;
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
...
...
@@ -1050,7 +1083,7 @@ mod tests {
for
i
in
0
..
8
{
let
block
=
completed_block
(
host_pool
,
[
i
;
4
])
.await
?
;
populate_
cuda_
block
(
&
block
,
i
as
i32
)
?
;
populate_block
(
&
block
,
i
as
u8
)
?
;
host_blocks
.push
(
block
);
}
...
...
@@ -1064,24 +1097,24 @@ mod tests {
let
mut
disk_blocks
=
Vec
::
new
();
for
host_block
in
&
immutable_host_blocks
{
for
(
i
,
host_block
)
in
immutable_host_blocks
.iter
()
.enumerate
()
{
let
blocks
=
disk_pool
.match_sequence_hashes
(
vec!
[
host_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
blocks
.len
(),
1
);
c
ompare
_block_contents
(
&
blocks
[
0
],
host_block
)
?
;
c
heck
_block_contents
(
host_block
,
&
blocks
[
0
],
i
as
u8
)
?
;
disk_blocks
.push
(
blocks
[
0
]
.clone
());
}
let
device_blocks
=
offload_manager
.onboard
(
disk_blocks
.clone
())
.await
?
;
assert_eq!
(
device_blocks
.len
(),
disk_blocks
.len
());
for
disk_block
in
&
disk_blocks
{
for
(
i
,
disk_block
)
in
disk_blocks
.iter
()
.enumerate
()
{
let
blocks
=
device_pool
.match_sequence_hashes
(
vec!
[
disk_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
blocks
.len
(),
1
);
c
ompare
_block_contents
(
&
blocks
[
0
],
disk_block
)
?
;
c
heck
_block_contents
(
disk_block
,
&
blocks
[
0
],
i
as
u8
)
?
;
}
Ok
(())
...
...
@@ -1093,6 +1126,7 @@ mod tests {
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
,
None
,
Some
(
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
),
None
,
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
...
...
@@ -1101,7 +1135,9 @@ mod tests {
let
mut
disk_blocks
=
Vec
::
new
();
for
i
in
0
..
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
{
disk_blocks
.push
(
completed_block
(
disk_pool
,
[
i
as
u32
;
4
])
.await
?
);
let
disk_block
=
completed_block
(
disk_pool
,
[
i
as
u32
;
4
])
.await
?
;
populate_block
(
&
disk_block
,
i
as
u8
)
?
;
disk_blocks
.push
(
disk_block
);
}
let
immutable_disk_blocks
=
disk_pool
.register_blocks
(
disk_blocks
)
.await
?
;
...
...
@@ -1111,14 +1147,169 @@ mod tests {
.await
?
;
assert_eq!
(
device_blocks
.len
(),
2
*
MAX_TRANSFER_BATCH_SIZE
+
1
);
for
device_block
in
&
device_blocks
{
for
(
i
,
device_block
)
in
device_blocks
.iter
()
.enumerate
()
{
let
blocks
=
device_pool
.match_sequence_hashes
(
vec!
[
device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
check_block_contents
(
device_block
,
&
blocks
[
0
],
i
as
u8
)
?
;
assert_eq!
(
blocks
.len
(),
1
);
compare_block_contents
(
&
blocks
[
0
],
device_block
)
?
;
}
Ok
(())
}
#[tokio::test]
async
fn
test_onboard_unsupported_block_type
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
_
,
_
)
=
build_pools
(
1
,
None
,
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
block
=
completed_block
(
device_pool
,
[
0
;
4
])
.await
?
;
let
registered_block
=
device_pool
.register_blocks
(
vec!
[
block
])
.await
?
.into_iter
()
.next
()
.unwrap
();
let
onboarded_blocks
=
offload_manager
.onboard
(
vec!
[
registered_block
])
.await
;
assert
!
(
matches!
(
onboarded_blocks
,
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
Other
(
_
)))
));
Ok
(())
}
#[tokio::test]
async
fn
test_offload_transfer_metadata
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
mut
device_block
=
completed_block
(
device_pool
,
[
0
;
4
])
.await
?
;
populate_block
(
&
device_block
,
42
)
?
;
let
new_metadata
=
device_block
.metadata
()
.update_priority
(
1
);
device_block
.update_metadata
(
new_metadata
);
let
immutable_device_block
=
device_pool
.register_blocks
(
vec!
[
device_block
])
.await
?
.into_iter
()
.next
()
.unwrap
();
offload_manager
.offload
(
&
immutable_device_block
,
0
)
.await
?
;
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
let
host_blocks
=
host_pool
.match_sequence_hashes
(
vec!
[
immutable_device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
host_blocks
.len
(),
1
);
check_block_contents
(
&
immutable_device_block
,
&
host_blocks
[
0
],
42
)
?
;
assert_eq!
(
host_blocks
[
0
]
.metadata
()
.priority
(),
1
);
Ok
(())
}
#[tokio::test]
async
fn
test_onboard_duplicate
()
->
Result
<
()
>
{
let
(
offload_manager
,
device_pool
,
host_pool
,
_
)
=
build_pools
(
4
,
Some
(
4
),
None
,
None
)
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
;
4
])
.await
?
;
let
immutable_device_block
=
device_pool
.register_blocks
(
vec!
[
device_block
])
.await
?
.into_iter
()
.next
()
.unwrap
();
populate_block
(
&
immutable_device_block
,
42
)
?
;
offload_manager
.offload
(
&
immutable_device_block
,
0
)
.await
?
;
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
let
host_blocks
=
host_pool
.match_sequence_hashes
(
vec!
[
immutable_device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
host_blocks
.len
(),
1
);
let
onboarded_blocks
=
offload_manager
.onboard
(
vec!
[
host_blocks
[
0
]
.clone
()])
.await
?
;
assert_eq!
(
onboarded_blocks
.len
(),
1
);
check_block_contents
(
&
host_blocks
[
0
],
&
onboarded_blocks
[
0
],
42
)
?
;
// This should be the same block that we put on the device.
// The block that was copied should be discarded by the block pool.
assert_eq!
(
onboarded_blocks
[
0
]
.block_idx
(),
immutable_device_block
.block_idx
()
);
Ok
(())
}
#[tokio::test]
async
fn
test_transfer_big_blocks
()
->
Result
<
()
>
{
// Try a block size of 32 MB.
let
inner_dim
=
2_u
size
.pow
(
20
)
*
32
/
NUM_LAYERS
/
BLOCK_SIZE
;
let
(
offload_manager
,
device_pool
,
host_pool
,
disk_pool
)
=
build_pools
(
2
,
Some
(
2
),
Some
(
2
),
Some
(
inner_dim
))
?
;
let
device_pool
=
device_pool
.as_ref
()
.unwrap
();
let
host_pool
=
host_pool
.as_ref
()
.unwrap
();
let
disk_pool
=
disk_pool
.as_ref
()
.unwrap
();
let
device_block
=
completed_block
(
device_pool
,
[
0
;
4
])
.await
?
;
populate_block
(
&
device_block
,
42
)
?
;
let
immutable_device_block
=
device_pool
.register_blocks
(
vec!
[
device_block
])
.await
?
.into_iter
()
.next
()
.unwrap
();
// Offload to host.
offload_manager
.offload
(
&
immutable_device_block
,
0
)
.await
?
;
// Wait for the offload to be processed.
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
100
))
.await
;
let
host_blocks
=
host_pool
.match_sequence_hashes
(
vec!
[
immutable_device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
host_blocks
.len
(),
1
);
check_block_contents
(
&
immutable_device_block
,
&
host_blocks
[
0
],
42
)
?
;
// Offload to disk
offload_manager
.offload
(
&
host_blocks
[
0
],
0
)
.await
?
;
// Wait for the offload to be processed.
tokio
::
time
::
sleep
(
std
::
time
::
Duration
::
from_millis
(
500
))
.await
;
let
disk_blocks
=
disk_pool
.match_sequence_hashes
(
vec!
[
immutable_device_block
.sequence_hash
()
?
]
.as_slice
())
.await
?
;
assert_eq!
(
disk_blocks
.len
(),
1
);
check_block_contents
(
&
host_blocks
[
0
],
&
disk_blocks
[
0
],
42
)
?
;
// Onboard to device.
let
device_blocks
=
offload_manager
.onboard
(
disk_blocks
.clone
())
.await
?
;
assert_eq!
(
device_blocks
.len
(),
1
);
check_block_contents
(
&
disk_blocks
[
0
],
&
device_blocks
[
0
],
42
)
?
;
Ok
(())
}
}
lib/llm/src/block_manager/offload/request.rs
View file @
6d9aac77
...
...
@@ -96,3 +96,30 @@ impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source,
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_offload_request_key_ordering
()
{
let
key1
=
OffloadRequestKey
{
priority
:
1
,
timestamp
:
1
,
};
let
key2
=
OffloadRequestKey
{
priority
:
2
,
timestamp
:
2
,
};
assert
!
(
key2
<
key1
);
let
key3
=
OffloadRequestKey
{
priority
:
2
,
timestamp
:
3
,
};
assert
!
(
key2
<
key3
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment