Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
3d40a692
Unverified
Commit
3d40a692
authored
May 29, 2025
by
jthomson04
Committed by
GitHub
May 29, 2025
Browse files
feat: Restructure kv manager block registration (#1093)
parent
7d0c9386
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
261 additions
and
84 deletions
+261
-84
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+5
-3
lib/llm/src/block_manager/block/registry.rs
lib/llm/src/block_manager/block/registry.rs
+131
-36
lib/llm/src/block_manager/block/state.rs
lib/llm/src/block_manager/block/state.rs
+8
-8
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+4
-4
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+1
-1
lib/llm/src/block_manager/pool.rs
lib/llm/src/block_manager/pool.rs
+52
-10
lib/llm/src/block_manager/pool/inactive.rs
lib/llm/src/block_manager/pool/inactive.rs
+24
-6
lib/llm/src/block_manager/pool/state.rs
lib/llm/src/block_manager/pool/state.rs
+14
-6
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+22
-10
No files found.
lib/llm/src/block_manager/block.rs
View file @
3d40a692
...
@@ -21,6 +21,8 @@ pub mod view;
...
@@ -21,6 +21,8 @@ pub mod view;
pub
use
crate
::
tokens
::
TokenBlockError
;
pub
use
crate
::
tokens
::
TokenBlockError
;
pub
use
anyhow
::
Result
;
pub
use
anyhow
::
Result
;
use
nixl_sys
::
NixlDescriptor
;
use
nixl_sys
::
NixlDescriptor
;
pub
use
registry
::{
GlobalRegistry
,
RegistrationHandle
};
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
use
crate
::
block_manager
::{
use
crate
::
block_manager
::{
...
@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
...
@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
pub
fn
sequence_hash
(
&
self
)
->
Result
<
SequenceHash
,
BlockError
>
{
pub
fn
sequence_hash
(
&
self
)
->
Result
<
SequenceHash
,
BlockError
>
{
match
self
.state
()
{
match
self
.state
()
{
BlockState
::
Complete
(
state
)
=>
Ok
(
state
.token_block
()
.sequence_hash
()),
BlockState
::
Complete
(
state
)
=>
Ok
(
state
.token_block
()
.sequence_hash
()),
BlockState
::
Registered
(
state
)
=>
Ok
(
state
.sequence_hash
()),
BlockState
::
Registered
(
state
,
_
)
=>
Ok
(
state
.sequence_hash
()),
_
=>
Err
(
BlockError
::
InvalidState
(
_
=>
Err
(
BlockError
::
InvalidState
(
"Block is not complete"
.to_string
(),
"Block is not complete"
.to_string
(),
)),
)),
...
@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
...
@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
fn
register
(
fn
register
(
&
mut
self
,
&
mut
self
,
registry
:
&
mut
registry
::
BlockRegistry
,
registry
:
&
mut
registry
::
BlockRegistry
,
)
->
Result
<
PublishHandle
,
registry
::
BlockRegistationError
>
;
)
->
Result
<
Option
<
PublishHandle
>
,
registry
::
BlockRegistationError
>
;
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
PrivateBlockExt
for
Block
<
S
,
M
>
{
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
PrivateBlockExt
for
Block
<
S
,
M
>
{
fn
register
(
fn
register
(
&
mut
self
,
&
mut
self
,
registry
:
&
mut
registry
::
BlockRegistry
,
registry
:
&
mut
registry
::
BlockRegistry
,
)
->
Result
<
PublishHandle
,
registry
::
BlockRegistationError
>
{
)
->
Result
<
Option
<
PublishHandle
>
,
registry
::
BlockRegistationError
>
{
registry
.register_block
(
&
mut
self
.state
)
registry
.register_block
(
&
mut
self
.state
)
}
}
}
}
...
...
lib/llm/src/block_manager/block/registry.rs
View file @
3d40a692
...
@@ -13,9 +13,27 @@
...
@@ -13,9 +13,27 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//! # KV Cache Block Registration
//!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
//!
//! ## Workflow
//!
//! 1. When a block is registered into a pool, we create a unique block handle.
//! 2. We then check the global registry to see if the block already exists in any other pool.
//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one.
//! 4. When the block handle is dropped, it means that the block is no longer in the pool.
//! 5. When the registration handle is dropped, it means that the block is no longer in any pool.
use
std
::{
use
std
::{
collections
::
HashMap
,
collections
::
HashMap
,
sync
::{
Arc
,
Weak
},
sync
::{
Arc
,
Mutex
,
Weak
},
};
};
use
super
::
super
::
events
::{
EventManager
,
EventReleaseManager
,
PublishHandle
};
use
super
::
super
::
events
::{
EventManager
,
EventReleaseManager
,
PublishHandle
};
...
@@ -24,6 +42,9 @@ use super::state::BlockState;
...
@@ -24,6 +42,9 @@ use super::state::BlockState;
use
crate
::
tokens
::{
BlockHash
,
SequenceHash
,
TokenBlock
};
use
crate
::
tokens
::{
BlockHash
,
SequenceHash
,
TokenBlock
};
use
derive_getters
::
Getters
;
use
derive_getters
::
Getters
;
use
tokio
::{
runtime
::
Handle
,
sync
::
mpsc
};
pub
type
GlobalRegistry
=
Arc
<
Mutex
<
HashMap
<
SequenceHash
,
Weak
<
RegistrationHandle
>>>>
;
#[derive(Debug,
thiserror::Error)]
#[derive(Debug,
thiserror::Error)]
pub
enum
BlockRegistationError
{
pub
enum
BlockRegistationError
{
...
@@ -34,27 +55,88 @@ pub enum BlockRegistationError {
...
@@ -34,27 +55,88 @@ pub enum BlockRegistationError {
InvalidState
(
String
),
InvalidState
(
String
),
}
}
/// Error returned when an attempt is made to unregister a block that is still active.
/// A block entry is a handle to a block that is registered in the pool.
#[derive(Debug,
thiserror::Error)]
/// On drop, we need to notify the pool that the block has been unregistered.
#[error(
"Failed to unregister block: {0}"
)]
/// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool.
pub
struct
UnregisterFailure
(
SequenceHash
);
#[derive(Debug)]
pub
struct
BlockHandle
{
sequence_hash
:
SequenceHash
,
unregister_tx
:
mpsc
::
UnboundedSender
<
SequenceHash
>
,
}
impl
BlockHandle
{
pub
fn
new
(
sequence_hash
:
SequenceHash
,
unregister_tx
:
mpsc
::
UnboundedSender
<
SequenceHash
>
,
)
->
Self
{
Self
{
sequence_hash
,
unregister_tx
,
}
}
}
impl
Drop
for
BlockHandle
{
fn
drop
(
&
mut
self
)
{
let
_
=
self
.unregister_tx
.send
(
self
.sequence_hash
);
}
}
#[derive()]
pub
struct
BlockRegistry
{
pub
struct
BlockRegistry
{
blocks
:
HashMap
<
SequenceHash
,
Weak
<
Registration
Handle
>>
,
blocks
:
Arc
<
Mutex
<
HashMap
<
SequenceHash
,
Weak
<
Block
Handle
>>
>>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
global_registry
:
GlobalRegistry
,
unregister_tx
:
mpsc
::
UnboundedSender
<
SequenceHash
>
,
}
}
impl
BlockRegistry
{
impl
BlockRegistry
{
pub
fn
new
(
event_manager
:
Arc
<
dyn
EventManager
>
)
->
Self
{
pub
fn
new
(
event_manager
:
Arc
<
dyn
EventManager
>
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
Self
{
let
(
unregister_tx
,
mut
unregister_rx
)
=
mpsc
::
unbounded_channel
();
let
blocks
:
Arc
<
Mutex
<
HashMap
<
SequenceHash
,
Weak
<
BlockHandle
>>>>
=
Arc
::
new
(
Mutex
::
new
(
HashMap
::
new
()));
let
blocks_clone
=
blocks
.clone
();
let
global_registry_clone
=
global_registry
.clone
();
async_runtime
.spawn
(
async
move
{
let
blocks
=
blocks_clone
;
let
global_registry
=
global_registry_clone
;
while
let
Some
(
sequence_hash
)
=
unregister_rx
.recv
()
.await
{
{
let
mut
blocks
=
blocks
.lock
()
.unwrap
();
if
let
Some
(
handle
)
=
blocks
.get
(
&
sequence_hash
)
{
if
handle
.upgrade
()
.is_none
()
{
blocks
.remove
(
&
sequence_hash
);
}
}
}
let
mut
global_registry
=
global_registry
.lock
()
.unwrap
();
if
let
Some
(
entry
)
=
global_registry
.get
(
&
sequence_hash
)
{
if
entry
.upgrade
()
.is_none
()
{
global_registry
.remove
(
&
sequence_hash
);
}
}
}
});
Self
{
Self
{
blocks
:
HashMap
::
new
()
,
blocks
,
event_manager
,
event_manager
,
global_registry
,
unregister_tx
,
}
}
}
}
pub
fn
is_registered
(
&
self
,
sequence_hash
:
SequenceHash
)
->
bool
{
pub
fn
is_registered
(
&
self
,
sequence_hash
:
SequenceHash
)
->
bool
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
let
blocks
=
self
.blocks
.lock
()
.unwrap
();
if
let
Some
(
handle
)
=
blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
true
;
return
true
;
}
}
...
@@ -65,7 +147,7 @@ impl BlockRegistry {
...
@@ -65,7 +147,7 @@ impl BlockRegistry {
pub
fn
register_block
(
pub
fn
register_block
(
&
mut
self
,
&
mut
self
,
block_state
:
&
mut
BlockState
,
block_state
:
&
mut
BlockState
,
)
->
Result
<
PublishHandle
,
BlockRegistationError
>
{
)
->
Result
<
Option
<
PublishHandle
>
,
BlockRegistationError
>
{
match
block_state
{
match
block_state
{
BlockState
::
Reset
=>
Err
(
BlockRegistationError
::
InvalidState
(
BlockState
::
Reset
=>
Err
(
BlockRegistationError
::
InvalidState
(
"Block is in Reset state"
.to_string
(),
"Block is in Reset state"
.to_string
(),
...
@@ -76,47 +158,60 @@ impl BlockRegistry {
...
@@ -76,47 +158,60 @@ impl BlockRegistry {
BlockState
::
Complete
(
state
)
=>
{
BlockState
::
Complete
(
state
)
=>
{
let
sequence_hash
=
state
.token_block
()
.sequence_hash
();
let
sequence_hash
=
state
.token_block
()
.sequence_hash
();
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
let
mut
blocks
=
self
.blocks
.lock
()
.unwrap
();
// If an identical block already exists in this pool, return an error.
if
let
Some
(
handle
)
=
blocks
.get
(
&
sequence_hash
)
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
if
let
Some
(
_
handle
)
=
handle
.upgrade
()
{
return
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
sequence_hash
));
return
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
sequence_hash
));
}
}
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let
mut
publish_handle
=
None
;
let
publish_handle
=
Self
::
create_publish_handle
(
state
.token_block
(),
self
.event_manager
.clone
());
let
block_handle
=
let
reg_handle
=
publish_handle
.remove_handle
();
Arc
::
new
(
BlockHandle
::
new
(
sequence_hash
,
self
.unregister_tx
.clone
()));
let
reg_handle
=
'reg_block
:
{
// Now, check the global registry.
let
mut
global_registry
=
self
.global_registry
.lock
()
.unwrap
();
// If an identical block exists in other pool, use the same registration handle.
if
let
Some
(
handle
)
=
global_registry
.get
(
&
sequence_hash
)
{
if
let
Some
(
handle
)
=
handle
.upgrade
()
{
break
'reg_block
handle
;
}
}
// Otherwise, create a new registration handle.
publish_handle
=
Some
(
Self
::
create_publish_handle
(
state
.token_block
(),
self
.event_manager
.clone
(),
));
let
reg_handle
=
publish_handle
.as_ref
()
.unwrap
()
.remove_handle
();
// Insert the registration handle into the global registry.
global_registry
.insert
(
sequence_hash
,
Arc
::
downgrade
(
&
reg_handle
));
// Insert the [RegistrationHandle] into the registry
reg_handle
self
.blocks
};
.insert
(
sequence_hash
,
Arc
::
downgrade
(
&
reg_handle
));
blocks
.insert
(
sequence_hash
,
Arc
::
downgrade
(
&
block_handle
));
// Update the [BlockState] to [BlockState::Registered]
// Update the [BlockState] to [BlockState::Registered]
let
_
=
std
::
mem
::
replace
(
block_state
,
BlockState
::
Registered
(
reg_handle
));
let
_
=
std
::
mem
::
replace
(
block_state
,
BlockState
::
Registered
(
reg_handle
,
block_handle
),
);
Ok
(
publish_handle
)
Ok
(
publish_handle
)
}
}
BlockState
::
Registered
(
registered
)
=>
Err
(
BlockState
::
Registered
(
registered
,
_
)
=>
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
registered
.sequence_hash
()),
BlockRegistationError
::
BlockAlreadyRegistered
(
registered
.sequence_hash
()),
),
),
}
}
}
}
pub
fn
unregister_block
(
&
mut
self
,
sequence_hash
:
SequenceHash
,
)
->
Result
<
(),
UnregisterFailure
>
{
if
let
Some
(
handle
)
=
self
.blocks
.get
(
&
sequence_hash
)
{
if
handle
.upgrade
()
.is_none
()
{
self
.blocks
.remove
(
&
sequence_hash
);
return
Ok
(());
}
else
{
return
Err
(
UnregisterFailure
(
sequence_hash
));
}
}
Ok
(())
}
fn
create_publish_handle
(
fn
create_publish_handle
(
token_block
:
&
TokenBlock
,
token_block
:
&
TokenBlock
,
event_manager
:
Arc
<
dyn
EventManager
>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
...
...
lib/llm/src/block_manager/block/state.rs
View file @
3d40a692
...
@@ -17,7 +17,7 @@ use std::sync::Arc;
...
@@ -17,7 +17,7 @@ use std::sync::Arc;
use
derive_getters
::
Getters
;
use
derive_getters
::
Getters
;
use
super
::
registry
::
RegistrationHandle
;
use
super
::
registry
::
{
BlockHandle
,
RegistrationHandle
}
;
use
super
::
Result
;
use
super
::
Result
;
use
crate
::
tokens
::{
PartialTokenBlock
,
SaltHash
,
Token
,
TokenBlock
,
Tokens
};
use
crate
::
tokens
::{
PartialTokenBlock
,
SaltHash
,
Token
,
TokenBlock
,
Tokens
};
...
@@ -30,7 +30,7 @@ pub enum BlockState {
...
@@ -30,7 +30,7 @@ pub enum BlockState {
Reset
,
Reset
,
Partial
(
PartialState
),
Partial
(
PartialState
),
Complete
(
CompleteState
),
Complete
(
CompleteState
),
Registered
(
Arc
<
RegistrationHandle
>
),
Registered
(
Arc
<
RegistrationHandle
>
,
Arc
<
BlockHandle
>
),
}
}
impl
BlockState
{
impl
BlockState
{
...
@@ -109,7 +109,7 @@ impl BlockState {
...
@@ -109,7 +109,7 @@ impl BlockState {
BlockState
::
Reset
=>
Some
(
0
),
BlockState
::
Reset
=>
Some
(
0
),
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.len
()),
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.len
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()
.len
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()
.len
()),
BlockState
::
Registered
(
_
)
=>
None
,
BlockState
::
Registered
(
_
,
_
)
=>
None
,
}
}
}
}
...
@@ -126,15 +126,15 @@ impl BlockState {
...
@@ -126,15 +126,15 @@ impl BlockState {
match
self
{
match
self
{
BlockState
::
Reset
=>
true
,
BlockState
::
Reset
=>
true
,
BlockState
::
Partial
(
state
)
=>
state
.block
.is_empty
(),
BlockState
::
Partial
(
state
)
=>
state
.block
.is_empty
(),
BlockState
::
Complete
(
_
)
=>
false
,
// Always full
BlockState
::
Complete
(
_
)
=>
false
,
// Always full
BlockState
::
Registered
(
_
)
=>
false
,
// Always full
BlockState
::
Registered
(
_
,
_
)
=>
false
,
// Always full
}
}
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
{
pub
fn
tokens
(
&
self
)
->
Option
<&
Tokens
>
{
match
self
{
match
self
{
BlockState
::
Reset
|
BlockState
::
Registered
(
_
)
=>
None
,
BlockState
::
Reset
|
BlockState
::
Registered
(
_
,
_
)
=>
None
,
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.tokens
()),
BlockState
::
Partial
(
state
)
=>
Some
(
state
.block
.tokens
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()),
BlockState
::
Complete
(
state
)
=>
Some
(
state
.token_block
.tokens
()),
}
}
...
@@ -147,12 +147,12 @@ impl BlockState {
...
@@ -147,12 +147,12 @@ impl BlockState {
/// Returns true if the block is in the complete or registered state
/// Returns true if the block is in the complete or registered state
pub
fn
is_complete
(
&
self
)
->
bool
{
pub
fn
is_complete
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Complete
(
_
)
|
BlockState
::
Registered
(
_
))
matches!
(
self
,
BlockState
::
Complete
(
_
)
|
BlockState
::
Registered
(
_
,
_
))
}
}
/// Returns true if the block is in the registered state
/// Returns true if the block is in the registered state
pub
fn
is_registered
(
&
self
)
->
bool
{
pub
fn
is_registered
(
&
self
)
->
bool
{
matches!
(
self
,
BlockState
::
Registered
(
_
state
))
matches!
(
self
,
BlockState
::
Registered
(
_
state
,
_
))
}
}
}
}
...
...
lib/llm/src/block_manager/offload.rs
View file @
3d40a692
...
@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
priority
:
u64
,
priority
:
u64
,
)
->
core
::
result
::
Result
<
(),
BlockPoolError
>
{
)
->
core
::
result
::
Result
<
(),
BlockPoolError
>
{
match
block
.state
()
{
match
block
.state
()
{
BlockState
::
Registered
(
_
)
=>
{}
BlockState
::
Registered
(
_
,
_
)
=>
{}
_
=>
{
_
=>
{
return
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
InvalidState
(
return
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
InvalidState
(
"Block is not registered."
.to_string
(),
"Block is not registered."
.to_string
(),
...
@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
...
@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
)
->
BlockResult
<
DeviceStorage
,
Metadata
>
{
)
->
BlockResult
<
DeviceStorage
,
Metadata
>
{
for
block
in
&
blocks
{
for
block
in
&
blocks
{
match
block
.state
()
{
match
block
.state
()
{
BlockState
::
Registered
(
_
)
=>
{}
BlockState
::
Registered
(
_
,
_
)
=>
{}
_
=>
{
_
=>
{
return
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
InvalidState
(
return
Err
(
BlockPoolError
::
BlockError
(
BlockError
::
InvalidState
(
"Block is not registered."
.to_string
(),
"Block is not registered."
.to_string
(),
...
@@ -857,7 +857,7 @@ mod tests {
...
@@ -857,7 +857,7 @@ mod tests {
// Check that the block is registered.
// Check that the block is registered.
assert
!
(
matches!
(
assert
!
(
matches!
(
onboarded_blocks
[
0
]
.state
(),
onboarded_blocks
[
0
]
.state
(),
BlockState
::
Registered
(
_
)
BlockState
::
Registered
(
_
,
_
)
));
));
check_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
check_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
...
@@ -940,7 +940,7 @@ mod tests {
...
@@ -940,7 +940,7 @@ mod tests {
);
);
assert
!
(
matches!
(
assert
!
(
matches!
(
onboarded_blocks
[
0
]
.state
(),
onboarded_blocks
[
0
]
.state
(),
BlockState
::
Registered
(
_
)
BlockState
::
Registered
(
_
,
_
)
));
));
check_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
check_block_contents
(
&
immutable_host_block
,
&
onboarded_blocks
[
0
],
42
)
?
;
...
...
lib/llm/src/block_manager/offload/pending.rs
View file @
3d40a692
...
@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
...
@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
target
:
&
mut
MutableBlock
<
Target
,
Metadata
>
,
target
:
&
mut
MutableBlock
<
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
)
->
Result
<
()
>
{
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if
let
BlockState
::
Registered
(
reg_handle
)
=
source
.state
()
{
if
let
BlockState
::
Registered
(
reg_handle
,
_
)
=
source
.state
()
{
// Bring the block back to the 'Reset' state.
// Bring the block back to the 'Reset' state.
target
.reset
();
target
.reset
();
// Transfer metadata.
// Transfer metadata.
...
...
lib/llm/src/block_manager/pool.rs
View file @
3d40a692
...
@@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock};
...
@@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock};
use
super
::
block
::{
use
super
::
block
::{
nixl
::
short_type_name
,
registry
::
BlockRegistry
,
Block
,
BlockError
,
BlockMetadata
,
nixl
::
short_type_name
,
registry
::
BlockRegistry
,
Block
,
BlockError
,
BlockMetadata
,
GlobalRegistry
,
};
};
use
super
::
events
::{
EventManager
,
NullEventManager
};
use
super
::
events
::{
EventManager
,
NullEventManager
};
use
super
::
storage
::
Storage
;
use
super
::
storage
::
Storage
;
...
@@ -80,6 +81,7 @@ use std::{
...
@@ -80,6 +81,7 @@ use std::{
collections
::{
BTreeSet
,
HashMap
,
VecDeque
},
collections
::{
BTreeSet
,
HashMap
,
VecDeque
},
sync
::{
Arc
,
Weak
},
sync
::{
Arc
,
Weak
},
};
};
use
tokio
::
runtime
::
Handle
;
use
tokio_util
::
sync
::
CancellationToken
;
use
tokio_util
::
sync
::
CancellationToken
;
use
dynamo_runtime
::
Result
;
use
dynamo_runtime
::
Result
;
...
@@ -116,15 +118,27 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
...
@@ -116,15 +118,27 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default)]
#[builder(default)]
blocks
:
Vec
<
Block
<
S
,
M
>>
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
#[builder(default)]
global_registry
:
GlobalRegistry
,
#[builder(default
=
"Handle::current()"
)]
async_runtime
:
Handle
,
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
BlockPoolArgsBuilder
<
S
,
M
>
{
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
BlockPoolArgsBuilder
<
S
,
M
>
{
pub
fn
build
(
self
)
->
anyhow
::
Result
<
BlockPool
<
S
,
M
>>
{
pub
fn
build
(
self
)
->
anyhow
::
Result
<
BlockPool
<
S
,
M
>>
{
let
args
=
self
.build_internal
()
?
;
let
args
=
self
.build_internal
()
?
;
let
(
event_manager
,
cancel_token
,
blocks
)
=
args
.dissolve
();
let
(
event_manager
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
)
=
args
.dissolve
();
tracing
::
info!
(
"building block pool"
);
tracing
::
info!
(
"building block pool"
);
let
pool
=
BlockPool
::
new
(
event_manager
,
cancel_token
,
blocks
);
let
pool
=
BlockPool
::
new
(
event_manager
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
,
);
Ok
(
pool
)
Ok
(
pool
)
}
}
...
@@ -200,9 +214,16 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
...
@@ -200,9 +214,16 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager
:
Arc
<
dyn
EventManager
>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
cancel_token
:
CancellationToken
,
cancel_token
:
CancellationToken
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
Self
{
)
->
Self
{
let
(
pool
,
progress_engine
)
=
let
(
pool
,
progress_engine
)
=
Self
::
with_progress_engine
(
Self
::
with_progress_engine
(
event_manager
,
cancel_token
,
blocks
);
event_manager
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
,
);
// pool.runtime.handle().spawn(async move {
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
// let mut progress_engine = progress_engine;
...
@@ -239,12 +260,21 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
...
@@ -239,12 +260,21 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager
:
Arc
<
dyn
EventManager
>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
cancel_token
:
CancellationToken
,
cancel_token
:
CancellationToken
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
(
Self
,
ProgressEngine
<
S
,
M
>
)
{
)
->
(
Self
,
ProgressEngine
<
S
,
M
>
)
{
let
(
priority_tx
,
priority_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
priority_tx
,
priority_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
ctrl_tx
,
ctrl_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
ctrl_tx
,
ctrl_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
progress_engine
=
let
progress_engine
=
ProgressEngine
::
<
S
,
M
>
::
new
(
ProgressEngine
::
<
S
,
M
>
::
new
(
event_manager
,
priority_rx
,
ctrl_rx
,
cancel_token
,
blocks
);
event_manager
,
priority_rx
,
ctrl_rx
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
,
);
(
(
Self
{
Self
{
...
@@ -468,9 +498,15 @@ mod tests {
...
@@ -468,9 +498,15 @@ mod tests {
self
,
self
,
)
->
anyhow
::
Result
<
(
BlockPool
<
S
,
M
>
,
ProgressEngine
<
S
,
M
>
)
>
{
)
->
anyhow
::
Result
<
(
BlockPool
<
S
,
M
>
,
ProgressEngine
<
S
,
M
>
)
>
{
let
args
=
self
.build_internal
()
?
;
let
args
=
self
.build_internal
()
?
;
let
(
event_manager
,
cancel_token
,
blocks
)
=
args
.dissolve
();
let
(
event_manager
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
)
=
let
(
pool
,
progress_engine
)
=
args
.dissolve
();
BlockPool
::
with_progress_engine
(
event_manager
,
cancel_token
,
blocks
);
let
(
pool
,
progress_engine
)
=
BlockPool
::
with_progress_engine
(
event_manager
,
cancel_token
,
blocks
,
global_registry
,
async_runtime
,
);
Ok
((
pool
,
progress_engine
))
Ok
((
pool
,
progress_engine
))
}
}
...
@@ -560,8 +596,14 @@ mod tests {
...
@@ -560,8 +596,14 @@ mod tests {
.into_blocks
()
.into_blocks
()
.unwrap
();
.unwrap
();
let
async_runtime
=
tokio
::
runtime
::
Runtime
::
new
()
.unwrap
();
// Create the BlockPool and add the blocks
// Create the BlockPool and add the blocks
let
pool
=
BlockPool
::
builder
()
.blocks
(
blocks
)
.build
()
.unwrap
();
let
pool
=
BlockPool
::
builder
()
.blocks
(
blocks
)
.async_runtime
(
async_runtime
.handle
()
.clone
())
.build
()
.unwrap
();
// All blocks should be in the Reset/Empty state
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
// No blocks should match the expected sequence hash
...
...
lib/llm/src/block_manager/pool/inactive.rs
View file @
3d40a692
...
@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
...
@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
block
.reset
();
block
.reset
();
self
.uninitialized_set
.push_back
(
block
);
self
.uninitialized_set
.push_back
(
block
);
}
}
BlockState
::
Registered
(
state
)
=>
{
BlockState
::
Registered
(
state
,
_
)
=>
{
let
sequence_hash
=
state
.sequence_hash
();
let
sequence_hash
=
state
.sequence_hash
();
self
.insert_with_sequence_hash
(
block
,
sequence_hash
);
self
.insert_with_sequence_hash
(
block
,
sequence_hash
);
}
}
...
@@ -603,6 +603,7 @@ pub(crate) mod tests {
...
@@ -603,6 +603,7 @@ pub(crate) mod tests {
pub
fn
create_blocks
(
pub
fn
create_blocks
(
tokens
:
Tokens
,
tokens
:
Tokens
,
block_size
:
usize
,
block_size
:
usize
,
async_runtime
:
Handle
,
)
->
Vec
<
Block
<
NullDeviceStorage
,
TestMetadata
>>
{
)
->
Vec
<
Block
<
NullDeviceStorage
,
TestMetadata
>>
{
let
(
token_blocks
,
_
partial_token_block
)
=
let
(
token_blocks
,
_
partial_token_block
)
=
tokens
.into_sequence
(
block_size
,
None
)
.into_parts
();
tokens
.into_sequence
(
block_size
,
None
)
.into_parts
();
...
@@ -615,7 +616,8 @@ pub(crate) mod tests {
...
@@ -615,7 +616,8 @@ pub(crate) mod tests {
let
mut
blocks
=
create_block_collection
(
num_blocks
)
.into_blocks
()
.unwrap
();
let
mut
blocks
=
create_block_collection
(
num_blocks
)
.into_blocks
()
.unwrap
();
let
event_manager
=
NullEventManager
::
new
();
let
event_manager
=
NullEventManager
::
new
();
let
mut
registry
=
BlockRegistry
::
new
(
event_manager
);
let
mut
registry
=
BlockRegistry
::
new
(
event_manager
,
GlobalRegistry
::
default
(),
async_runtime
);
// Iterate through the generated TokenBlocks and the template Blocks,
// Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one.
// setting the state and registering each one.
...
@@ -645,6 +647,7 @@ pub(crate) mod tests {
...
@@ -645,6 +647,7 @@ pub(crate) mod tests {
tokens
:
Tokens
,
tokens
:
Tokens
,
block_size
:
usize
,
block_size
:
usize
,
pool
:
&
mut
InactiveBlockPool
<
NullDeviceStorage
,
TestMetadata
>
,
pool
:
&
mut
InactiveBlockPool
<
NullDeviceStorage
,
TestMetadata
>
,
async_runtime
:
Handle
,
)
->
(
Vec
<
Block
<
NullDeviceStorage
,
TestMetadata
>>
,
usize
)
{
)
->
(
Vec
<
Block
<
NullDeviceStorage
,
TestMetadata
>>
,
usize
)
{
let
(
mut
token_blocks
,
_
partial_token_block
)
=
let
(
mut
token_blocks
,
_
partial_token_block
)
=
tokens
.into_sequence
(
block_size
,
None
)
.into_parts
();
tokens
.into_sequence
(
block_size
,
None
)
.into_parts
();
...
@@ -657,7 +660,8 @@ pub(crate) mod tests {
...
@@ -657,7 +660,8 @@ pub(crate) mod tests {
let
matched_block_count
=
matched_blocks
.len
();
let
matched_block_count
=
matched_blocks
.len
();
let
event_manager
=
NullEventManager
::
new
();
let
event_manager
=
NullEventManager
::
new
();
let
mut
registry
=
BlockRegistry
::
new
(
event_manager
);
let
mut
registry
=
BlockRegistry
::
new
(
event_manager
,
GlobalRegistry
::
default
(),
async_runtime
);
// all matched blocks should be in the complete or registered state
// all matched blocks should be in the complete or registered state
for
block
in
&
mut
matched_blocks
{
for
block
in
&
mut
matched_blocks
{
...
@@ -697,6 +701,8 @@ pub(crate) mod tests {
...
@@ -697,6 +701,8 @@ pub(crate) mod tests {
fn
test_block_pool_lifecycle
()
{
fn
test_block_pool_lifecycle
()
{
dynamo_runtime
::
logging
::
init
();
dynamo_runtime
::
logging
::
init
();
let
async_runtime
=
tokio
::
runtime
::
Runtime
::
new
()
.unwrap
();
const
PAGE_SIZE
:
usize
=
2
;
const
PAGE_SIZE
:
usize
=
2
;
let
mut
pool
=
create_block_pool
(
10
);
let
mut
pool
=
create_block_pool
(
10
);
...
@@ -715,7 +721,12 @@ pub(crate) mod tests {
...
@@ -715,7 +721,12 @@ pub(crate) mod tests {
let
tokens
=
create_token_sequence
(
&
[
1
,
2
,
3
,
4
]);
let
tokens
=
create_token_sequence
(
&
[
1
,
2
,
3
,
4
]);
let
(
blocks
,
matched_block_count
)
=
acquire_blocks
(
tokens
.clone
(),
PAGE_SIZE
,
&
mut
pool
);
let
(
blocks
,
matched_block_count
)
=
acquire_blocks
(
tokens
.clone
(),
PAGE_SIZE
,
&
mut
pool
,
async_runtime
.handle
()
.clone
(),
);
assert_eq!
(
blocks
.len
(),
2
);
assert_eq!
(
blocks
.len
(),
2
);
assert_eq!
(
matched_block_count
,
0
);
assert_eq!
(
matched_block_count
,
0
);
assert_eq!
(
pool
.available_blocks
(),
8
);
assert_eq!
(
pool
.available_blocks
(),
8
);
...
@@ -725,7 +736,12 @@ pub(crate) mod tests {
...
@@ -725,7 +736,12 @@ pub(crate) mod tests {
assert_eq!
(
pool
.total_blocks
(),
10
);
assert_eq!
(
pool
.total_blocks
(),
10
);
assert_eq!
(
pool
.available_blocks
(),
10
);
assert_eq!
(
pool
.available_blocks
(),
10
);
let
(
blocks
,
matched_block_count
)
=
acquire_blocks
(
tokens
.clone
(),
PAGE_SIZE
,
&
mut
pool
);
let
(
blocks
,
matched_block_count
)
=
acquire_blocks
(
tokens
.clone
(),
PAGE_SIZE
,
&
mut
pool
,
async_runtime
.handle
()
.clone
(),
);
assert_eq!
(
blocks
.len
(),
2
);
assert_eq!
(
blocks
.len
(),
2
);
assert_eq!
(
matched_block_count
,
2
);
assert_eq!
(
matched_block_count
,
2
);
assert_eq!
(
pool
.available_blocks
(),
8
);
assert_eq!
(
pool
.available_blocks
(),
8
);
...
@@ -745,9 +761,11 @@ pub(crate) mod tests {
...
@@ -745,9 +761,11 @@ pub(crate) mod tests {
fn
test_basic_sequence_matching
()
{
fn
test_basic_sequence_matching
()
{
let
mut
pool
=
InactiveBlockPool
::
new
();
let
mut
pool
=
InactiveBlockPool
::
new
();
let
async_runtime
=
tokio
::
runtime
::
Runtime
::
new
()
.unwrap
();
// Create a sequence of 4 tokens split into blocks of 2
// Create a sequence of 4 tokens split into blocks of 2
let
sequence
=
create_token_sequence
(
&
[
1
,
2
,
3
,
4
]);
let
sequence
=
create_token_sequence
(
&
[
1
,
2
,
3
,
4
]);
let
blocks
=
create_blocks
(
sequence
,
2
);
let
blocks
=
create_blocks
(
sequence
,
2
,
async_runtime
.handle
()
.clone
()
);
assert_eq!
(
blocks
.len
(),
2
);
assert_eq!
(
blocks
.len
(),
2
);
// Match the blocks in sequence
// Match the blocks in sequence
...
...
lib/llm/src/block_manager/pool/state.rs
View file @
3d40a692
...
@@ -24,11 +24,13 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -24,11 +24,13 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
fn
new
(
fn
new
(
event_manager
:
Arc
<
dyn
EventManager
>
,
event_manager
:
Arc
<
dyn
EventManager
>
,
return_tx
:
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Block
<
S
,
M
>>
,
return_tx
:
tokio
::
sync
::
mpsc
::
UnboundedSender
<
Block
<
S
,
M
>>
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
Self
{
)
->
Self
{
Self
{
Self
{
active
:
ActiveBlockPool
::
new
(),
active
:
ActiveBlockPool
::
new
(),
inactive
:
InactiveBlockPool
::
new
(),
inactive
:
InactiveBlockPool
::
new
(),
registry
:
BlockRegistry
::
new
(
event_manager
.clone
()),
registry
:
BlockRegistry
::
new
(
event_manager
.clone
()
,
global_registry
,
async_runtime
),
return_tx
,
return_tx
,
event_manager
,
event_manager
,
}
}
...
@@ -88,7 +90,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -88,7 +90,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
return_rx
:
&
mut
tokio
::
sync
::
mpsc
::
UnboundedReceiver
<
Block
<
S
,
M
>>
,
return_rx
:
&
mut
tokio
::
sync
::
mpsc
::
UnboundedReceiver
<
Block
<
S
,
M
>>
,
)
->
Block
<
S
,
M
>
{
)
->
Block
<
S
,
M
>
{
while
let
Some
(
block
)
=
return_rx
.recv
()
.await
{
while
let
Some
(
block
)
=
return_rx
.recv
()
.await
{
if
matches!
(
block
.state
(),
BlockState
::
Registered
(
handle
)
if
handle
.sequence_hash
()
==
sequence_hash
)
if
matches!
(
block
.state
(),
BlockState
::
Registered
(
handle
,
_
)
if
handle
.sequence_hash
()
==
sequence_hash
)
{
{
return
block
;
return
block
;
}
}
...
@@ -151,7 +153,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -151,7 +153,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let
mutable
=
if
let
Some
(
raw_block
)
=
self
.inactive
.match_sequence_hash
(
sequence_hash
)
let
mutable
=
if
let
Some
(
raw_block
)
=
self
.inactive
.match_sequence_hash
(
sequence_hash
)
{
{
assert
!
(
matches!
(
raw_block
.state
(),
BlockState
::
Registered
(
_
)));
assert
!
(
matches!
(
raw_block
.state
(),
BlockState
::
Registered
(
_
,
_
)));
MutableBlock
::
new
(
raw_block
,
self
.return_tx
.clone
())
MutableBlock
::
new
(
raw_block
,
self
.return_tx
.clone
())
}
else
{
}
else
{
// Attempt to register the block
// Attempt to register the block
...
@@ -161,7 +163,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -161,7 +163,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
match
result
{
match
result
{
Ok
(
handle
)
=>
{
Ok
(
handle
)
=>
{
publish_handles
.take_handle
(
handle
);
// Only create our publish handle if this block is new, and not transfered.
if
let
Some
(
handle
)
=
handle
{
publish_handles
.take_handle
(
handle
);
}
block
block
}
}
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
_
))
=>
{
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
_
))
=>
{
...
@@ -222,7 +227,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
...
@@ -222,7 +227,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
};
};
// this assert allows us to skip the error checking on the active pool registration step
// this assert allows us to skip the error checking on the active pool registration step
assert
!
(
matches!
(
raw_block
.state
(),
BlockState
::
Registered
(
_
)));
assert
!
(
matches!
(
raw_block
.state
(),
BlockState
::
Registered
(
_
,
_
)));
let
mutable
=
MutableBlock
::
new
(
raw_block
,
self
.return_tx
.clone
());
let
mutable
=
MutableBlock
::
new
(
raw_block
,
self
.return_tx
.clone
());
...
@@ -255,9 +260,12 @@ impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
...
@@ -255,9 +260,12 @@ impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
ctrl_rx
:
tokio
::
sync
::
mpsc
::
UnboundedReceiver
<
ControlRequest
<
S
,
M
>>
,
ctrl_rx
:
tokio
::
sync
::
mpsc
::
UnboundedReceiver
<
ControlRequest
<
S
,
M
>>
,
cancel_token
:
CancellationToken
,
cancel_token
:
CancellationToken
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
blocks
:
Vec
<
Block
<
S
,
M
>>
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
Self
{
)
->
Self
{
let
(
return_tx
,
return_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
(
return_tx
,
return_rx
)
=
tokio
::
sync
::
mpsc
::
unbounded_channel
();
let
mut
state
=
State
::
<
S
,
M
>
::
new
(
event_manager
,
return_tx
);
let
mut
state
=
State
::
<
S
,
M
>
::
new
(
event_manager
,
return_tx
,
global_registry
,
async_runtime
);
tracing
::
debug!
(
count
=
blocks
.len
(),
"adding blocks to inactive pool"
);
tracing
::
debug!
(
count
=
blocks
.len
(),
"adding blocks to inactive pool"
);
state
.inactive
.add_blocks
(
blocks
);
state
.inactive
.add_blocks
(
blocks
);
...
...
lib/llm/src/block_manager/state.rs
View file @
3d40a692
...
@@ -17,7 +17,7 @@ use super::*;
...
@@ -17,7 +17,7 @@ use super::*;
use
super
::
offload
::
OffloadManager
;
use
super
::
offload
::
OffloadManager
;
use
super
::{
use
super
::{
block
::{
Block
,
ImmutableBlock
},
block
::{
Block
,
GlobalRegistry
,
ImmutableBlock
},
config
::
NixlOptions
,
config
::
NixlOptions
,
};
};
use
cudarc
::
driver
::
CudaStream
;
use
cudarc
::
driver
::
CudaStream
;
...
@@ -76,6 +76,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -76,6 +76,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
// Create a map of NIXL backends
// Create a map of NIXL backends
let
mut
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
=
HashMap
::
new
();
let
mut
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
=
HashMap
::
new
();
let
global_registry
=
GlobalRegistry
::
default
();
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
// TODO: Build a map of NIXL backends to block pools/sets
let
nixl_agent
=
Arc
::
new
(
match
config
.runtime.nixl
{
let
nixl_agent
=
Arc
::
new
(
match
config
.runtime.nixl
{
...
@@ -123,6 +125,14 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -123,6 +125,14 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let
mut
next_block_set_idx
=
0
;
let
mut
next_block_set_idx
=
0
;
let
mut
local_block_set
=
block
::
nixl
::
NixlBlockSet
::
new
(
worker_id
);
let
mut
local_block_set
=
block
::
nixl
::
NixlBlockSet
::
new
(
worker_id
);
let
async_rt_handle
=
match
config
.runtime.async_runtime
{
Some
(
rt
)
=>
rt
.handle
()
.clone
(),
None
=>
match
Handle
::
try_current
()
{
Ok
(
handle
)
=>
handle
,
Err
(
e
)
=>
anyhow
::
bail!
(
e
),
},
};
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
let
(
disk_pool
,
disk_blocks
)
=
if
let
Some
(
config
)
=
config
.disk_layout
{
if
nixl_agent
.is_none
()
{
if
nixl_agent
.is_none
()
{
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
tracing
::
warn!
(
"NIXL is disabled; will not allocate disk blocks."
);
...
@@ -138,6 +148,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -138,6 +148,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx
,
next_block_set_idx
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
global_registry
.clone
(),
async_rt_handle
.clone
(),
)
?
;
)
?
;
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
}
...
@@ -158,6 +170,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -158,6 +170,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx
,
next_block_set_idx
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
global_registry
.clone
(),
async_rt_handle
.clone
(),
)
?
;
)
?
;
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
...
@@ -177,6 +191,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -177,6 +191,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx
,
next_block_set_idx
,
cancellation_token
.clone
(),
cancellation_token
.clone
(),
worker_id
,
worker_id
,
global_registry
.clone
(),
async_rt_handle
.clone
(),
)
?
;
)
?
;
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
(
Some
(
Arc
::
new
(
pool
)),
Some
(
blocks
))
}
else
{
}
else
{
...
@@ -190,20 +206,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...
@@ -190,20 +206,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set
.set_nixl_metadata
(
nixl_agent
.get_local_md
()
?
);
local_block_set
.set_nixl_metadata
(
nixl_agent
.get_local_md
()
?
);
}
}
let
offload_async_rt_handle
=
match
config
.runtime.async_runtime
{
Some
(
rt
)
=>
rt
.handle
()
.clone
(),
None
=>
match
Handle
::
try_current
()
{
Ok
(
handle
)
=>
handle
,
Err
(
e
)
=>
anyhow
::
bail!
(
e
),
},
};
let
offload_manager
=
OffloadManager
::
new
(
let
offload_manager
=
OffloadManager
::
new
(
disk_pool
.clone
(),
disk_pool
.clone
(),
host_pool
.clone
(),
host_pool
.clone
(),
device_pool
.clone
(),
device_pool
.clone
(),
nixl_agent
.clone
(),
nixl_agent
.clone
(),
offload_
async_rt_handle
,
async_rt_handle
,
)
?
;
)
?
;
let
state
=
Arc
::
new
(
Self
{
let
state
=
Arc
::
new
(
Self
{
...
@@ -484,10 +492,14 @@ fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
...
@@ -484,10 +492,14 @@ fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
block_set_idx
:
usize
,
block_set_idx
:
usize
,
cancellation_token
:
CancellationToken
,
cancellation_token
:
CancellationToken
,
worker_id
:
WorkerID
,
worker_id
:
WorkerID
,
global_registry
:
GlobalRegistry
,
async_runtime
:
Handle
,
)
->
Result
<
(
BlockPool
<
S
,
M
>
,
Vec
<
Block
<
S
,
M
>>
)
>
{
)
->
Result
<
(
BlockPool
<
S
,
M
>
,
Vec
<
Block
<
S
,
M
>>
)
>
{
let
blocks
=
block
::
layout_to_blocks
::
<
_
,
M
>
(
layout
,
block_set_idx
,
worker_id
)
?
;
let
blocks
=
block
::
layout_to_blocks
::
<
_
,
M
>
(
layout
,
block_set_idx
,
worker_id
)
?
;
let
pool
=
BlockPool
::
<
S
,
M
>
::
builder
()
let
pool
=
BlockPool
::
<
S
,
M
>
::
builder
()
.cancel_token
(
cancellation_token
)
.cancel_token
(
cancellation_token
)
.global_registry
(
global_registry
)
.async_runtime
(
async_runtime
)
.build
()
?
;
.build
()
?
;
Ok
((
pool
,
blocks
))
Ok
((
pool
,
blocks
))
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment