Unverified Commit b813befa authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat: KV Cache Manager block offloading (#1030)

parent 29813508
...@@ -25,6 +25,7 @@ mod state; ...@@ -25,6 +25,7 @@ mod state;
pub mod block; pub mod block;
pub mod events; pub mod events;
pub mod layout; pub mod layout;
pub mod offload;
pub mod pool; pub mod pool;
pub mod storage; pub mod storage;
...@@ -61,6 +62,7 @@ pub type WorkerID = u64; ...@@ -61,6 +62,7 @@ pub type WorkerID = u64;
pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>; pub type ReferenceBlockManager = KvBlockManager<BasicMetadata>;
/// Represents the different cache levels for KV blocks /// Represents the different cache levels for KV blocks
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub enum CacheLevel { pub enum CacheLevel {
/// Represents KV blocks in GPU memory /// Represents KV blocks in GPU memory
G1, G1,
......
...@@ -24,7 +24,7 @@ use nixl_sys::NixlDescriptor; ...@@ -24,7 +24,7 @@ use nixl_sys::NixlDescriptor;
pub use state::{BlockState, BlockStateInvalid}; pub use state::{BlockState, BlockStateInvalid};
use crate::block_manager::{ use crate::block_manager::{
state::{KvBlockManagerState as BlockManager, TransferContext}, state::KvBlockManagerState as BlockManager,
storage::{Local, Remote, Storage}, storage::{Local, Remote, Storage},
}; };
use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens}; use crate::tokens::{SaltHash, SequenceHash, Token, TokenBlock, Tokens};
...@@ -100,10 +100,6 @@ pub trait ReadableBlock: BlockDataProvider { ...@@ -100,10 +100,6 @@ pub trait ReadableBlock: BlockDataProvider {
fn storage_type_id(&self) -> std::any::TypeId { fn storage_type_id(&self) -> std::any::TypeId {
std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>() std::any::TypeId::of::<<Self as ReadableBlock>::StorageType>()
} }
fn transfer_context(&self) -> &TransferContext {
unimplemented!()
}
} }
pub trait ReadableBlocks {} pub trait ReadableBlocks {}
...@@ -683,10 +679,27 @@ pub struct ImmutableBlock<S: Storage, M: BlockMetadata> { ...@@ -683,10 +679,27 @@ pub struct ImmutableBlock<S: Storage, M: BlockMetadata> {
block: Arc<MutableBlock<S, M>>, block: Arc<MutableBlock<S, M>>,
} }
impl<S: Storage, M: BlockMetadata> Clone for ImmutableBlock<S, M> {
fn clone(&self) -> Self {
Self {
block: self.block.clone(),
}
}
}
impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self { pub(crate) fn new(block: Arc<MutableBlock<S, M>>) -> Self {
Self { block } Self { block }
} }
pub fn manager(&self) -> Option<&Arc<BlockManager<M>>> {
// Access the underlying Block's manager field directly through deref
self.manager.as_ref()
}
pub fn mutable_block(&self) -> &Arc<MutableBlock<S, M>> {
&self.block
}
} }
impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> { impl<S: Storage + NixlDescriptor, M: BlockMetadata> ReadableBlock for ImmutableBlock<S, M> {
...@@ -743,8 +756,17 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>> ...@@ -743,8 +756,17 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
} }
} }
pub mod nixl { impl<S: Storage, M: BlockMetadata> ImmutableBlock<S, M> {
pub async fn enqueue_offload(&self, priority: u64) -> Result<()> {
// TODO: Is it ok to silently fail if the block is not managed?
if let Some(manager) = self.manager() {
manager.enqueue_offload_block(self, priority).await?;
}
Ok(())
}
}
pub mod nixl {
use super::*; use super::*;
use super::view::{BlockKind, Kind, LayerKind}; use super::view::{BlockKind, Kind, LayerKind};
...@@ -1411,6 +1433,15 @@ pub mod nixl { ...@@ -1411,6 +1433,15 @@ pub mod nixl {
} }
} }
#[cfg(test)]
pub mod test_utils {
use super::private::PrivateToken;
pub fn get_private_token() -> PrivateToken {
PrivateToken
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -140,6 +140,8 @@ pub struct RegistrationHandle { ...@@ -140,6 +140,8 @@ pub struct RegistrationHandle {
#[getter(skip)] #[getter(skip)]
release_manager: Arc<dyn EventReleaseManager>, release_manager: Arc<dyn EventReleaseManager>,
token_block: TokenBlock,
} }
impl RegistrationHandle { impl RegistrationHandle {
...@@ -152,6 +154,7 @@ impl RegistrationHandle { ...@@ -152,6 +154,7 @@ impl RegistrationHandle {
sequence_hash: token_block.sequence_hash(), sequence_hash: token_block.sequence_hash(),
parent_sequence_hash: token_block.parent_sequence_hash(), parent_sequence_hash: token_block.parent_sequence_hash(),
release_manager, release_manager,
token_block: token_block.clone(),
} }
} }
} }
......
...@@ -30,6 +30,7 @@ use cudarc::driver::CudaStream; ...@@ -30,6 +30,7 @@ use cudarc::driver::CudaStream;
use std::ops::Range; use std::ops::Range;
pub use crate::block_manager::state::TransferContext;
pub use crate::block_manager::storage::{CudaAccessible, Local, Remote}; pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait; pub use async_trait::async_trait;
...@@ -129,15 +130,24 @@ where ...@@ -129,15 +130,24 @@ where
} }
pub trait WriteTo<Target> { pub trait WriteTo<Target> {
fn write_to(&self, dst: &mut Target, notify: Option<String>) -> Result<(), TransferError>; fn write_to(
&self,
dst: &mut Target,
notify: Option<String>,
ctx: &TransferContext,
) -> Result<(), TransferError>;
} }
impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB impl<RB: ReadableBlock, WB: WritableBlock> WriteTo<WB> for RB
where where
RB: WriteToStrategy<WB> + Local, RB: WriteToStrategy<WB> + Local,
{ {
fn write_to(&self, dst: &mut WB, notify: Option<String>) -> Result<(), TransferError> { fn write_to(
let ctx = self.transfer_context(); &self,
dst: &mut WB,
notify: Option<String>,
ctx: &TransferContext,
) -> Result<(), TransferError> {
match Self::write_to_strategy() { match Self::write_to_strategy() {
TransferStrategy::Memcpy => memcpy::copy_block(self, dst), TransferStrategy::Memcpy => memcpy::copy_block(self, dst),
TransferStrategy::CudaAsyncH2D TransferStrategy::CudaAsyncH2D
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use std::thread::spawn;
use tokio::sync::mpsc;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
use crate::block_manager::BlockPool;
use anyhow::Result;
use cudarc::driver::CudaEvent;
type OnboardResult<Target, Metadata> =
Result<Vec<ImmutableBlock<Target, Metadata>>, BlockPoolError>;
/// Manage a set of pending transfers.
pub struct PendingTransfer<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
/// The block being copied from.
_sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
/// The block being copied to.
targets: Vec<MutableBlock<Target, Metadata>>,
/// The Cuda event that indicates the completion of the transfer.
event: CudaEvent,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
/// The target pool that will receive the registered block.
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
PendingTransfer<Source, Target, Metadata>
{
pub fn new(
sources: Vec<Arc<MutableBlock<Source, Metadata>>>,
targets: Vec<MutableBlock<Target, Metadata>>,
event: CudaEvent,
completion_indicator: Option<oneshot::Sender<OnboardResult<Target, Metadata>>>,
target_pool: Arc<Option<BlockPool<Target, Metadata>>>,
) -> Self {
Self {
_sources: sources,
targets,
event,
completion_indicator,
target_pool,
}
}
}
pub struct TransferManager<Source: Storage, Target: Storage, Metadata: BlockMetadata> {
pending_transfer_q: mpsc::Sender<PendingTransfer<Source, Target, Metadata>>,
}
impl<Source: Storage, Target: Storage, Metadata: BlockMetadata>
TransferManager<Source, Target, Metadata>
{
pub fn new(max_depth: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<PendingTransfer<Source, Target, Metadata>>(max_depth);
spawn(move || {
while let Some(pending_transfer) = rx.blocking_recv() {
// Wait for the event.
pending_transfer.event.synchronize()?;
let PendingTransfer {
targets,
target_pool,
..
} = pending_transfer;
if let Some(target_pool) = target_pool.as_ref() {
// Register the blocks in the new pool only AFTER the transfers have been completed.
// This way, we maintain the invariant that blocks that are registered in a pool
// are always available in that pool.
let blocks = target_pool.register_blocks_blocking(targets)?;
if let Some(completion_indicator) = pending_transfer.completion_indicator {
completion_indicator.send(Ok(blocks))?;
}
}
}
Ok::<(), anyhow::Error>(())
});
Self {
pending_transfer_q: tx,
}
}
pub async fn handle_pending_transfer(
&self,
pending_transfer: PendingTransfer<Source, Target, Metadata>,
) -> Result<()> {
self.pending_transfer_q.send(pending_transfer).await?;
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Ordering;
use std::sync::Weak;
use crate::block_manager::block::{BlockMetadata, ImmutableBlock, MutableBlock};
use crate::block_manager::pool::BlockPoolError;
use crate::block_manager::storage::Storage;
#[derive(PartialEq, Eq, Ord, PartialOrd)]
pub struct OffloadRequestKey {
pub priority: u64,
pub timestamp: u64,
}
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
pub struct OffloadRequest<S: Storage, M: BlockMetadata> {
pub key: OffloadRequestKey,
pub block: Weak<MutableBlock<S, M>>,
pub sequence_hash: u64,
}
impl<S: Storage, M: BlockMetadata> PartialOrd for OffloadRequest<S, M> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// Order offload requests by priority, high to low.
impl<S: Storage, M: BlockMetadata> Ord for OffloadRequest<S, M> {
fn cmp(&self, other: &Self) -> Ordering {
self.key.cmp(&other.key)
}
}
/// Equality is based on sequence hash, priority, and location.
impl<S: Storage, M: BlockMetadata> PartialEq for OffloadRequest<S, M> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<S: Storage, M: BlockMetadata> Eq for OffloadRequest<S, M> {}
pub struct OnboardRequest<Source: Storage, Target: Storage, M: BlockMetadata> {
pub blocks: Vec<ImmutableBlock<Source, M>>,
pub response_tx:
oneshot::Sender<std::result::Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
}
impl<Source: Storage, Target: Storage, M: BlockMetadata> OnboardRequest<Source, Target, M> {
pub fn new(
blocks: Vec<ImmutableBlock<Source, M>>,
response_tx: oneshot::Sender<Result<Vec<ImmutableBlock<Target, M>>, BlockPoolError>>,
) -> Self {
Self {
blocks,
response_tx,
}
}
}
...@@ -147,6 +147,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -147,6 +147,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
continue; continue;
} }
let mut offload = true;
let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash) let mutable = if let Some(raw_block) = self.inactive.match_sequence_hash(sequence_hash)
{ {
assert!(matches!(raw_block.state(), BlockState::Registered(_))); assert!(matches!(raw_block.state(), BlockState::Registered(_)));
...@@ -164,6 +166,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -164,6 +166,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
} }
Err(BlockRegistationError::BlockAlreadyRegistered(_)) => { Err(BlockRegistationError::BlockAlreadyRegistered(_)) => {
// Block is already registered, wait for it to be returned // Block is already registered, wait for it to be returned
offload = false;
let raw_block = let raw_block =
self.wait_for_returned_block(sequence_hash, return_rx).await; self.wait_for_returned_block(sequence_hash, return_rx).await;
MutableBlock::new(raw_block, self.return_tx.clone()) MutableBlock::new(raw_block, self.return_tx.clone())
...@@ -176,6 +179,11 @@ impl<S: Storage, M: BlockMetadata> State<S, M> { ...@@ -176,6 +179,11 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let immutable = self.active.register(mutable)?; let immutable = self.active.register(mutable)?;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if offload {
immutable.enqueue_offload(0).await.unwrap();
}
immutable_blocks.push(immutable); immutable_blocks.push(immutable);
} }
......
...@@ -15,8 +15,12 @@ ...@@ -15,8 +15,12 @@
use super::*; use super::*;
use super::{block::Block, config::NixlOptions}; use super::offload::OffloadManager;
use super::{
block::{Block, ImmutableBlock},
config::NixlOptions,
pool::BlockPoolError,
};
use cudarc::driver::CudaStream; use cudarc::driver::CudaStream;
use std::sync::Arc; use std::sync::Arc;
...@@ -47,11 +51,13 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> { ...@@ -47,11 +51,13 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent: Option<NixlAgent>, nixl_agent: Option<NixlAgent>,
nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>, nixl_backends: HashMap<String, Arc<nixl_sys::Backend>>,
host_pool: Option<BlockPool<PinnedStorage, Metadata>>, host_pool: Arc<Option<BlockPool<PinnedStorage, Metadata>>>,
device_pool: Option<BlockPool<DeviceStorage, Metadata>>, device_pool: Arc<Option<BlockPool<DeviceStorage, Metadata>>>,
local_block_set: NixlBlockSet, local_block_set: NixlBlockSet,
remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>, remote_block_sets: RwLock<HashMap<WorkerID, HashMap<usize, RemoteBlocks>>>,
offload_manager: Arc<OffloadManager<Metadata>>,
} }
impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
...@@ -114,10 +120,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -114,10 +120,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Some(pool), Some(blocks)) (Arc::new(Some(pool)), Some(blocks))
} else { } else {
tracing::debug!("No host layout provided; will not allocate host blocks."); tracing::debug!("No host layout provided; will not allocate host blocks.");
(None, None) (Arc::new(None), None)
}; };
// Create the device block pool if a device layout is provided // Create the device block pool if a device layout is provided
...@@ -132,10 +138,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -132,10 +138,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token.clone(), cancellation_token.clone(),
worker_id, worker_id,
)?; )?;
(Some(pool), Some(blocks)) (Arc::new(Some(pool)), Some(blocks))
} else { } else {
tracing::debug!("No device layout provided; will not allocate device blocks."); tracing::debug!("No device layout provided; will not allocate device blocks.");
(None, None) (Arc::new(None), None)
}; };
// Finalize the local block set by adding NIXL metadata // Finalize the local block set by adding NIXL metadata
...@@ -144,6 +150,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -144,6 +150,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?); local_block_set.set_nixl_metadata(nixl_agent.get_local_md()?);
} }
let offload_manager = OffloadManager::new(device_pool.clone(), host_pool.clone())?;
let state = Arc::new(Self { let state = Arc::new(Self {
worker_id, worker_id,
cancellation_token, cancellation_token,
...@@ -153,6 +161,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -153,6 +161,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool, device_pool,
local_block_set, local_block_set,
remote_block_sets: RwLock::new(HashMap::new()), remote_block_sets: RwLock::new(HashMap::new()),
offload_manager,
}); });
if let Some(mut blocks) = host_blocks { if let Some(mut blocks) = host_blocks {
...@@ -163,6 +172,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -163,6 +172,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state state
.host_pool .host_pool
.as_ref() .as_ref()
.as_ref()
.unwrap() .unwrap()
.add_blocks_blocking(blocks)?; .add_blocks_blocking(blocks)?;
} }
...@@ -175,6 +185,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -175,6 +185,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state state
.device_pool .device_pool
.as_ref() .as_ref()
.as_ref()
.unwrap() .unwrap()
.add_blocks_blocking(blocks)?; .add_blocks_blocking(blocks)?;
} }
...@@ -334,16 +345,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> { ...@@ -334,16 +345,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
} }
pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> { pub fn host(&self) -> Option<&BlockPool<PinnedStorage, Metadata>> {
self.host_pool.as_ref() self.host_pool.as_ref().as_ref()
} }
pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> { pub fn device(&self) -> Option<&BlockPool<DeviceStorage, Metadata>> {
self.device_pool.as_ref() self.device_pool.as_ref().as_ref()
} }
pub fn worker_id(&self) -> WorkerID { pub fn worker_id(&self) -> WorkerID {
self.worker_id self.worker_id
} }
pub(crate) async fn enqueue_offload_block<S: Storage + 'static>(
&self,
block: &ImmutableBlock<S, Metadata>,
priority: u64,
) -> Result<()> {
self.offload_manager.offload(block, priority).await?;
Ok(())
}
pub async fn onboard_blocks(
&self,
blocks: Vec<ImmutableBlock<PinnedStorage, Metadata>>,
) -> core::result::Result<Vec<ImmutableBlock<DeviceStorage, Metadata>>, BlockPoolError> {
self.offload_manager.onboard(blocks).await
}
} }
impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> { impl<Metadata: BlockMetadata> std::fmt::Debug for KvBlockManagerState<Metadata> {
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![deny(missing_docs)] // TODO: Add docs.
#![allow(missing_docs)]
//! # Storage Management //! # Storage Management
//! //!
...@@ -121,7 +122,8 @@ pub trait Remote {} ...@@ -121,7 +122,8 @@ pub trait Remote {}
/// Marker trait for [`Storage`] types that can be accessed by the standard /// Marker trait for [`Storage`] types that can be accessed by the standard
/// mechanisms of the system, e.g. `memcpy`, `memset`, etc. /// mechanisms of the system, e.g. `memcpy`, `memset`, etc.
pub trait SystemAccessible: Storage {} pub trait SystemAccessible {}
pub trait CudaAccessible {}
/// Errors that can occur during storage operations /// Errors that can occur during storage operations
#[derive(Debug, Error)] #[derive(Debug, Error)]
...@@ -139,15 +141,15 @@ pub enum StorageError { ...@@ -139,15 +141,15 @@ pub enum StorageError {
#[error("Storage operation failed: {0}")] #[error("Storage operation failed: {0}")]
OperationFailed(String), OperationFailed(String),
#[error("CUDA error: {0}")]
Cuda(#[from] cudarc::driver::DriverError),
#[error("Registration key already exists: {0}")] #[error("Registration key already exists: {0}")]
RegistrationKeyExists(String), RegistrationKeyExists(String),
#[error("Handle not found for key: {0}")] #[error("Handle not found for key: {0}")]
HandleNotFound(String), HandleNotFound(String),
#[error("CUDA error: {0}")]
CudaError(#[from] cudarc::driver::DriverError),
#[error("NIXL error: {0}")] #[error("NIXL error: {0}")]
NixlError(#[from] nixl_sys::NixlError), NixlError(#[from] nixl_sys::NixlError),
} }
......
...@@ -114,7 +114,7 @@ impl Cuda { ...@@ -114,7 +114,7 @@ impl Cuda {
/// If the context does not exist, it will return None. /// If the context does not exist, it will return None.
/// ///
/// This will not lazily instantiate a context for a device. Use /// This will not lazily instantiate a context for a device. Use
/// [Cuda::get_or_init_device] /// [Cuda::device_or_create]
pub fn device(device_id: usize) -> Option<Arc<CudaContext>> { pub fn device(device_id: usize) -> Option<Arc<CudaContext>> {
Cuda::instance() Cuda::instance()
.lock() .lock()
...@@ -127,7 +127,7 @@ impl Cuda { ...@@ -127,7 +127,7 @@ impl Cuda {
/// ///
/// This will lazily instantiate a context for a device. Use /// This will lazily instantiate a context for a device. Use
/// [CudaContextManager::device] to get an existing context. /// [CudaContextManager::device] to get an existing context.
pub fn get_or_init_device(device_id: usize) -> Result<Arc<CudaContext>, StorageError> { pub fn device_or_create(device_id: usize) -> Result<Arc<CudaContext>, StorageError> {
Cuda::instance().lock().unwrap().get_context(device_id) Cuda::instance().lock().unwrap().get_context(device_id)
} }
...@@ -159,12 +159,12 @@ impl Cuda { ...@@ -159,12 +159,12 @@ impl Cuda {
} }
// Get a context if it exists, but don't create one // Get a context if it exists, but don't create one
fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> { pub fn get_existing_context(&self, device_id: usize) -> Option<Arc<CudaContext>> {
self.contexts.get(&device_id).cloned() self.contexts.get(&device_id).cloned()
} }
// Check if a context exists for a device // Check if a context exists for a device
fn has_context(&self, device_id: usize) -> bool { pub fn has_context(&self, device_id: usize) -> bool {
self.contexts.contains_key(&device_id) self.contexts.contains_key(&device_id)
} }
} }
...@@ -186,10 +186,10 @@ impl PinnedStorage { ...@@ -186,10 +186,10 @@ impl PinnedStorage {
/// Create a new pinned storage with the given size /// Create a new pinned storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> { pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
unsafe { unsafe {
ctx.bind_to_thread().map_err(StorageError::CudaError)?; ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED) let ptr = cudarc::driver::result::malloc_host(size, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
.map_err(StorageError::CudaError)?; .map_err(StorageError::Cuda)?;
let ptr = ptr as *mut u8; let ptr = ptr as *mut u8;
assert!(!ptr.is_null(), "Failed to allocate pinned memory"); assert!(!ptr.is_null(), "Failed to allocate pinned memory");
...@@ -283,7 +283,7 @@ pub struct PinnedAllocator { ...@@ -283,7 +283,7 @@ pub struct PinnedAllocator {
impl Default for PinnedAllocator { impl Default for PinnedAllocator {
fn default() -> Self { fn default() -> Self {
Self { Self {
ctx: Cuda::get_or_init_device(0).expect("Failed to create CUDA context"), ctx: Cuda::device_or_create(0).expect("Failed to create CUDA context"),
} }
} }
} }
...@@ -292,7 +292,7 @@ impl PinnedAllocator { ...@@ -292,7 +292,7 @@ impl PinnedAllocator {
/// Create a new pinned allocator /// Create a new pinned allocator
pub fn new() -> Result<Self, StorageError> { pub fn new() -> Result<Self, StorageError> {
Ok(Self { Ok(Self {
ctx: Cuda::get_or_init_device(0)?, ctx: Cuda::device_or_create(0)?,
}) })
} }
} }
...@@ -318,9 +318,8 @@ impl CudaAccessible for DeviceStorage {} ...@@ -318,9 +318,8 @@ impl CudaAccessible for DeviceStorage {}
impl DeviceStorage { impl DeviceStorage {
/// Create a new device storage with the given size /// Create a new device storage with the given size
pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> { pub fn new(ctx: &Arc<CudaContext>, size: usize) -> Result<Self, StorageError> {
ctx.bind_to_thread().map_err(StorageError::CudaError)?; ctx.bind_to_thread().map_err(StorageError::Cuda)?;
let ptr = let ptr = unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::Cuda)? };
unsafe { cudarc::driver::result::malloc_sync(size).map_err(StorageError::CudaError)? };
Ok(Self { Ok(Self {
ptr, ptr,
...@@ -406,11 +405,10 @@ impl DeviceAllocator { ...@@ -406,11 +405,10 @@ impl DeviceAllocator {
/// Create a new device allocator /// Create a new device allocator
pub fn new(device_id: usize) -> Result<Self, StorageError> { pub fn new(device_id: usize) -> Result<Self, StorageError> {
Ok(Self { Ok(Self {
ctx: Cuda::get_or_init_device(device_id)?, ctx: Cuda::device_or_create(device_id)?,
}) })
} }
/// Get the CUDA context
pub fn ctx(&self) -> &Arc<CudaContext> { pub fn ctx(&self) -> &Arc<CudaContext> {
&self.ctx &self.ctx
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment