Unverified Commit 3d40a692 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Restructure kv manager block registration (#1093)

parent 7d0c9386
...@@ -21,6 +21,8 @@ pub mod view; ...@@ -21,6 +21,8 @@ pub mod view;
pub use crate::tokens::TokenBlockError; pub use crate::tokens::TokenBlockError;
pub use anyhow::Result; pub use anyhow::Result;
use nixl_sys::NixlDescriptor; use nixl_sys::NixlDescriptor;
pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid}; pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{ use crate::block_manager::{
...@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> { ...@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> { pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() { match self.state() {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()), BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state) => Ok(state.sequence_hash()), BlockState::Registered(state, _) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState( _ => Err(BlockError::InvalidState(
"Block is not complete".to_string(), "Block is not complete".to_string(),
)), )),
...@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt { ...@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
fn register( fn register(
&mut self, &mut self,
registry: &mut registry::BlockRegistry, registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError>; ) -> Result<Option<PublishHandle>, registry::BlockRegistationError>;
} }
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> { impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
fn register( fn register(
&mut self, &mut self,
registry: &mut registry::BlockRegistry, registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError> { ) -> Result<Option<PublishHandle>, registry::BlockRegistationError> {
registry.register_block(&mut self.state) registry.register_block(&mut self.state)
} }
} }
......
...@@ -13,9 +13,27 @@ ...@@ -13,9 +13,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! # KV Cache Block Registration
//!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
//!
//! ## Workflow
//!
//! 1. When a block is registered into a pool, we create a unique block handle.
//! 2. We then check the global registry to see if the block already exists in any other pool.
//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one.
//! 4. When the block handle is dropped, it means that the block is no longer in the pool.
//! 5. When the registration handle is dropped, it means that the block is no longer in any pool.
use std::{ use std::{
collections::HashMap, collections::HashMap,
sync::{Arc, Weak}, sync::{Arc, Mutex, Weak},
}; };
use super::super::events::{EventManager, EventReleaseManager, PublishHandle}; use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
...@@ -24,6 +42,9 @@ use super::state::BlockState; ...@@ -24,6 +42,9 @@ use super::state::BlockState;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock}; use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters; use derive_getters::Getters;
use tokio::{runtime::Handle, sync::mpsc};
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum BlockRegistationError { pub enum BlockRegistationError {
...@@ -34,27 +55,88 @@ pub enum BlockRegistationError { ...@@ -34,27 +55,88 @@ pub enum BlockRegistationError {
InvalidState(String), InvalidState(String),
} }
/// Error returned when an attempt is made to unregister a block that is still active. /// A block entry is a handle to a block that is registered in the pool.
#[derive(Debug, thiserror::Error)] /// On drop, we need to notify the pool that the block has been unregistered.
#[error("Failed to unregister block: {0}")] /// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool.
pub struct UnregisterFailure(SequenceHash); #[derive(Debug)]
pub struct BlockHandle {
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
}
impl BlockHandle {
pub fn new(
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
) -> Self {
Self {
sequence_hash,
unregister_tx,
}
}
}
impl Drop for BlockHandle {
fn drop(&mut self) {
let _ = self.unregister_tx.send(self.sequence_hash);
}
}
#[derive()]
pub struct BlockRegistry { pub struct BlockRegistry {
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>, blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
} }
impl BlockRegistry { impl BlockRegistry {
pub fn new(event_manager: Arc<dyn EventManager>) -> Self { pub fn new(
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();
let blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>> =
Arc::new(Mutex::new(HashMap::new()));
let blocks_clone = blocks.clone();
let global_registry_clone = global_registry.clone();
async_runtime.spawn(async move {
let blocks = blocks_clone;
let global_registry = global_registry_clone;
while let Some(sequence_hash) = unregister_rx.recv().await {
{
let mut blocks = blocks.lock().unwrap();
if let Some(handle) = blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
blocks.remove(&sequence_hash);
}
}
}
let mut global_registry = global_registry.lock().unwrap();
if let Some(entry) = global_registry.get(&sequence_hash) {
if entry.upgrade().is_none() {
global_registry.remove(&sequence_hash);
}
}
}
});
Self { Self {
blocks: HashMap::new(), blocks,
event_manager, event_manager,
global_registry,
unregister_tx,
} }
} }
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool { pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
if let Some(handle) = self.blocks.get(&sequence_hash) { let blocks = self.blocks.lock().unwrap();
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() { if let Some(_handle) = handle.upgrade() {
return true; return true;
} }
...@@ -65,7 +147,7 @@ impl BlockRegistry { ...@@ -65,7 +147,7 @@ impl BlockRegistry {
pub fn register_block( pub fn register_block(
&mut self, &mut self,
block_state: &mut BlockState, block_state: &mut BlockState,
) -> Result<PublishHandle, BlockRegistationError> { ) -> Result<Option<PublishHandle>, BlockRegistationError> {
match block_state { match block_state {
BlockState::Reset => Err(BlockRegistationError::InvalidState( BlockState::Reset => Err(BlockRegistationError::InvalidState(
"Block is in Reset state".to_string(), "Block is in Reset state".to_string(),
...@@ -76,47 +158,60 @@ impl BlockRegistry { ...@@ -76,47 +158,60 @@ impl BlockRegistry {
BlockState::Complete(state) => { BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash(); let sequence_hash = state.token_block().sequence_hash();
if let Some(handle) = self.blocks.get(&sequence_hash) { let mut blocks = self.blocks.lock().unwrap();
// If an identical block already exists in this pool, return an error.
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() { if let Some(_handle) = handle.upgrade() {
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash)); return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
} }
} }
// Create the [RegistrationHandle] and [PublishHandle] let mut publish_handle = None;
let publish_handle =
Self::create_publish_handle(state.token_block(), self.event_manager.clone()); let block_handle =
let reg_handle = publish_handle.remove_handle(); Arc::new(BlockHandle::new(sequence_hash, self.unregister_tx.clone()));
let reg_handle = 'reg_block: {
// Now, check the global registry.
let mut global_registry = self.global_registry.lock().unwrap();
// If an identical block exists in other pool, use the same registration handle.
if let Some(handle) = global_registry.get(&sequence_hash) {
if let Some(handle) = handle.upgrade() {
break 'reg_block handle;
}
}
// Otherwise, create a new registration handle.
publish_handle = Some(Self::create_publish_handle(
state.token_block(),
self.event_manager.clone(),
));
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
// Insert the registration handle into the global registry.
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle));
// Insert the [RegistrationHandle] into the registry reg_handle
self.blocks };
.insert(sequence_hash, Arc::downgrade(&reg_handle));
blocks.insert(sequence_hash, Arc::downgrade(&block_handle));
// Update the [BlockState] to [BlockState::Registered] // Update the [BlockState] to [BlockState::Registered]
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle)); let _ = std::mem::replace(
block_state,
BlockState::Registered(reg_handle, block_handle),
);
Ok(publish_handle) Ok(publish_handle)
} }
BlockState::Registered(registered) => Err( BlockState::Registered(registered, _) => Err(
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()), BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
), ),
} }
} }
pub fn unregister_block(
&mut self,
sequence_hash: SequenceHash,
) -> Result<(), UnregisterFailure> {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
self.blocks.remove(&sequence_hash);
return Ok(());
} else {
return Err(UnregisterFailure(sequence_hash));
}
}
Ok(())
}
fn create_publish_handle( fn create_publish_handle(
token_block: &TokenBlock, token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
......
...@@ -17,7 +17,7 @@ use std::sync::Arc; ...@@ -17,7 +17,7 @@ use std::sync::Arc;
use derive_getters::Getters; use derive_getters::Getters;
use super::registry::RegistrationHandle; use super::registry::{BlockHandle, RegistrationHandle};
use super::Result; use super::Result;
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens}; use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};
...@@ -30,7 +30,7 @@ pub enum BlockState { ...@@ -30,7 +30,7 @@ pub enum BlockState {
Reset, Reset,
Partial(PartialState), Partial(PartialState),
Complete(CompleteState), Complete(CompleteState),
Registered(Arc<RegistrationHandle>), Registered(Arc<RegistrationHandle>, Arc<BlockHandle>),
} }
impl BlockState { impl BlockState {
...@@ -109,7 +109,7 @@ impl BlockState { ...@@ -109,7 +109,7 @@ impl BlockState {
BlockState::Reset => Some(0), BlockState::Reset => Some(0),
BlockState::Partial(state) => Some(state.block.len()), BlockState::Partial(state) => Some(state.block.len()),
BlockState::Complete(state) => Some(state.token_block.tokens().len()), BlockState::Complete(state) => Some(state.token_block.tokens().len()),
BlockState::Registered(_) => None, BlockState::Registered(_, _) => None,
} }
} }
...@@ -126,15 +126,15 @@ impl BlockState { ...@@ -126,15 +126,15 @@ impl BlockState {
match self { match self {
BlockState::Reset => true, BlockState::Reset => true,
BlockState::Partial(state) => state.block.is_empty(), BlockState::Partial(state) => state.block.is_empty(),
BlockState::Complete(_) => false, // Always full BlockState::Complete(_) => false, // Always full
BlockState::Registered(_) => false, // Always full BlockState::Registered(_, _) => false, // Always full
} }
} }
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered. /// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub fn tokens(&self) -> Option<&Tokens> { pub fn tokens(&self) -> Option<&Tokens> {
match self { match self {
BlockState::Reset | BlockState::Registered(_) => None, BlockState::Reset | BlockState::Registered(_, _) => None,
BlockState::Partial(state) => Some(state.block.tokens()), BlockState::Partial(state) => Some(state.block.tokens()),
BlockState::Complete(state) => Some(state.token_block.tokens()), BlockState::Complete(state) => Some(state.token_block.tokens()),
} }
...@@ -147,12 +147,12 @@ impl BlockState { ...@@ -147,12 +147,12 @@ impl BlockState {
/// Returns true if the block is in the complete or registered state /// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool { pub fn is_complete(&self) -> bool {
matches!(self, BlockState::Complete(_) | BlockState::Registered(_)) matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _))
} }
/// Returns true if the block is in the registered state /// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool { pub fn is_registered(&self) -> bool {
matches!(self, BlockState::Registered(_state)) matches!(self, BlockState::Registered(_state, _))
} }
} }
......
...@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
priority: u64, priority: u64,
) -> core::result::Result<(), BlockPoolError> { ) -> core::result::Result<(), BlockPoolError> {
match block.state() { match block.state() {
BlockState::Registered(_) => {} BlockState::Registered(_, _) => {}
_ => { _ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState( return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(), "Block is not registered.".to_string(),
...@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> { ...@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
) -> BlockResult<DeviceStorage, Metadata> { ) -> BlockResult<DeviceStorage, Metadata> {
for block in &blocks { for block in &blocks {
match block.state() { match block.state() {
BlockState::Registered(_) => {} BlockState::Registered(_, _) => {}
_ => { _ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState( return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(), "Block is not registered.".to_string(),
...@@ -857,7 +857,7 @@ mod tests { ...@@ -857,7 +857,7 @@ mod tests {
// Check that the block is registered. // Check that the block is registered.
assert!(matches!( assert!(matches!(
onboarded_blocks[0].state(), onboarded_blocks[0].state(),
BlockState::Registered(_) BlockState::Registered(_, _)
)); ));
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?; check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
...@@ -940,7 +940,7 @@ mod tests { ...@@ -940,7 +940,7 @@ mod tests {
); );
assert!(matches!( assert!(matches!(
onboarded_blocks[0].state(), onboarded_blocks[0].state(),
BlockState::Registered(_) BlockState::Registered(_, _)
)); ));
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?; check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
......
...@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>( ...@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
target: &mut MutableBlock<Target, Metadata>, target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> { ) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail. // Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() { if let BlockState::Registered(reg_handle, _) = source.state() {
// Bring the block back to the 'Reset' state. // Bring the block back to the 'Reset' state.
target.reset(); target.reset();
// Transfer metadata. // Transfer metadata.
......
...@@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock}; ...@@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock};
use super::block::{ use super::block::{
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata, nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata,
GlobalRegistry,
}; };
use super::events::{EventManager, NullEventManager}; use super::events::{EventManager, NullEventManager};
use super::storage::Storage; use super::storage::Storage;
...@@ -80,6 +81,7 @@ use std::{ ...@@ -80,6 +81,7 @@ use std::{
collections::{BTreeSet, HashMap, VecDeque}, collections::{BTreeSet, HashMap, VecDeque},
sync::{Arc, Weak}, sync::{Arc, Weak},
}; };
use tokio::runtime::Handle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result; use dynamo_runtime::Result;
...@@ -116,15 +118,27 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> { ...@@ -116,15 +118,27 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default)] #[builder(default)]
blocks: Vec<Block<S, M>>, blocks: Vec<Block<S, M>>,
#[builder(default)]
global_registry: GlobalRegistry,
#[builder(default = "Handle::current()")]
async_runtime: Handle,
} }
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> { impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> { pub fn build(self) -> anyhow::Result<BlockPool<S, M>> {
let args = self.build_internal()?; let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve(); let (event_manager, cancel_token, blocks, global_registry, async_runtime) = args.dissolve();
tracing::info!("building block pool"); tracing::info!("building block pool");
let pool = BlockPool::new(event_manager, cancel_token, blocks); let pool = BlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
Ok(pool) Ok(pool)
} }
...@@ -200,9 +214,16 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> { ...@@ -200,9 +214,16 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>, blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self { ) -> Self {
let (pool, progress_engine) = let (pool, progress_engine) = Self::with_progress_engine(
Self::with_progress_engine(event_manager, cancel_token, blocks); event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
// pool.runtime.handle().spawn(async move { // pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine; // let mut progress_engine = progress_engine;
...@@ -239,12 +260,21 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> { ...@@ -239,12 +260,21 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>, blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> (Self, ProgressEngine<S, M>) { ) -> (Self, ProgressEngine<S, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel(); let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel(); let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine = let progress_engine = ProgressEngine::<S, M>::new(
ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token, blocks); event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
);
( (
Self { Self {
...@@ -468,9 +498,15 @@ mod tests { ...@@ -468,9 +498,15 @@ mod tests {
self, self,
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> { ) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
let args = self.build_internal()?; let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve(); let (event_manager, cancel_token, blocks, global_registry, async_runtime) =
let (pool, progress_engine) = args.dissolve();
BlockPool::with_progress_engine(event_manager, cancel_token, blocks); let (pool, progress_engine) = BlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
Ok((pool, progress_engine)) Ok((pool, progress_engine))
} }
...@@ -560,8 +596,14 @@ mod tests { ...@@ -560,8 +596,14 @@ mod tests {
.into_blocks() .into_blocks()
.unwrap(); .unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the BlockPool and add the blocks // Create the BlockPool and add the blocks
let pool = BlockPool::builder().blocks(blocks).build().unwrap(); let pool = BlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state // All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash // No blocks should match the expected sequence hash
......
...@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> { ...@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
block.reset(); block.reset();
self.uninitialized_set.push_back(block); self.uninitialized_set.push_back(block);
} }
BlockState::Registered(state) => { BlockState::Registered(state, _) => {
let sequence_hash = state.sequence_hash(); let sequence_hash = state.sequence_hash();
self.insert_with_sequence_hash(block, sequence_hash); self.insert_with_sequence_hash(block, sequence_hash);
} }
...@@ -603,6 +603,7 @@ pub(crate) mod tests { ...@@ -603,6 +603,7 @@ pub(crate) mod tests {
pub fn create_blocks( pub fn create_blocks(
tokens: Tokens, tokens: Tokens,
block_size: usize, block_size: usize,
async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> { ) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
let (token_blocks, _partial_token_block) = let (token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts(); tokens.into_sequence(block_size, None).into_parts();
...@@ -615,7 +616,8 @@ pub(crate) mod tests { ...@@ -615,7 +616,8 @@ pub(crate) mod tests {
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap(); let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
let event_manager = NullEventManager::new(); let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager); let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
// Iterate through the generated TokenBlocks and the template Blocks, // Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one. // setting the state and registering each one.
...@@ -645,6 +647,7 @@ pub(crate) mod tests { ...@@ -645,6 +647,7 @@ pub(crate) mod tests {
tokens: Tokens, tokens: Tokens,
block_size: usize, block_size: usize,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>, pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) { ) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
let (mut token_blocks, _partial_token_block) = let (mut token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts(); tokens.into_sequence(block_size, None).into_parts();
...@@ -657,7 +660,8 @@ pub(crate) mod tests { ...@@ -657,7 +660,8 @@ pub(crate) mod tests {
let matched_block_count = matched_blocks.len(); let matched_block_count = matched_blocks.len();
let event_manager = NullEventManager::new(); let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager); let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
// all matched blocks should be in the complete or registered state // all matched blocks should be in the complete or registered state
for block in &mut matched_blocks { for block in &mut matched_blocks {
...@@ -697,6 +701,8 @@ pub(crate) mod tests { ...@@ -697,6 +701,8 @@ pub(crate) mod tests {
fn test_block_pool_lifecycle() { fn test_block_pool_lifecycle() {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
const PAGE_SIZE: usize = 2; const PAGE_SIZE: usize = 2;
let mut pool = create_block_pool(10); let mut pool = create_block_pool(10);
...@@ -715,7 +721,12 @@ pub(crate) mod tests { ...@@ -715,7 +721,12 @@ pub(crate) mod tests {
let tokens = create_token_sequence(&[1, 2, 3, 4]); let tokens = create_token_sequence(&[1, 2, 3, 4]);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool); let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(),
PAGE_SIZE,
&mut pool,
async_runtime.handle().clone(),
);
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 0); assert_eq!(matched_block_count, 0);
assert_eq!(pool.available_blocks(), 8); assert_eq!(pool.available_blocks(), 8);
...@@ -725,7 +736,12 @@ pub(crate) mod tests { ...@@ -725,7 +736,12 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 10); assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10); assert_eq!(pool.available_blocks(), 10);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool); let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(),
PAGE_SIZE,
&mut pool,
async_runtime.handle().clone(),
);
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 2); assert_eq!(matched_block_count, 2);
assert_eq!(pool.available_blocks(), 8); assert_eq!(pool.available_blocks(), 8);
...@@ -745,9 +761,11 @@ pub(crate) mod tests { ...@@ -745,9 +761,11 @@ pub(crate) mod tests {
fn test_basic_sequence_matching() { fn test_basic_sequence_matching() {
let mut pool = InactiveBlockPool::new(); let mut pool = InactiveBlockPool::new();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create a sequence of 4 tokens split into blocks of 2 // Create a sequence of 4 tokens split into blocks of 2
let sequence = create_token_sequence(&[1, 2, 3, 4]); let sequence = create_token_sequence(&[1, 2, 3, 4]);
let blocks = create_blocks(sequence, 2); let blocks = create_blocks(sequence, 2, async_runtime.handle().clone());
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
// Match the blocks in sequence // Match the blocks in sequence
......
...@@ -24,11 +24,13 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -24,11 +24,13 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
fn new( fn new(
event_manager: Arc<dyn EventManager>, event_manager: Arc<dyn EventManager>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>, return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self { ) -> Self {
Self { Self {
active: ActiveBlockPool::new(), active: ActiveBlockPool::new(),
inactive: InactiveBlockPool::new(), inactive: InactiveBlockPool::new(),
registry: BlockRegistry::new(event_manager.clone()), registry: BlockRegistry::new(event_manager.clone(), global_registry, async_runtime),
return_tx, return_tx,
event_manager, event_manager,
} }
...@@ -88,7 +90,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -88,7 +90,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>, return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Block<S, M> { ) -> Block<S, M> {
while let Some(block) = return_rx.recv().await { while let Some(block) = return_rx.recv().await {
if matches!(block.state(), BlockState::Registered(handle) if handle.sequence_hash() == sequence_hash) if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash)
{ {
return block; return block;
} }
...@@ -151,7 +153,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -151,7 +153,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{ {
assert!(matches!(raw_block.state(), BlockState::Registered(_))); assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
MutableBlock::new(raw_block, self.return_tx.clone()) MutableBlock::new(raw_block, self.return_tx.clone())
} else { } else {
// Attempt to register the block // Attempt to register the block
...@@ -161,7 +163,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -161,7 +163,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
match result { match result {
Ok(handle) => { Ok(handle) => {
publish_handles.take_handle(handle); // Only create our publish handle if this block is new, and not transfered.
if let Some(handle) = handle {
publish_handles.take_handle(handle);
}
block block
} }
Err(BlockRegistationError::BlockAlreadyRegistered(_)) => { Err(BlockRegistationError::BlockAlreadyRegistered(_)) => {
...@@ -222,7 +227,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -222,7 +227,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
}; };
// this assert allows us to skip the error checking on the active pool registration step // this assert allows us to skip the error checking on the active pool registration step
assert!(matches!(raw_block.state(), BlockState::Registered(_))); assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
let mutable = MutableBlock::new(raw_block, self.return_tx.clone()); let mutable = MutableBlock::new(raw_block, self.return_tx.clone());
...@@ -255,9 +260,12 @@ impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> { ...@@ -255,9 +260,12 @@ impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>, ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>, blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self { ) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, M>::new(event_manager, return_tx); let mut state =
State::<S, M>::new(event_manager, return_tx, global_registry, async_runtime);
tracing::debug!(count = blocks.len(), "adding blocks to inactive pool"); tracing::debug!(count = blocks.len(), "adding blocks to inactive pool");
state.inactive.add_blocks(blocks); state.inactive.add_blocks(blocks);
......
...@@ -17,7 +17,7 @@ use super::*; ...@@ -17,7 +17,7 @@ use super::*;
use super::offload::OffloadManager; use super::offload::OffloadManager;
use super::{ use super::{
block::{Block, ImmutableBlock}, block::{Block, GlobalRegistry, ImmutableBlock},
config::NixlOptions, config::NixlOptions,
}; };
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
...@@ -76,6 +76,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -76,6 +76,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
// Create a map of NIXL backends // Create a map of NIXL backends
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new(); let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
let global_registry = GlobalRegistry::default();
// Create a NIXL agent if NIXL is enabled and instantiate requested backends // Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets // TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = Arc::new(match config.runtime.nixl { let nixl_agent = Arc::new(match config.runtime.nixl {
...@@ -123,6 +125,14 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -123,6 +125,14 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let mut next_block_set_idx = 0; let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id); let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout { let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() { if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks."); tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
...@@ -138,6 +148,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -138,6 +148,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx, next_block_set_idx,
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?; )?;
(Some(Arc::new(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} }
...@@ -158,6 +170,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -158,6 +170,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx, next_block_set_idx,
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?; )?;
(Some(Arc::new(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} else { } else {
...@@ -177,6 +191,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -177,6 +191,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx, next_block_set_idx,
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?; )?;
(Some(Arc::new(pool)), Some(blocks)) (Some(Arc::new(pool)), Some(blocks))
} else { } else {
...@@ -190,20 +206,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -190,20 +206,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
} }
let offload_async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let offload_manager = OffloadManager::new( let offload_manager = OffloadManager::new(
disk_pool.clone(), disk_pool.clone(),
host_pool.clone(), host_pool.clone(),
device_pool.clone(), device_pool.clone(),
nixl_agent.clone(), nixl_agent.clone(),
offload_async_rt_handle, async_rt_handle,
)?; )?;
let state = Arc::new(Self { let state = Arc::new(Self {
...@@ -484,10 +492,14 @@ fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>( ...@@ -484,10 +492,14 @@ fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
block_set_idx: usize, block_set_idx: usize,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
worker_id: WorkerID, worker_id: WorkerID,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> { ) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?; let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let pool = BlockPool::<S, M>::builder() let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token) .cancel_token(cancellation_token)
.global_registry(global_registry)
.async_runtime(async_runtime)
.build()?; .build()?;
Ok((pool, blocks)) Ok((pool, blocks))
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment