Unverified Commit 3d40a692 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: Restructure kv manager block registration (#1093)

parent 7d0c9386
......@@ -21,6 +21,8 @@ pub mod view;
pub use crate::tokens::TokenBlockError;
pub use anyhow::Result;
use nixl_sys::NixlDescriptor;
pub use registry::{GlobalRegistry, RegistrationHandle};
pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{
......@@ -176,7 +178,7 @@ impl<S: Storage, M: BlockMetadata> Block<S, M> {
pub fn sequence_hash(&self) -> Result<SequenceHash, BlockError> {
match self.state() {
BlockState::Complete(state) => Ok(state.token_block().sequence_hash()),
BlockState::Registered(state) => Ok(state.sequence_hash()),
BlockState::Registered(state, _) => Ok(state.sequence_hash()),
_ => Err(BlockError::InvalidState(
"Block is not complete".to_string(),
)),
......@@ -250,14 +252,14 @@ pub(crate) trait PrivateBlockExt {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError>;
) -> Result<Option<PublishHandle>, registry::BlockRegistationError>;
}
impl<S: Storage, M: BlockMetadata> PrivateBlockExt for Block<S, M> {
fn register(
&mut self,
registry: &mut registry::BlockRegistry,
) -> Result<PublishHandle, registry::BlockRegistationError> {
) -> Result<Option<PublishHandle>, registry::BlockRegistationError> {
registry.register_block(&mut self.state)
}
}
......
......@@ -13,9 +13,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Cache Block Registration
//!
//! - This module is responsible for maintaining a registry of all blocks currently within a pool.
//! This consists of two components: A global registry of all blocks, and a per-pool registry of blocks.
//! - The global registry is a mapping of sequences hashes to registration handles. If two blocks in different pools
//! have the same sequence hash, then they will share the same registration handle. The global registry is shared across all pools.
//! - The per-pool registry is a mapping of sequence hashes to block handles. This is used to track which blocks are
//! currently within a specific pool. The block handle is unique across pools, and is used to track the block's lifetime.
//! - When a block is in the registered state, it has a unique block handle and a possibly shared registration handle.
//!
//! ## Workflow
//!
//! 1. When a block is registered into a pool, we create a unique block handle.
//! 2. We then check the global registry to see if the block already exists in any other pool.
//! 3. If it does, we use the existing registration handle. Otherwise, we create a new one.
//! 4. When the block handle is dropped, it means that the block is no longer in the pool.
//! 5. When the registration handle is dropped, it means that the block is no longer in any pool.
use std::{
collections::HashMap,
sync::{Arc, Weak},
sync::{Arc, Mutex, Weak},
};
use super::super::events::{EventManager, EventReleaseManager, PublishHandle};
......@@ -24,6 +42,9 @@ use super::state::BlockState;
use crate::tokens::{BlockHash, SequenceHash, TokenBlock};
use derive_getters::Getters;
use tokio::{runtime::Handle, sync::mpsc};
pub type GlobalRegistry = Arc<Mutex<HashMap<SequenceHash, Weak<RegistrationHandle>>>>;
#[derive(Debug, thiserror::Error)]
pub enum BlockRegistationError {
......@@ -34,27 +55,88 @@ pub enum BlockRegistationError {
InvalidState(String),
}
/// Error returned when an attempt is made to unregister a block that is still active.
#[derive(Debug, thiserror::Error)]
#[error("Failed to unregister block: {0}")]
pub struct UnregisterFailure(SequenceHash);
/// A block entry is a handle to a block that is registered in the pool.
/// On drop, we need to notify the pool that the block has been unregistered.
/// This is different than the registration handle, which is only dropped when the block is no longer in ANY pool.
#[derive(Debug)]
pub struct BlockHandle {
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
}
impl BlockHandle {
pub fn new(
sequence_hash: SequenceHash,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
) -> Self {
Self {
sequence_hash,
unregister_tx,
}
}
}
impl Drop for BlockHandle {
fn drop(&mut self) {
let _ = self.unregister_tx.send(self.sequence_hash);
}
}
#[derive()]
pub struct BlockRegistry {
blocks: HashMap<SequenceHash, Weak<RegistrationHandle>>,
blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>>,
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
unregister_tx: mpsc::UnboundedSender<SequenceHash>,
}
impl BlockRegistry {
pub fn new(event_manager: Arc<dyn EventManager>) -> Self {
pub fn new(
event_manager: Arc<dyn EventManager>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
let (unregister_tx, mut unregister_rx) = mpsc::unbounded_channel();
let blocks: Arc<Mutex<HashMap<SequenceHash, Weak<BlockHandle>>>> =
Arc::new(Mutex::new(HashMap::new()));
let blocks_clone = blocks.clone();
let global_registry_clone = global_registry.clone();
async_runtime.spawn(async move {
let blocks = blocks_clone;
let global_registry = global_registry_clone;
while let Some(sequence_hash) = unregister_rx.recv().await {
{
let mut blocks = blocks.lock().unwrap();
if let Some(handle) = blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
blocks.remove(&sequence_hash);
}
}
}
let mut global_registry = global_registry.lock().unwrap();
if let Some(entry) = global_registry.get(&sequence_hash) {
if entry.upgrade().is_none() {
global_registry.remove(&sequence_hash);
}
}
}
});
Self {
blocks: HashMap::new(),
blocks,
event_manager,
global_registry,
unregister_tx,
}
}
pub fn is_registered(&self, sequence_hash: SequenceHash) -> bool {
if let Some(handle) = self.blocks.get(&sequence_hash) {
let blocks = self.blocks.lock().unwrap();
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return true;
}
......@@ -65,7 +147,7 @@ impl BlockRegistry {
pub fn register_block(
&mut self,
block_state: &mut BlockState,
) -> Result<PublishHandle, BlockRegistationError> {
) -> Result<Option<PublishHandle>, BlockRegistationError> {
match block_state {
BlockState::Reset => Err(BlockRegistationError::InvalidState(
"Block is in Reset state".to_string(),
......@@ -76,47 +158,60 @@ impl BlockRegistry {
BlockState::Complete(state) => {
let sequence_hash = state.token_block().sequence_hash();
if let Some(handle) = self.blocks.get(&sequence_hash) {
let mut blocks = self.blocks.lock().unwrap();
// If an identical block already exists in this pool, return an error.
if let Some(handle) = blocks.get(&sequence_hash) {
if let Some(_handle) = handle.upgrade() {
return Err(BlockRegistationError::BlockAlreadyRegistered(sequence_hash));
}
}
// Create the [RegistrationHandle] and [PublishHandle]
let publish_handle =
Self::create_publish_handle(state.token_block(), self.event_manager.clone());
let reg_handle = publish_handle.remove_handle();
let mut publish_handle = None;
let block_handle =
Arc::new(BlockHandle::new(sequence_hash, self.unregister_tx.clone()));
// Insert the [RegistrationHandle] into the registry
self.blocks
.insert(sequence_hash, Arc::downgrade(&reg_handle));
let reg_handle = 'reg_block: {
// Now, check the global registry.
let mut global_registry = self.global_registry.lock().unwrap();
// If an identical block exists in other pool, use the same registration handle.
if let Some(handle) = global_registry.get(&sequence_hash) {
if let Some(handle) = handle.upgrade() {
break 'reg_block handle;
}
}
// Otherwise, create a new registration handle.
publish_handle = Some(Self::create_publish_handle(
state.token_block(),
self.event_manager.clone(),
));
let reg_handle = publish_handle.as_ref().unwrap().remove_handle();
// Insert the registration handle into the global registry.
global_registry.insert(sequence_hash, Arc::downgrade(&reg_handle));
reg_handle
};
blocks.insert(sequence_hash, Arc::downgrade(&block_handle));
// Update the [BlockState] to [BlockState::Registered]
let _ = std::mem::replace(block_state, BlockState::Registered(reg_handle));
let _ = std::mem::replace(
block_state,
BlockState::Registered(reg_handle, block_handle),
);
Ok(publish_handle)
}
BlockState::Registered(registered) => Err(
BlockState::Registered(registered, _) => Err(
BlockRegistationError::BlockAlreadyRegistered(registered.sequence_hash()),
),
}
}
pub fn unregister_block(
&mut self,
sequence_hash: SequenceHash,
) -> Result<(), UnregisterFailure> {
if let Some(handle) = self.blocks.get(&sequence_hash) {
if handle.upgrade().is_none() {
self.blocks.remove(&sequence_hash);
return Ok(());
} else {
return Err(UnregisterFailure(sequence_hash));
}
}
Ok(())
}
fn create_publish_handle(
token_block: &TokenBlock,
event_manager: Arc<dyn EventManager>,
......
......@@ -17,7 +17,7 @@ use std::sync::Arc;
use derive_getters::Getters;
use super::registry::RegistrationHandle;
use super::registry::{BlockHandle, RegistrationHandle};
use super::Result;
use crate::tokens::{PartialTokenBlock, SaltHash, Token, TokenBlock, Tokens};
......@@ -30,7 +30,7 @@ pub enum BlockState {
Reset,
Partial(PartialState),
Complete(CompleteState),
Registered(Arc<RegistrationHandle>),
Registered(Arc<RegistrationHandle>, Arc<BlockHandle>),
}
impl BlockState {
......@@ -109,7 +109,7 @@ impl BlockState {
BlockState::Reset => Some(0),
BlockState::Partial(state) => Some(state.block.len()),
BlockState::Complete(state) => Some(state.token_block.tokens().len()),
BlockState::Registered(_) => None,
BlockState::Registered(_, _) => None,
}
}
......@@ -127,14 +127,14 @@ impl BlockState {
BlockState::Reset => true,
BlockState::Partial(state) => state.block.is_empty(),
BlockState::Complete(_) => false, // Always full
BlockState::Registered(_) => false, // Always full
BlockState::Registered(_, _) => false, // Always full
}
}
/// Returns a reference to the underlying TokenBlock if the state is Complete or Registered.
pub fn tokens(&self) -> Option<&Tokens> {
match self {
BlockState::Reset | BlockState::Registered(_) => None,
BlockState::Reset | BlockState::Registered(_, _) => None,
BlockState::Partial(state) => Some(state.block.tokens()),
BlockState::Complete(state) => Some(state.token_block.tokens()),
}
......@@ -147,12 +147,12 @@ impl BlockState {
/// Returns true if the block is in the complete or registered state
pub fn is_complete(&self) -> bool {
matches!(self, BlockState::Complete(_) | BlockState::Registered(_))
matches!(self, BlockState::Complete(_) | BlockState::Registered(_, _))
}
/// Returns true if the block is in the registered state
pub fn is_registered(&self) -> bool {
matches!(self, BlockState::Registered(_state))
matches!(self, BlockState::Registered(_state, _))
}
}
......
......@@ -334,7 +334,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
priority: u64,
) -> core::result::Result<(), BlockPoolError> {
match block.state() {
BlockState::Registered(_) => {}
BlockState::Registered(_, _) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
......@@ -397,7 +397,7 @@ impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
) -> BlockResult<DeviceStorage, Metadata> {
for block in &blocks {
match block.state() {
BlockState::Registered(_) => {}
BlockState::Registered(_, _) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
......@@ -857,7 +857,7 @@ mod tests {
// Check that the block is registered.
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
BlockState::Registered(_, _)
));
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
......@@ -940,7 +940,7 @@ mod tests {
);
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
BlockState::Registered(_, _)
));
check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
......
......@@ -118,7 +118,7 @@ fn transfer_metadata<Source: Storage, Target: Storage, Metadata: BlockMetadata>(
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
if let BlockState::Registered(reg_handle, _) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
......
......@@ -70,6 +70,7 @@ pub use super::block::{ImmutableBlock, MutableBlock};
use super::block::{
nixl::short_type_name, registry::BlockRegistry, Block, BlockError, BlockMetadata,
GlobalRegistry,
};
use super::events::{EventManager, NullEventManager};
use super::storage::Storage;
......@@ -80,6 +81,7 @@ use std::{
collections::{BTreeSet, HashMap, VecDeque},
sync::{Arc, Weak},
};
use tokio::runtime::Handle;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::Result;
......@@ -116,15 +118,27 @@ pub struct BlockPoolArgs<S: Storage, M: BlockMetadata> {
#[builder(default)]
blocks: Vec<Block<S, M>>,
#[builder(default)]
global_registry: GlobalRegistry,
#[builder(default = "Handle::current()")]
async_runtime: Handle,
}
impl<S: Storage, M: BlockMetadata> BlockPoolArgsBuilder<S, M> {
pub fn build(self) -> anyhow::Result<BlockPool<S, M>> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve();
let (event_manager, cancel_token, blocks, global_registry, async_runtime) = args.dissolve();
tracing::info!("building block pool");
let pool = BlockPool::new(event_manager, cancel_token, blocks);
let pool = BlockPool::new(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
Ok(pool)
}
......@@ -200,9 +214,16 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
let (pool, progress_engine) =
Self::with_progress_engine(event_manager, cancel_token, blocks);
let (pool, progress_engine) = Self::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
// pool.runtime.handle().spawn(async move {
// let mut progress_engine = progress_engine;
......@@ -239,12 +260,21 @@ impl<S: Storage, M: BlockMetadata> BlockPool<S, M> {
event_manager: Arc<dyn EventManager>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> (Self, ProgressEngine<S, M>) {
let (priority_tx, priority_rx) = tokio::sync::mpsc::unbounded_channel();
let (ctrl_tx, ctrl_rx) = tokio::sync::mpsc::unbounded_channel();
let progress_engine =
ProgressEngine::<S, M>::new(event_manager, priority_rx, ctrl_rx, cancel_token, blocks);
let progress_engine = ProgressEngine::<S, M>::new(
event_manager,
priority_rx,
ctrl_rx,
cancel_token,
blocks,
global_registry,
async_runtime,
);
(
Self {
......@@ -468,9 +498,15 @@ mod tests {
self,
) -> anyhow::Result<(BlockPool<S, M>, ProgressEngine<S, M>)> {
let args = self.build_internal()?;
let (event_manager, cancel_token, blocks) = args.dissolve();
let (pool, progress_engine) =
BlockPool::with_progress_engine(event_manager, cancel_token, blocks);
let (event_manager, cancel_token, blocks, global_registry, async_runtime) =
args.dissolve();
let (pool, progress_engine) = BlockPool::with_progress_engine(
event_manager,
cancel_token,
blocks,
global_registry,
async_runtime,
);
Ok((pool, progress_engine))
}
......@@ -560,8 +596,14 @@ mod tests {
.into_blocks()
.unwrap();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create the BlockPool and add the blocks
let pool = BlockPool::builder().blocks(blocks).build().unwrap();
let pool = BlockPool::builder()
.blocks(blocks)
.async_runtime(async_runtime.handle().clone())
.build()
.unwrap();
// All blocks should be in the Reset/Empty state
// No blocks should match the expected sequence hash
......
......@@ -138,7 +138,7 @@ impl<S: Storage, M: BlockMetadata> InactiveBlockPool<S, M> {
block.reset();
self.uninitialized_set.push_back(block);
}
BlockState::Registered(state) => {
BlockState::Registered(state, _) => {
let sequence_hash = state.sequence_hash();
self.insert_with_sequence_hash(block, sequence_hash);
}
......@@ -603,6 +603,7 @@ pub(crate) mod tests {
pub fn create_blocks(
tokens: Tokens,
block_size: usize,
async_runtime: Handle,
) -> Vec<Block<NullDeviceStorage, TestMetadata>> {
let (token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
......@@ -615,7 +616,8 @@ pub(crate) mod tests {
let mut blocks = create_block_collection(num_blocks).into_blocks().unwrap();
let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager);
let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
// Iterate through the generated TokenBlocks and the template Blocks,
// setting the state and registering each one.
......@@ -645,6 +647,7 @@ pub(crate) mod tests {
tokens: Tokens,
block_size: usize,
pool: &mut InactiveBlockPool<NullDeviceStorage, TestMetadata>,
async_runtime: Handle,
) -> (Vec<Block<NullDeviceStorage, TestMetadata>>, usize) {
let (mut token_blocks, _partial_token_block) =
tokens.into_sequence(block_size, None).into_parts();
......@@ -657,7 +660,8 @@ pub(crate) mod tests {
let matched_block_count = matched_blocks.len();
let event_manager = NullEventManager::new();
let mut registry = BlockRegistry::new(event_manager);
let mut registry =
BlockRegistry::new(event_manager, GlobalRegistry::default(), async_runtime);
// all matched blocks should be in the complete or registered state
for block in &mut matched_blocks {
......@@ -697,6 +701,8 @@ pub(crate) mod tests {
fn test_block_pool_lifecycle() {
dynamo_runtime::logging::init();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
const PAGE_SIZE: usize = 2;
let mut pool = create_block_pool(10);
......@@ -715,7 +721,12 @@ pub(crate) mod tests {
let tokens = create_token_sequence(&[1, 2, 3, 4]);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool);
let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(),
PAGE_SIZE,
&mut pool,
async_runtime.handle().clone(),
);
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 0);
assert_eq!(pool.available_blocks(), 8);
......@@ -725,7 +736,12 @@ pub(crate) mod tests {
assert_eq!(pool.total_blocks(), 10);
assert_eq!(pool.available_blocks(), 10);
let (blocks, matched_block_count) = acquire_blocks(tokens.clone(), PAGE_SIZE, &mut pool);
let (blocks, matched_block_count) = acquire_blocks(
tokens.clone(),
PAGE_SIZE,
&mut pool,
async_runtime.handle().clone(),
);
assert_eq!(blocks.len(), 2);
assert_eq!(matched_block_count, 2);
assert_eq!(pool.available_blocks(), 8);
......@@ -745,9 +761,11 @@ pub(crate) mod tests {
fn test_basic_sequence_matching() {
let mut pool = InactiveBlockPool::new();
let async_runtime = tokio::runtime::Runtime::new().unwrap();
// Create a sequence of 4 tokens split into blocks of 2
let sequence = create_token_sequence(&[1, 2, 3, 4]);
let blocks = create_blocks(sequence, 2);
let blocks = create_blocks(sequence, 2, async_runtime.handle().clone());
assert_eq!(blocks.len(), 2);
// Match the blocks in sequence
......
......@@ -24,11 +24,13 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
fn new(
event_manager: Arc<dyn EventManager>,
return_tx: tokio::sync::mpsc::UnboundedSender<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
Self {
active: ActiveBlockPool::new(),
inactive: InactiveBlockPool::new(),
registry: BlockRegistry::new(event_manager.clone()),
registry: BlockRegistry::new(event_manager.clone(), global_registry, async_runtime),
return_tx,
event_manager,
}
......@@ -88,7 +90,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
return_rx: &mut tokio::sync::mpsc::UnboundedReceiver<Block<S, M>>,
) -> Block<S, M> {
while let Some(block) = return_rx.recv().await {
if matches!(block.state(), BlockState::Registered(handle) if handle.sequence_hash() == sequence_hash)
if matches!(block.state(), BlockState::Registered(handle, _) if handle.sequence_hash() == sequence_hash)
{
return block;
}
......@@ -151,7 +153,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{
assert!(matches!(raw_block.state(), BlockState::Registered(_)));
assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
MutableBlock::new(raw_block, self.return_tx.clone())
} else {
// Attempt to register the block
......@@ -161,7 +163,10 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
match result {
Ok(handle) => {
// Only create our publish handle if this block is new, and not transfered.
if let Some(handle) = handle {
publish_handles.take_handle(handle);
}
block
}
Err(BlockRegistationError::BlockAlreadyRegistered(_)) => {
......@@ -222,7 +227,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
};
// this assert allows us to skip the error checking on the active pool registration step
assert!(matches!(raw_block.state(), BlockState::Registered(_)));
assert!(matches!(raw_block.state(), BlockState::Registered(_, _)));
let mutable = MutableBlock::new(raw_block, self.return_tx.clone());
......@@ -255,9 +260,12 @@ impl<S: Storage, M: BlockMetadata> ProgressEngine<S, M> {
ctrl_rx: tokio::sync::mpsc::UnboundedReceiver<ControlRequest<S, M>>,
cancel_token: CancellationToken,
blocks: Vec<Block<S, M>>,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Self {
let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel();
let mut state = State::<S, M>::new(event_manager, return_tx);
let mut state =
State::<S, M>::new(event_manager, return_tx, global_registry, async_runtime);
tracing::debug!(count = blocks.len(), "adding blocks to inactive pool");
state.inactive.add_blocks(blocks);
......
......@@ -17,7 +17,7 @@ use super::*;
use super::offload::OffloadManager;
use super::{
block::{Block, ImmutableBlock},
block::{Block, GlobalRegistry, ImmutableBlock},
config::NixlOptions,
};
use cudarc::driver::CudaStream;
......@@ -76,6 +76,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
// Create a map of NIXL backends
let mut nixl_backends: HashMap<String, Arc<nixl_sys::Backend>> = HashMap::new();
let global_registry = GlobalRegistry::default();
// Create a NIXL agent if NIXL is enabled and instantiate requested backends
// TODO: Build a map of NIXL backends to block pools/sets
let nixl_agent = Arc::new(match config.runtime.nixl {
......@@ -123,6 +125,14 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
let mut next_block_set_idx = 0;
let mut local_block_set = block::nixl::NixlBlockSet::new(worker_id);
let async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let (disk_pool, disk_blocks) = if let Some(config) = config.disk_layout {
if nixl_agent.is_none() {
tracing::warn!("NIXL is disabled; will not allocate disk blocks.");
......@@ -138,6 +148,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?;
(Some(Arc::new(pool)), Some(blocks))
}
......@@ -158,6 +170,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
......@@ -177,6 +191,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
next_block_set_idx,
cancellation_token.clone(),
worker_id,
global_registry.clone(),
async_rt_handle.clone(),
)?;
(Some(Arc::new(pool)), Some(blocks))
} else {
......@@ -190,20 +206,12 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
}
let offload_async_rt_handle = match config.runtime.async_runtime {
Some(rt) => rt.handle().clone(),
None => match Handle::try_current() {
Ok(handle) => handle,
Err(e) => anyhow::bail!(e),
},
};
let offload_manager = OffloadManager::new(
disk_pool.clone(),
host_pool.clone(),
device_pool.clone(),
nixl_agent.clone(),
offload_async_rt_handle,
async_rt_handle,
)?;
let state = Arc::new(Self {
......@@ -484,10 +492,14 @@ fn create_block_pool<S: Storage + NixlRegisterableStorage, M: BlockMetadata>(
block_set_idx: usize,
cancellation_token: CancellationToken,
worker_id: WorkerID,
global_registry: GlobalRegistry,
async_runtime: Handle,
) -> Result<(BlockPool<S, M>, Vec<Block<S, M>>)> {
let blocks = block::layout_to_blocks::<_, M>(layout, block_set_idx, worker_id)?;
let pool = BlockPool::<S, M>::builder()
.cancel_token(cancellation_token)
.global_registry(global_registry)
.async_runtime(async_runtime)
.build()?;
Ok((pool, blocks))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment