Unverified Commit b813befa authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KV Cache Manager block offloading (#1030)

parent 29813508
...@@ -25,6 +25,7 @@ mod state; ...@@ -25,6 +25,7 @@ mod state;
pub mod block; pub mod block;
pub mod events; pub mod events;
pub mod layout; pub mod layout;
pub mod offload;
pub mod pool; pub mod pool;
pub mod storage; pub mod storage;
...@@ -61,6 +62,7 @@ pub type WorkerID = u64; ...@@ -61,6 +62,7 @@ pub type WorkerID = u64;
pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>; pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>;
/// Represents the different cache levels for KV blocks /// Represents the different cache levels for KV blocks
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub enum CacheLevel { pub enum CacheLevel {
/// Represents KV blocks in GPU memory /// Represents KV blocks in GPU memory
G1, G1,
......
...@@ -24,7 +24,7 @@ use nixl_sys::NixlDescriptor; ...@@ -24,7 +24,7 @@ use nixl_sys::NixlDescriptor;
pub use state::{BlockState, BlockStateInvalid}; pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{ use crate::block_manager::{
state::{KvBlockManagerState as BlockManager, TransferContext}, state::KvBlockManagerState as BlockManager,
storage::{Local, Remote, Storage}, storage::{Local, Remote, Storage},
}; };
use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens}; use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens};
...@@ -100,10 +100,6 @@ pub trait ReadableBlock: BlockDataProvider { ...@@ -100,10 +100,6 @@ pub trait ReadableBlock: BlockDataProvider {
fn storage_type_id(&self) -> std::any::TypeId { fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>() std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>()
} }
fn transfer_context(&self) -> &TransferContext {
unimplemented!()
}
} }
pub trait ReadableBlocks {} pub trait ReadableBlocks {}
...@@ -683,10 +679,27 @@ pub struct ImmutableBlock<S: Storage, M: BlockMetadata> { ...@@ -683,10 +679,27 @@ pub struct ImmutableBlock<S: Storage, M: BlockMetadata> {
block: Arc<MutableBlock<S, M>>, block: Arc<MutableBlock<S, M>>,
} }
impl<S: Storage, M: BlockMetadata> Clone for ImmutableBlock<S, M> {
fn clone(&self) -> Self {
Self {
block: self.block.clone(),
}
}
}
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self { pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self {
Self { block } Self { block }
} }
pub fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
// Access the underlying Block's manager field directly through deref
self.manager.as_ref()
}
pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> {
&self.block
}
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> { impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> {
...@@ -743,8 +756,17 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> ...@@ -743,8 +756,17 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
} }
} }
pub mod nixl { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
// TODO: Is it ok to silently fail if the block is not managed?
if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?;
}
Ok(())
}
}
pub mod nixl {
use super::*; use super::*;
use super::view::{BlockKind, Kind, LayerKind}; use super::view::{BlockKind, Kind, LayerKind};
...@@ -1411,6 +1433,15 @@ pub mod nixl { ...@@ -1411,6 +1433,15 @@ pub mod nixl {
} }
} }
#[cfg(test)]
pub mod test_utils {
use super::private::PrivateToken;
pub fn get_private_token() -> PrivateToken {
PrivateToken
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -140,6 +140,8 @@ pub struct RegistrationHandle { ...@@ -140,6 +140,8 @@ pub struct RegistrationHandle {
#[getter(skip)] #[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>, release_manager: Arc<dyn EventReleaseManager>,
token_block: TokenBlock,
} }
impl RegistrationHandle { impl RegistrationHandle {
...@@ -152,6 +154,7 @@ impl RegistrationHandle { ...@@ -152,6 +154,7 @@ impl RegistrationHandle {
sequence_hash: token_block.sequence_hash(), sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(), parent_sequence_hash: token_block.parent_sequence_hash(),
release_manager, release_manager,
token_block: token_block.clone(),
} }
} }
} }
......
...@@ -30,6 +30,7 @@ use cudarc::driver::CudaStream; ...@@ -30,6 +30,7 @@ use cudarc::driver::CudaStream;
use std::ops::Range; use std::ops::Range;
pub use crate::block_manager::state::TransferContext;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote}; pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait; pub use async_trait::async_trait;
...@@ -129,15 +130,24 @@ where ...@@ -129,15 +130,24 @@ where
} }
pub trait WriteTo<Target> { pub trait WriteTo<Target> {
fn write_to(&self, dst: &mut Target, notify: Option<String>) -> Result<(), TransferError>; fn write_to(
&self,
dst: &mut Target,
notify: Option<String>,
ctx: &TransferContext,
) -> Result<(), TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
where where
RB: WriteToStrategy<WB> + Local, RB: WriteToStrategy<WB> + Local,
{ {
fn write_to(&self, dst: &mut WB, notify: Option<String>) -> Result<(), TransferError> { fn write_to(
let ctx = self.transfer_context(); &self,
dst: &mut WB,
notify: Option<String>,
ctx: &TransferContext,
) -> Result<(), TransferError> {
match Self::write_to_strategy() { match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst), TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
TransferStrategy::CudaAsyncH2D TransferStrategy::CudaAsyncH2D
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, Notify};
use super::block::{
transfer::WriteTo, BlockError, BlockExt, BlockMetadata, BlockState, ImmutableBlock,
MutableBlock,
};
use super::pool::BlockPoolError;
use super::state::TransferContext;
use super::storage::{Cuda, Storage};
use super::{BlockPool, DeviceStorage, PinnedStorage};
use anyhow::Result;
use cudarc::driver::sys::CUevent_flags;
use std::any::Any;
use std::collections::BTreeSet;
mod pending;
mod request;
use pending::{PendingTransfer, TransferManager};
use request::{OffloadRequest, OffloadRequestKey, OnboardRequest};
const MAX_OFFLOAD_STREAM_DEPTH: usize = 4;
/// The offload manager handles all block transfers between different cache levels.
pub struct OffloadManager<Metadata: BlockMetadata> {
// Handles to the device and host pools.
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
/// Priority queue of pending offloads
dtoh_offload_queue: Arc<Mutex<BTreeSet<OffloadRequest<DeviceStorage, Metadata>>>>,
/// Used to notify the offload worker that an item has been added to the priority queue
dtoh_offload_notify: Arc<Notify>,
/// An incrementing counter for offloaded blocks. Within the same priority, blocks with lower tick values are processed first.
tick: Arc<Mutex<u64>>,
/// Queue of pending onboarding requests.
htod_onboard_tx: mpsc::UnboundedSender<OnboardRequest<PinnedStorage, DeviceStorage, Metadata>>,
}
impl<Metadata: BlockMetadata> OffloadManager<Metadata> {
pub fn new(
device: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
host: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
) -> Result<Arc<Self>> {
let dtoh_offload_queue = Arc::new(Mutex::new(BTreeSet::new()));
let dtoh_offload_notify = Arc::new(Notify::new());
let (htod_onboard_tx, htod_onboard_rx) = mpsc::unbounded_channel();
let this = Arc::new(Self {
device,
host,
dtoh_offload_queue,
dtoh_offload_notify,
tick: Arc::new(Mutex::new(0)),
htod_onboard_tx,
});
let this_clone = this.clone();
// The offload and onboard workers must run in separate streams.
// Otherwise, we'd only be doing either an offload or onboard at a time, cutting our effective transfer bandwidth in half.
tokio::spawn(async move { this_clone.offload_worker().await });
let this_clone = this.clone();
tokio::spawn(async move { this_clone.onboard_worker(htod_onboard_rx).await });
Ok(this)
}
async fn update_target_metadata<Source: Storage, Target: Storage>(
source: &Arc<MutableBlock<Source, Metadata>>,
target: &mut MutableBlock<Target, Metadata>,
) -> Result<()> {
// Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
if let BlockState::Registered(reg_handle) = source.state() {
// Bring the block back to the 'Reset' state.
target.reset();
// Transfer metadata.
target.update_metadata(source.metadata().clone());
// Copy tokens
target.apply_token_block(reg_handle.token_block().clone())?;
} else {
Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)))?;
}
Ok(())
}
async fn offload_worker(&self) -> Result<()> {
// Since cuda memcpys in streams are async, this gets a bit tricky.
// We can't just consume the queue normally, otherwise the stream would become very backlogged.
// From the point when the a transfer is put into the stream until the transfer corresponding to the block is complete, we need to hold a strong reference to the block.
// If we don't do this, the block may be evicted and overwritten before the transfer is complete.
// To do this, we use a queue to track blocks currently being offloaded. Once the offload is complete (as indicated by a CudaEvent), the reference to the block is dropped.
if self.device.is_none() || self.host.is_none() {
return Ok(());
}
let cuda_ctx = Cuda::device_or_create(0)?;
let transfer_ctx = TransferContext::new(None, cuda_ctx.new_stream()?);
let device = self.device.as_ref().as_ref().unwrap();
let host = self.host.as_ref().as_ref().unwrap();
// We don't want to hold too many strong references to blocks in the device pool, since it would limit our effective KV Cache capacity.
// In this case, we limit it to just enough to ensure that a transfer is always occurring.
let dtoh_pending_offload_manager = TransferManager::new(MAX_OFFLOAD_STREAM_DEPTH);
loop {
// Try to check the offload queue.
let request = self.dtoh_offload_queue.lock().await.pop_first();
// If there is a request, process it.
if let Some(request) = request {
// Try to upgrade the block to a strong reference.
let block = match request.block.upgrade() {
Some(block) => Some(block),
// If unable to upgrade, the block may have been moved to the inactive pool.
None => device
.match_sequence_hashes(vec![request.sequence_hash].as_slice())
.await?
.pop()
.map(|block| block.mutable_block().clone()),
};
// If we've found the block, offload it to the host.
if let Some(block) = block {
// Allocate a block from the host pool.
// TODO: The most likely error here is that the host pool is full.
// It's probably not a good idea to keep consuming queue elements in the meantime.
let host_blocks = match host.allocate_blocks(1).await {
Ok(blocks) => blocks,
Err(_) => {
continue;
}
};
if let Some(mut host_block) = host_blocks.into_iter().next() {
// Enqueue the offload into the stream.
block.write_to(&mut host_block, None, &transfer_ctx)?;
// Record an event after the transfer is complete. Use the BLOCKING_SYNC flag to ensure the event is recorded synchronously on the host.
let event = transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
// Update block metadata and register with host pool.
OffloadManager::update_target_metadata(&block, &mut host_block).await?;
// Record the pending offload. This may block if too many offloads are already pending.
dtoh_pending_offload_manager
.handle_pending_transfer(PendingTransfer::new(
vec![block],
vec![host_block],
event,
None,
self.host.clone(),
))
.await?;
}
}
} else {
// If the queue is empty, wait to be notified.
self.dtoh_offload_notify.notified().await;
}
}
}
async fn onboard_worker(
&self,
mut htod_onboard_rx: mpsc::UnboundedReceiver<
OnboardRequest<PinnedStorage, DeviceStorage, Metadata>,
>,
) -> Result<()> {
if self.device.is_none() || self.host.is_none() {
return Ok(());
}
let cuda_ctx = Cuda::device_or_create(0)?;
let transfer_ctx = TransferContext::new(None, cuda_ctx.new_stream()?);
// For the onboarding manager, we can get away with a much bigger queue, since any onboardings would get triggered by an upcoming prefill.
let htod_pending_onboard_manager = TransferManager::new(16384);
let device = self.device.as_ref().as_ref().unwrap();
while let Some(request) = htod_onboard_rx.recv().await {
let mut device_blocks = match device.allocate_blocks(request.blocks.len()).await {
Ok(blocks) => blocks,
Err(err) => {
request.response_tx.send(Err(err))?;
continue;
}
};
for (host_block, device_block) in request.blocks.iter().zip(device_blocks.iter_mut()) {
host_block.write_to(device_block, None, &transfer_ctx)?;
OffloadManager::update_target_metadata(host_block.mutable_block(), device_block)
.await?;
}
// Record an event after all transfers are complete. See use of CU_EVENT_BLOCKING_SYNC in offload_worker.
let event = transfer_ctx
.stream()
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
let sources = request
.blocks
.iter()
.map(|b| b.mutable_block().clone())
.collect();
htod_pending_onboard_manager
.handle_pending_transfer(PendingTransfer::new(
sources,
device_blocks,
event,
Some(request.response_tx),
self.device.clone(),
))
.await?;
}
Ok(())
}
pub async fn offload<S: Storage>(
&self,
block: &ImmutableBlock<S, Metadata>,
priority: u64,
) -> core::result::Result<(), BlockPoolError> {
match block.state() {
BlockState::Registered(_) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)));
}
}
// This can get called by all pools, regardless of whether or not they have a place to offload to.
// Because of this, we need to check the block type here.
let any_block = block as &dyn Any;
// For now, only consider offloads from G1 (device) to G2 (host).
// TODO: What's the performance penalty of this runtime type-checking?
if let Some(device_block) =
any_block.downcast_ref::<ImmutableBlock<DeviceStorage, Metadata>>()
{
let mut tick = self.tick.lock().await;
let key = OffloadRequestKey {
priority,
timestamp: *tick,
};
// Increment a counter for each block. Within the same priority, blocks with lower counter values are processed first.
*tick += 1;
drop(tick);
let request = OffloadRequest {
block: Arc::downgrade(device_block.mutable_block()),
sequence_hash: device_block.sequence_hash()?,
key,
};
self.dtoh_offload_queue.lock().await.insert(request);
self.dtoh_offload_notify.notify_one();
}
Ok(())
}
pub async fn onboard(
&self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
for block in &blocks {
match block.state() {
BlockState::Registered(_) => {}
_ => {
return Err(BlockPoolError::BlockError(BlockError::InvalidState(
"Block is not registered.".to_string(),
)));
}
}
}
let (tx, rx) = oneshot::channel();
self.htod_onboard_tx
.send(OnboardRequest::new(blocks, tx))
.map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
match rx.await {
Ok(res) => res,
Err(_) => Err(BlockPoolError::ProgressEngineShutdown),
}
}
}
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
use crate::block_manager::block::test_utils::get_private_token;
use crate::block_manager::{
block::{BasicMetadata, BlockDataExt, BlockDataProvider, Blocks},
layout::FullyContiguous,
pool::BlockPool,
storage::{
cuda::CudaAccessible, DeviceAllocator, DeviceStorage, PinnedAllocator, PinnedStorage,
},
DType, LayoutConfig,
};
use nixl_sys::NixlDescriptor;
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind, cudaMemset};
const BLOCK_SIZE: usize = 4;
type DevicePool = Arc<Option<BlockPool<DeviceStorage, BasicMetadata>>>;
type HostPool = Arc<Option<BlockPool<PinnedStorage, BasicMetadata>>>;
fn build_pools(
device_blocks: usize,
host_blocks: Option<usize>,
) -> Result<(Arc<OffloadManager<BasicMetadata>>, DevicePool, HostPool)> {
let mut config = LayoutConfig {
num_blocks: device_blocks,
num_layers: 8,
page_size: BLOCK_SIZE,
inner_dim: 1024,
alignment: 1,
dtype: DType::FP16,
};
let device = FullyContiguous::allocate(config.clone(), &DeviceAllocator::default())?;
let device_blocks = Blocks::<_, BasicMetadata>::new(device, 42, 0)?.into_blocks()?;
let device_pool = Arc::new(Some(BlockPool::builder().blocks(device_blocks).build()?));
let host_pool = if let Some(host_blocks) = host_blocks {
config.num_blocks = host_blocks;
let host = FullyContiguous::allocate(config, &PinnedAllocator::default())?;
let host_blocks = Blocks::<_, BasicMetadata>::new(host, 42, 0)?.into_blocks()?;
Arc::new(Some(BlockPool::builder().blocks(host_blocks).build()?))
} else {
Arc::new(None)
};
let manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?;
Ok((manager, device_pool, host_pool))
}
/// Create a block in the 'RESET' state.
async fn get_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
) -> Result<MutableBlock<S, Metadata>> {
pool.allocate_blocks(1)
.await?
.into_iter()
.next()
.ok_or(anyhow::anyhow!("Failed to allocate block"))
}
/// Create a block in the 'PARTIAL' state.
async fn partial_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
token: u32,
) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?;
block.init_sequence(42)?;
block.add_token(token)?;
Ok(block)
}
/// Create a block in the 'COMPLETED' state.
async fn completed_block<S: Storage, Metadata: BlockMetadata>(
pool: &BlockPool<S, Metadata>,
tokens: [u32; BLOCK_SIZE],
) -> Result<MutableBlock<S, Metadata>> {
let mut block = get_block(pool).await?;
block.init_sequence(42)?;
for token in tokens {
block.add_token(token)?;
}
block.commit()?;
Ok(block)
}
fn populate_cuda_block<S: Storage + CudaAccessible + NixlDescriptor>(
block: &impl BlockDataProvider<StorageType = S>,
value: i32,
) -> Result<()> {
let block_data = block.block_data(get_private_token()).block_view()?;
let block_size = block_data.size();
unsafe {
cudaMemset(
block_data.as_ptr() as *mut std::ffi::c_void,
value,
block_size,
)
.result()?;
}
Ok(())
}
/// Compare the contents of a device block and a host block.
async fn compare_block_contents(
device_block: &impl BlockDataProvider<StorageType = DeviceStorage>,
host_block: &impl BlockDataProvider<StorageType = PinnedStorage>,
) -> Result<()> {
let host_data = host_block.block_data(get_private_token()).block_view()?;
let device_data = device_block.block_data(get_private_token()).block_view()?;
let size = host_data.size();
assert_eq!(size, device_data.size());
let mut host_buffer = vec![0u8; size];
let host_slice;
unsafe {
cudaMemcpy(
host_buffer.as_mut_ptr() as *mut std::ffi::c_void,
device_data.as_ptr() as *const std::ffi::c_void,
size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
.result()?;
host_slice = std::slice::from_raw_parts(host_buffer.as_ptr(), size);
}
assert_eq!(host_buffer, host_slice);
Ok(())
}
#[tokio::test]
async fn test_offload_invalid_blocks() -> Result<()> {
let (offload_manager, device_pool, _) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
// Check blocks in the 'RESET' state.
let immutable_block = ImmutableBlock::new(Arc::new(get_block(device_pool).await?));
assert!(matches!(
offload_manager.offload(&immutable_block, 0).await,
Err(BlockPoolError::BlockError(BlockError::InvalidState(_)))
));
// Check blocks in the 'PARTIAL' state.
let immutable_block = ImmutableBlock::new(Arc::new(partial_block(device_pool, 0).await?));
assert!(matches!(
offload_manager.offload(&immutable_block, 0).await,
Err(BlockPoolError::BlockError(BlockError::InvalidState(_)))
));
// Check blocks in the 'COMPLETED' state.
let immutable_block = ImmutableBlock::new(Arc::new(
completed_block(device_pool, [0; BLOCK_SIZE]).await?,
));
assert!(matches!(
offload_manager.offload(&immutable_block, 0).await,
Err(BlockPoolError::BlockError(BlockError::InvalidState(_)))
));
Ok(())
}
#[tokio::test]
async fn test_offload_registered_blocks() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
// Create a block and register it with the offload manager
let block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
.register_blocks(vec![block])
.await?
.into_iter()
.next()
.ok_or(anyhow::anyhow!("Failed to register block"))?;
populate_cuda_block(&immutable_device_block, 42)?;
// Offloads should only go to G2 (for now)
offload_manager.offload(&immutable_device_block, 0).await?;
// Wait for it to be processed.
// TODO: This is a bit of a hack, and may lead to non-deterministic behavior.
// In theory, the offload + memcpy should take much less time than this.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Check that the block exists in the host pool
let host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(host_blocks.len(), 1);
assert_eq!(
host_blocks[0].sequence_hash()?,
immutable_device_block.sequence_hash()?
);
compare_block_contents(&immutable_device_block, &host_blocks[0]).await?;
Ok(())
}
#[tokio::test]
async fn test_no_host_blocks_available() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let host_blocks = host_pool.allocate_blocks(4).await?;
assert_eq!(host_blocks.len(), 4);
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
offload_manager.offload(&immutable_device_block, 0).await?;
// Wait for offload to be processed.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// The offload should fail gracefuly due to a lack of host blocks
let matched_host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(matched_host_blocks.len(), 0);
// Wait for blocks to be returned to the pool.
drop(host_blocks);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Try the offload again.
offload_manager.offload(&immutable_device_block, 0).await?;
// Wait for offload to be processed.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// This time, the offload should succeed.
let matched_host_blocks = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(matched_host_blocks.len(), 1);
Ok(())
}
#[tokio::test]
async fn test_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
// Allocate and fill a block on the host.
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool
.register_blocks(vec![host_block])
.await?
.into_iter()
.next()
.unwrap();
populate_cuda_block(&immutable_host_block, 42)?;
// Onboard the block.
let onboarded_blocks = offload_manager
.onboard(vec![immutable_host_block.clone()])
.await?;
assert_eq!(onboarded_blocks.len(), 1);
// Check that the sequence hash is the same.
assert_eq!(
onboarded_blocks[0].sequence_hash()?,
immutable_host_block.sequence_hash()?
);
// Check that the block is registered.
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;
// Wait for the new value to show up in the device pool.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let device_blocks = device_pool
.match_sequence_hashes(vec![onboarded_blocks[0].sequence_hash()?].as_slice())
.await?;
assert_eq!(device_blocks.len(), 1);
assert_eq!(
device_blocks[0].sequence_hash()?,
onboarded_blocks[0].sequence_hash()?
);
// Check that this is the same block.
compare_block_contents(&device_blocks[0], &immutable_host_block).await?;
Ok(())
}
#[tokio::test]
async fn test_offload_onboard() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
populate_cuda_block(&immutable_device_block, 42)?;
// Offload the block to the host.
offload_manager.offload(&immutable_device_block, 0).await?;
// Wait for the offload to be processed.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Check that the block exists in the host pool.
let immutable_host_block = host_pool
.match_sequence_hashes(vec![immutable_device_block.sequence_hash()?].as_slice())
.await?
.into_iter()
.next()
.unwrap();
compare_block_contents(&immutable_device_block, &immutable_host_block).await?;
// Remove the device block from the pool by dropping it and allocating more blocks.
drop(immutable_device_block);
// Wait for the block to be returned to the pool.
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let device_blocks = device_pool.allocate_blocks(4).await?;
assert_eq!(device_blocks.len(), 4);
drop(device_blocks);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Check that the block is not in the device pool.
let device_blocks = device_pool
.match_sequence_hashes(vec![immutable_host_block.sequence_hash()?].as_slice())
.await?;
assert_eq!(device_blocks.len(), 0);
// Onboard the block back to the device pool.
let onboarded_blocks = offload_manager
.onboard(vec![immutable_host_block.clone()])
.await?;
assert_eq!(onboarded_blocks.len(), 1);
assert_eq!(
onboarded_blocks[0].sequence_hash()?,
immutable_host_block.sequence_hash()?
);
assert!(matches!(
onboarded_blocks[0].state(),
BlockState::Registered(_)
));
compare_block_contents(&onboarded_blocks[0], &immutable_host_block).await?;
Ok(())
}
#[tokio::test]
async fn test_onboard_err_handling() -> Result<()> {
let (offload_manager, device_pool, host_pool) = build_pools(4, Some(4))?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let host_pool = host_pool.as_ref().as_ref().unwrap();
let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
let immutable_host_block = host_pool
.register_blocks(vec![host_block])
.await?
.into_iter()
.next()
.unwrap();
let device_blocks = device_pool.allocate_blocks(4).await?;
assert_eq!(device_blocks.len(), 4);
let res = offload_manager
.onboard(vec![immutable_host_block.clone()])
.await;
assert!(matches!(
res.err().unwrap(),
BlockPoolError::NotEnoughBlocksAvailable(_, _)
));
Ok(())
}
#[tokio::test]
async fn test_offload_onboard_no_host_blocks() -> Result<()> {
let (offload_manager, device_pool, _) = build_pools(4, None)?;
let device_pool = device_pool.as_ref().as_ref().unwrap();
let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
let immutable_device_block = device_pool
.register_blocks(vec![device_block])
.await?
.into_iter()
.next()
.unwrap();
offload_manager.offload(&immutable_device_block, 0).await?;
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use std::thread::spawn;
use tokio::sync::mpsc;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
use crate::block_manager::BlockPool;
use anyhow::Result;
use cudarc::driver::CudaEvent;
type OnboardResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
/// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from.
_sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
/// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>,
/// The Cuda event that indicates the completion of the transfer.
event: CudaEvent,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
/// The target pool that will receive the registered block.
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
PendingTransfer<Source, Target, Metadata>
{
pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>,
event: CudaEvent,
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
) -> Self {
Self {
_sources: sources,
targets,
event,
completion_indicator,
target_pool,
}
}
}
pub struct TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<PendingTransfer<Source, Target, Metadata>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
TransferManager<Source, Target, Metadata>
{
pub fn new(max_depth: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<PendingTransfer<Source, Target, Metadata>>(max_depth);
spawn(move || {
while let Some(pending_transfer) = rx.blocking_recv() {
// Wait for the event.
pending_transfer.event.synchronize()?;
let PendingTransfer {
targets,
target_pool,
..
} = pending_transfer;
if let Some(target_pool) = target_pool.as_ref() {
// Register the blocks in the new pool only AFTER the transfers have been completed.
// This way, we maintain the invariant that blocks that are registered in a pool
// are always available in that pool.
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = pending_transfer.completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
}
Ok::<(), anyhow::Error>(())
});
Self {
pending_transfer_q: tx,
}
}
pub async fn handle_pending_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
self.pending_transfer_q.send(pending_transfer).await?;
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Ordering;
use std::sync::Weak;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
#[derive(PartialEq, Eq, Ord, PartialOrd)]
pub struct OffloadRequestKey {
pub priority: u64,
pub timestamp: u64,
}
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
pub struct OffloadRequest<S: Storage, M: BlockMetadata> {
pub key: OffloadRequestKey,
pub block: Weak<MutableBlock<S, M>>,
pub sequence_hash: u64,
}
impl<S: Storage, M: BlockMetadata> PartialOrd for OffloadRequest<S, M> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Order offload requests by priority, high to low.
impl<S: Storage, M: BlockMetadata> Ord for OffloadRequest<S, M> {
fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key)
}
}
/// Equality is based on sequence hash, priority, and location.
impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {}
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> {
pub blocks: Vec<ImmutableBlock<Source, M>>,
pub response_tx:
oneshot::Sender<std::result::Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
}
impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, Target, M> {
pub fn new(
blocks: Vec<ImmutableBlock<Source, M>>,
response_tx: oneshot::Sender<Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
) -> Self {
Self {
blocks,
response_tx,
}
}
}
...@@ -147,6 +147,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -147,6 +147,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
continue; continue;
} }
let mut offload = true;
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{ {
assert!(matches!(raw_block.state(), BlockState::Registered(_))); assert!(matches!(raw_block.state(), BlockState::Registered(_)));
...@@ -164,6 +166,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -164,6 +166,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
} }
Err(BlockRegistationError::BlockAlreadyRegistered(_)) => { Err(BlockRegistationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned // Block is already registered, wait for it to be returned
offload = false;
let raw_block = let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await; self.wait_for_returned_block(sequence_hash, return_rx).await;
MutableBlock::new(raw_block, self.return_tx.clone()) MutableBlock::new(raw_block, self.return_tx.clone())
...@@ -176,6 +179,11 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -176,6 +179,11 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self.active.register(mutable)?; let immutable = self.active.register(mutable)?;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if offload {
immutable.enqueue_offload(0).await.unwrap();
}
immutable_blocks.push(immutable); immutable_blocks.push(immutable);
} }
......
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
use super::*; use super::*;
use super::{block::Block, config::NixlOptions}; use super::offload::OffloadManager;
use super::{
block::{Block, ImmutableBlock},
config::NixlOptions,
pool::BlockPoolError,
};
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use std::sync::Arc; use std::sync::Arc;
...@@ -47,11 +51,13 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> { ...@@ -47,11 +51,13 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent: Option<NixlAgent>, nixl_agent: Option<NixlAgent>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>, nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
host_pool: Option<BlockPool<PinnedStorage, Metadata>>, host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Option<BlockPool<DeviceStorage, Metadata>>, device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
local_block_set: NixlBlockSet, local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>, remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
offload_manager: Arc<OffloadManager<Metadata>>,
} }
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...@@ -114,10 +120,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -114,10 +120,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Some(pool), Some(blocks)) (Arc::new(Some(pool)), Some(blocks))
} else { } else {
tracing::debug!("No host layout provided; will not allocate host blocks."); tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None) (Arc::new(None), None)
}; };
// Create the device block pool if a device layout is provided // Create the device block pool if a device layout is provided
...@@ -132,10 +138,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -132,10 +138,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Some(pool), Some(blocks)) (Arc::new(Some(pool)), Some(blocks))
} else { } else {
tracing::debug!("No device layout provided; will not allocate device blocks."); tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None) (Arc::new(None), None)
}; };
// Finalize the local block set by adding NIXL metadata // Finalize the local block set by adding NIXL metadata
...@@ -144,6 +150,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -144,6 +150,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
} }
let offload_manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?;
let state = Arc::new(Self { let state = Arc::new(Self {
worker_id, worker_id,
cancellation_token, cancellation_token,
...@@ -153,6 +161,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -153,6 +161,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool, device_pool,
local_block_set, local_block_set,
remote_block_sets: RwLock::new(HashMap::new()), remote_block_sets: RwLock::new(HashMap::new()),
offload_manager,
}); });
if let Some(mut blocks) = host_blocks { if let Some(mut blocks) = host_blocks {
...@@ -163,6 +172,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -163,6 +172,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state state
.host_pool .host_pool
.as_ref() .as_ref()
.as_ref()
.unwrap() .unwrap()
.add_blocks_blocking(blocks)?; .add_blocks_blocking(blocks)?;
} }
...@@ -175,6 +185,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -175,6 +185,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state state
.device_pool .device_pool
.as_ref() .as_ref()
.as_ref()
.unwrap() .unwrap()
.add_blocks_blocking(blocks)?; .add_blocks_blocking(blocks)?;
} }
...@@ -334,16 +345,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -334,16 +345,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
} }
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> { pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref() self.host_pool.as_ref().as_ref()
} }
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> { pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref() self.device_pool.as_ref().as_ref()
} }
pub fn worker_id(&self) -> WorkerID { pub fn worker_id(&self) -> WorkerID {
self.worker_id self.worker_id
} }
pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
&self,
block: &ImmutableBlock<S, Metadata>,
priority: u64,
) -> Result<()> {
self.offload_manager.offload(block, priority).await?;
Ok(())
}
pub async fn onboard_blocks(
&self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
self.offload_manager.onboard(blocks).await
}
} }
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> { impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> {
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![deny(missing_docs)] // TODO: Add docs.
#![allow(missing_docs)]
//! # Storage Management //! # Storage Management
//! //!
...@@ -121,7 +122,8 @@ pub trait Remote {} ...@@ -121,7 +122,8 @@ pub trait Remote {}
/// Marker trait for [`Storage`] types that can be accessed by the standard /// Marker trait for [`Storage`] types that can be accessed by the standard
/// mechanisms of the system, e.g. `memcpy`, `memset`, etc. /// mechanisms of the system, e.g. `memcpy`, `memset`, etc.
pub trait SystemAccessible: Storage {} pub trait SystemAccessible {}
pub trait CudaAccessible {}
/// Errors that can occur during storage operations /// Errors that can occur during storage operations
#[derive(Debug, Error)] #[derive(Debug, Error)]
...@@ -139,15 +141,15 @@ pub enum StorageError { ...@@ -139,15 +141,15 @@ pub enum StorageError {
#[error("Storage operation failed: {0}")] #[error("Storage operation failed: {0}")]
OperationFailed(String), OperationFailed(String),
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("Registration key already exists: {0}")] #[error("Registration key already exists: {0}")]
RegistrationKeyExists(String), RegistrationKeyExists(String),
#[error("Handle not found for key: {0}")] #[error("Handle not found for key: {0}")]
HandleNotFound(String), HandleNotFound(String),
#[error("CUDA error: {0}")]
CudaError(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")] #[error("NIXL error: {0}")]
NixlError(#[from] nixl_sys::NixlError), NixlError(#[from] nixl_sys::NixlError),
} }
......
...@@ -114,7 +114,7 @@ impl Cuda { ...@@ -114,7 +114,7 @@ impl Cuda {
/// If the context does not exist, it will return None. /// If the context does not exist, it will return None.
/// ///
/// This will not lazily instantiate a context for a device. Use /// This will not lazily instantiate a context for a device. Use
/// [Cuda::get_or_init_device] /// [Cuda::device_or_create]
pub fn device(device_id: usize) -> Option<Arc<CudaContext>> { pub fn device(device_id: usize) -> Option<Arc<CudaContext>> {
Cuda::instance() Cuda::instance()
.lock() .lock()
...@@ -127,7 +127,7 @@ impl Cuda { ...@@ -127,7 +127,7 @@ impl Cuda {
/// ///
/// This will lazily instantiate a context for a device. Use /// This will lazily instantiate a context for a device. Use
/// [CudaContextManager::device] to get an existing context. /// [CudaContextManager::device] to get an existing context.
pub fn get_or_init_device(device_id: usize) -> Result<Arc<CudaContext>, StorageError> { pub fn device_or_create(device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
Cuda::instance().lock().unwrap().get_context(device_id) Cuda::instance().lock().unwrap().get_context(device_id)
} }
...@@ -159,12 +159,12 @@ impl Cuda { ...@@ -159,12 +159,12 @@ impl Cuda {
} }
// Get a context if it exists, but don't create one // Get a context if it exists, but don't create one
fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> { pub fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> {
self.contexts.get(&device_id).cloned() self.contexts.get(&device_id).cloned()
} }
// Check if a context exists for a device // Check if a context exists for a device
fn has_context(&self, device_id: usize) -> bool { pub fn has_context(&self, device_id: usize) -> bool {
self.contexts.contains_key(&device_id) self.contexts.contains_key(&device_id)
} }
} }
...@@ -186,10 +186,10 @@ impl PinnedStorage { ...@@ -186,10 +186,10 @@ impl PinnedStorage {
/// Create a new pinned storage with the given size /// Create a new pinned storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> { pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
unsafe { unsafe {
ctx.bind_to_thread().map_err(StorageError::CudaError)?; ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED) let ptr = cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::CudaError)?; .map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8; let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory"); assert!(!ptr.is_null(), "Failed to allocate pinned memory");
...@@ -283,7 +283,7 @@ pub struct PinnedAllocator { ...@@ -283,7 +283,7 @@ pub struct PinnedAllocator {
impl Default for PinnedAllocator { impl Default for PinnedAllocator {
fn default() -> Self { fn default() -> Self {
Self { Self {
ctx: Cuda::get_or_init_device(0).expect("Failed to create CUDA context"), ctx: Cuda::device_or_create(0).expect("Failed to create CUDA context"),
} }
} }
} }
...@@ -292,7 +292,7 @@ impl PinnedAllocator { ...@@ -292,7 +292,7 @@ impl PinnedAllocator {
/// Create a new pinned allocator /// Create a new pinned allocator
pub fn new() -> Result<Self, StorageError> { pub fn new() -> Result<Self, StorageError> {
Ok(Self { Ok(Self {
ctx: Cuda::get_or_init_device(0)?, ctx: Cuda::device_or_create(0)?,
}) })
} }
} }
...@@ -318,9 +318,8 @@ impl CudaAccessible for DeviceStorage {} ...@@ -318,9 +318,8 @@ impl CudaAccessible for DeviceStorage {}
impl DeviceStorage { impl DeviceStorage {
/// Create a new device storage with the given size /// Create a new device storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> { pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
ctx.bind_to_thread().map_err(StorageError::CudaError)?; ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = let ptr = unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::Cuda)? };
unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::CudaError)? };
Ok(Self { Ok(Self {
ptr, ptr,
...@@ -406,11 +405,10 @@ impl DeviceAllocator { ...@@ -406,11 +405,10 @@ impl DeviceAllocator {
/// Create a new device allocator /// Create a new device allocator
pub fn new(device_id: usize) -> Result<Self, StorageError> { pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self { Ok(Self {
ctx: Cuda::get_or_init_device(device_id)?, ctx: Cuda::device_or_create(device_id)?,
}) })
} }
/// Get the CUDA context
pub fn ctx(&self) -> &Arc<CudaContext> { pub fn ctx(&self) -> &Arc<CudaContext> {
&self.ctx &self.ctx
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment