Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b813befa
Unverified
Commit
b813befa
authored
May 14, 2025
by
jthomson04
Committed by
GitHub
May 14, 2025
Browse files
feat: KV Cache Manager block offloading (#1030)
parent
29813508
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1058 additions
and
37 deletions
+1058
-37
lib/llm/src/block_manager.rs
lib/llm/src/block_manager.rs
+2
-0
lib/llm/src/block_manager/block.rs
lib/llm/src/block_manager/block.rs
+37
-6
lib/llm/src/block_manager/block/registry.rs
lib/llm/src/block_manager/block/registry.rs
+3
-0
lib/llm/src/block_manager/block/transfer.rs
lib/llm/src/block_manager/block/transfer.rs
+13
-3
lib/llm/src/block_manager/offload.rs
lib/llm/src/block_manager/offload.rs
+751
-0
lib/llm/src/block_manager/offload/pending.rs
lib/llm/src/block_manager/offload/pending.rs
+112
-0
lib/llm/src/block_manager/offload/request.rs
lib/llm/src/block_manager/offload/request.rs
+76
-0
lib/llm/src/block_manager/pool/state.rs
lib/llm/src/block_manager/pool/state.rs
+8
-0
lib/llm/src/block_manager/state.rs
lib/llm/src/block_manager/state.rs
+38
-10
lib/llm/src/block_manager/storage.rs
lib/llm/src/block_manager/storage.rs
+7
-5
lib/llm/src/block_manager/storage/cuda.rs
lib/llm/src/block_manager/storage/cuda.rs
+11
-13
No files found.
lib/llm/src/block_manager.rs
View file @
b813befa
...
...
@@ -25,6 +25,7 @@ mod state;
pub
mod
block
;
pub
mod
events
;
pub
mod
layout
;
pub
mod
offload
;
pub
mod
pool
;
pub
mod
storage
;
...
...
@@ -61,6 +62,7 @@ pub type WorkerID = u64;
pub
type
ReferenceBlockManager
=
KvBlockManager
<
BasicMetadata
>
;
/// Represents the different cache levels for KV blocks
#[derive(Copy,
Clone,
Debug,
Eq,
Hash,
PartialEq)]
pub
enum
CacheLevel
{
/// Represents KV blocks in GPU memory
G1
,
...
...
lib/llm/src/block_manager/block.rs
View file @
b813befa
...
...
@@ -24,7 +24,7 @@ use nixl_sys::NixlDescriptor;
pub
use
state
::{
BlockState
,
BlockStateInvalid
};
use
crate
::
block_manager
::{
state
::
{
KvBlockManagerState
as
BlockManager
,
TransferContext
},
state
::
KvBlockManagerState
as
BlockManager
,
storage
::{
Local
,
Remote
,
Storage
},
};
use
crate
::
tokens
::{
SaltHash
,
SequenceHash
,
Token
,
TokenBlock
,
Tokens
};
...
...
@@ -100,10 +100,6 @@ pub trait ReadableBlock: BlockDataProvider {
fn
storage_type_id
(
&
self
)
->
std
::
any
::
TypeId
{
std
::
any
::
TypeId
::
of
::
<<
Self
as
ReadableBlock
>
::
StorageType
>
()
}
fn
transfer_context
(
&
self
)
->
&
TransferContext
{
unimplemented!
()
}
}
pub
trait
ReadableBlocks
{}
...
...
@@ -683,10 +679,27 @@ pub struct ImmutableBlock<S: Storage, M: BlockMetadata> {
block
:
Arc
<
MutableBlock
<
S
,
M
>>
,
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Clone
for
ImmutableBlock
<
S
,
M
>
{
fn
clone
(
&
self
)
->
Self
{
Self
{
block
:
self
.block
.clone
(),
}
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
ImmutableBlock
<
S
,
M
>
{
pub
(
crate
)
fn
new
(
block
:
Arc
<
MutableBlock
<
S
,
M
>>
)
->
Self
{
Self
{
block
}
}
pub
fn
manager
(
&
self
)
->
Option
<&
Arc
<
BlockManager
<
M
>>>
{
// Access the underlying Block's manager field directly through deref
self
.manager
.as_ref
()
}
pub
fn
mutable_block
(
&
self
)
->
&
Arc
<
MutableBlock
<
S
,
M
>>
{
&
self
.block
}
}
impl
<
S
:
Storage
+
NixlDescriptor
,
M
:
BlockMetadata
>
ReadableBlock
for
ImmutableBlock
<
S
,
M
>
{
...
...
@@ -743,8 +756,17 @@ impl<'a, S: Storage, M: BlockMetadata> AsBlockSlice<'a, ImmutableBlock<S, M>>
}
}
pub
mod
nixl
{
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
ImmutableBlock
<
S
,
M
>
{
pub
async
fn
enqueue_offload
(
&
self
,
priority
:
u64
)
->
Result
<
()
>
{
// TODO: Is it ok to silently fail if the block is not managed?
if
let
Some
(
manager
)
=
self
.manager
()
{
manager
.enqueue_offload_block
(
self
,
priority
)
.await
?
;
}
Ok
(())
}
}
pub
mod
nixl
{
use
super
::
*
;
use
super
::
view
::{
BlockKind
,
Kind
,
LayerKind
};
...
...
@@ -1411,6 +1433,15 @@ pub mod nixl {
}
}
#[cfg(test)]
pub
mod
test_utils
{
use
super
::
private
::
PrivateToken
;
pub
fn
get_private_token
()
->
PrivateToken
{
PrivateToken
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
...
...
lib/llm/src/block_manager/block/registry.rs
View file @
b813befa
...
...
@@ -140,6 +140,8 @@ pub struct RegistrationHandle {
#[getter(skip)]
release_manager
:
Arc
<
dyn
EventReleaseManager
>
,
token_block
:
TokenBlock
,
}
impl
RegistrationHandle
{
...
...
@@ -152,6 +154,7 @@ impl RegistrationHandle {
sequence_hash
:
token_block
.sequence_hash
(),
parent_sequence_hash
:
token_block
.parent_sequence_hash
(),
release_manager
,
token_block
:
token_block
.clone
(),
}
}
}
...
...
lib/llm/src/block_manager/block/transfer.rs
View file @
b813befa
...
...
@@ -30,6 +30,7 @@ use cudarc::driver::CudaStream;
use
std
::
ops
::
Range
;
pub
use
crate
::
block_manager
::
state
::
TransferContext
;
pub
use
crate
::
block_manager
::
storage
::{
CudaAccessible
,
Local
,
Remote
};
pub
use
async_trait
::
async_trait
;
...
...
@@ -129,15 +130,24 @@ where
}
pub
trait
WriteTo
<
Target
>
{
fn
write_to
(
&
self
,
dst
:
&
mut
Target
,
notify
:
Option
<
String
>
)
->
Result
<
(),
TransferError
>
;
fn
write_to
(
&
self
,
dst
:
&
mut
Target
,
notify
:
Option
<
String
>
,
ctx
:
&
TransferContext
,
)
->
Result
<
(),
TransferError
>
;
}
impl
<
RB
:
ReadableBlock
,
WB
:
WritableBlock
>
WriteTo
<
WB
>
for
RB
where
RB
:
WriteToStrategy
<
WB
>
+
Local
,
{
fn
write_to
(
&
self
,
dst
:
&
mut
WB
,
notify
:
Option
<
String
>
)
->
Result
<
(),
TransferError
>
{
let
ctx
=
self
.transfer_context
();
fn
write_to
(
&
self
,
dst
:
&
mut
WB
,
notify
:
Option
<
String
>
,
ctx
:
&
TransferContext
,
)
->
Result
<
(),
TransferError
>
{
match
Self
::
write_to_strategy
()
{
TransferStrategy
::
Memcpy
=>
memcpy
::
copy_block
(
self
,
dst
),
TransferStrategy
::
CudaAsyncH2D
...
...
lib/llm/src/block_manager/offload.rs
0 → 100644
View file @
b813befa
This diff is collapsed.
Click to expand it.
lib/llm/src/block_manager/offload/pending.rs
0 → 100644
View file @
b813befa
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
sync
::
Arc
;
use
std
::
thread
::
spawn
;
use
tokio
::
sync
::
mpsc
;
use
crate
::
block_manager
::
block
::{
BlockMetadata
,
ImmutableBlock
,
MutableBlock
};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
storage
::
Storage
;
use
crate
::
block_manager
::
BlockPool
;
use
anyhow
::
Result
;
use
cudarc
::
driver
::
CudaEvent
;
type
OnboardResult
<
Target
,
Metadata
>
=
Result
<
Vec
<
ImmutableBlock
<
Target
,
Metadata
>>
,
BlockPoolError
>
;
/// Manage a set of pending transfers.
pub
struct
PendingTransfer
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
/// The block being copied from.
_
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
/// The block being copied to.
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
/// The Cuda event that indicates the completion of the transfer.
event
:
CudaEvent
,
/// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
completion_indicator
:
Option
<
oneshot
::
Sender
<
OnboardResult
<
Target
,
Metadata
>>>
,
/// The target pool that will receive the registered block.
target_pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
}
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
PendingTransfer
<
Source
,
Target
,
Metadata
>
{
pub
fn
new
(
sources
:
Vec
<
Arc
<
MutableBlock
<
Source
,
Metadata
>>>
,
targets
:
Vec
<
MutableBlock
<
Target
,
Metadata
>>
,
event
:
CudaEvent
,
completion_indicator
:
Option
<
oneshot
::
Sender
<
OnboardResult
<
Target
,
Metadata
>>>
,
target_pool
:
Arc
<
Option
<
BlockPool
<
Target
,
Metadata
>>>
,
)
->
Self
{
Self
{
_
sources
:
sources
,
targets
,
event
,
completion_indicator
,
target_pool
,
}
}
}
pub
struct
TransferManager
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
{
pending_transfer_q
:
mpsc
::
Sender
<
PendingTransfer
<
Source
,
Target
,
Metadata
>>
,
}
impl
<
Source
:
Storage
,
Target
:
Storage
,
Metadata
:
BlockMetadata
>
TransferManager
<
Source
,
Target
,
Metadata
>
{
pub
fn
new
(
max_depth
:
usize
)
->
Self
{
let
(
tx
,
mut
rx
)
=
mpsc
::
channel
::
<
PendingTransfer
<
Source
,
Target
,
Metadata
>>
(
max_depth
);
spawn
(
move
||
{
while
let
Some
(
pending_transfer
)
=
rx
.blocking_recv
()
{
// Wait for the event.
pending_transfer
.event
.synchronize
()
?
;
let
PendingTransfer
{
targets
,
target_pool
,
..
}
=
pending_transfer
;
if
let
Some
(
target_pool
)
=
target_pool
.as_ref
()
{
// Register the blocks in the new pool only AFTER the transfers have been completed.
// This way, we maintain the invariant that blocks that are registered in a pool
// are always available in that pool.
let
blocks
=
target_pool
.register_blocks_blocking
(
targets
)
?
;
if
let
Some
(
completion_indicator
)
=
pending_transfer
.completion_indicator
{
completion_indicator
.send
(
Ok
(
blocks
))
?
;
}
}
}
Ok
::
<
(),
anyhow
::
Error
>
(())
});
Self
{
pending_transfer_q
:
tx
,
}
}
pub
async
fn
handle_pending_transfer
(
&
self
,
pending_transfer
:
PendingTransfer
<
Source
,
Target
,
Metadata
>
,
)
->
Result
<
()
>
{
self
.pending_transfer_q
.send
(
pending_transfer
)
.await
?
;
Ok
(())
}
}
lib/llm/src/block_manager/offload/request.rs
0 → 100644
View file @
b813befa
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
cmp
::
Ordering
;
use
std
::
sync
::
Weak
;
use
crate
::
block_manager
::
block
::{
BlockMetadata
,
ImmutableBlock
,
MutableBlock
};
use
crate
::
block_manager
::
pool
::
BlockPoolError
;
use
crate
::
block_manager
::
storage
::
Storage
;
#[derive(PartialEq,
Eq,
Ord,
PartialOrd)]
pub
struct
OffloadRequestKey
{
pub
priority
:
u64
,
pub
timestamp
:
u64
,
}
/// Data needed to offload a block.
/// While the block is in the offload queue, we hold a weak reference to it.
/// This way, we don't prevent the block from being reused if needed.
pub
struct
OffloadRequest
<
S
:
Storage
,
M
:
BlockMetadata
>
{
pub
key
:
OffloadRequestKey
,
pub
block
:
Weak
<
MutableBlock
<
S
,
M
>>
,
pub
sequence_hash
:
u64
,
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
PartialOrd
for
OffloadRequest
<
S
,
M
>
{
fn
partial_cmp
(
&
self
,
other
:
&
Self
)
->
Option
<
Ordering
>
{
Some
(
self
.cmp
(
other
))
}
}
/// Order offload requests by priority, high to low.
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Ord
for
OffloadRequest
<
S
,
M
>
{
fn
cmp
(
&
self
,
other
:
&
Self
)
->
Ordering
{
self
.key
.cmp
(
&
other
.key
)
}
}
/// Equality is based on sequence hash, priority, and location.
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
PartialEq
for
OffloadRequest
<
S
,
M
>
{
fn
eq
(
&
self
,
other
:
&
Self
)
->
bool
{
self
.key
==
other
.key
}
}
impl
<
S
:
Storage
,
M
:
BlockMetadata
>
Eq
for
OffloadRequest
<
S
,
M
>
{}
pub
struct
OnboardRequest
<
Source
:
Storage
,
Target
:
Storage
,
M
:
BlockMetadata
>
{
pub
blocks
:
Vec
<
ImmutableBlock
<
Source
,
M
>>
,
pub
response_tx
:
oneshot
::
Sender
<
std
::
result
::
Result
<
Vec
<
ImmutableBlock
<
Target
,
M
>>
,
BlockPoolError
>>
,
}
impl
<
Source
:
Storage
,
Target
:
Storage
,
M
:
BlockMetadata
>
OnboardRequest
<
Source
,
Target
,
M
>
{
pub
fn
new
(
blocks
:
Vec
<
ImmutableBlock
<
Source
,
M
>>
,
response_tx
:
oneshot
::
Sender
<
Result
<
Vec
<
ImmutableBlock
<
Target
,
M
>>
,
BlockPoolError
>>
,
)
->
Self
{
Self
{
blocks
,
response_tx
,
}
}
}
lib/llm/src/block_manager/pool/state.rs
View file @
b813befa
...
...
@@ -147,6 +147,8 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
continue
;
}
let
mut
offload
=
true
;
let
mutable
=
if
let
Some
(
raw_block
)
=
self
.inactive
.match_sequence_hash
(
sequence_hash
)
{
assert
!
(
matches!
(
raw_block
.state
(),
BlockState
::
Registered
(
_
)));
...
...
@@ -164,6 +166,7 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
}
Err
(
BlockRegistationError
::
BlockAlreadyRegistered
(
_
))
=>
{
// Block is already registered, wait for it to be returned
offload
=
false
;
let
raw_block
=
self
.wait_for_returned_block
(
sequence_hash
,
return_rx
)
.await
;
MutableBlock
::
new
(
raw_block
,
self
.return_tx
.clone
())
...
...
@@ -176,6 +179,11 @@ impl<S: Storage, M: BlockMetadata> State<S, M> {
let
immutable
=
self
.active
.register
(
mutable
)
?
;
// TODO: Make a way to set meaningful priority values, and maybe don't enqueue offloads for every registered block.
if
offload
{
immutable
.enqueue_offload
(
0
)
.await
.unwrap
();
}
immutable_blocks
.push
(
immutable
);
}
...
...
lib/llm/src/block_manager/state.rs
View file @
b813befa
...
...
@@ -15,8 +15,12 @@
use
super
::
*
;
use
super
::{
block
::
Block
,
config
::
NixlOptions
};
use
super
::
offload
::
OffloadManager
;
use
super
::{
block
::{
Block
,
ImmutableBlock
},
config
::
NixlOptions
,
pool
::
BlockPoolError
,
};
use
cudarc
::
driver
::
CudaStream
;
use
std
::
sync
::
Arc
;
...
...
@@ -47,11 +51,13 @@ pub struct KvBlockManagerState<Metadata: BlockMetadata> {
nixl_agent
:
Option
<
NixlAgent
>
,
nixl_backends
:
HashMap
<
String
,
Arc
<
nixl_sys
::
Backend
>>
,
host_pool
:
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>
,
device_pool
:
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>
,
host_pool
:
Arc
<
Option
<
BlockPool
<
PinnedStorage
,
Metadata
>>
>
,
device_pool
:
Arc
<
Option
<
BlockPool
<
DeviceStorage
,
Metadata
>>
>
,
local_block_set
:
NixlBlockSet
,
remote_block_sets
:
RwLock
<
HashMap
<
WorkerID
,
HashMap
<
usize
,
RemoteBlocks
>>>
,
offload_manager
:
Arc
<
OffloadManager
<
Metadata
>>
,
}
impl
<
Metadata
:
BlockMetadata
>
KvBlockManagerState
<
Metadata
>
{
...
...
@@ -114,10 +120,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
worker_id
,
)
?
;
(
Some
(
pool
),
Some
(
blocks
))
(
Arc
::
new
(
Some
(
pool
)
)
,
Some
(
blocks
))
}
else
{
tracing
::
debug!
(
"No host layout provided; will not allocate host blocks."
);
(
None
,
None
)
(
Arc
::
new
(
None
)
,
None
)
};
// Create the device block pool if a device layout is provided
...
...
@@ -132,10 +138,10 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
cancellation_token
.clone
(),
worker_id
,
)
?
;
(
Some
(
pool
),
Some
(
blocks
))
(
Arc
::
new
(
Some
(
pool
)
)
,
Some
(
blocks
))
}
else
{
tracing
::
debug!
(
"No device layout provided; will not allocate device blocks."
);
(
None
,
None
)
(
Arc
::
new
(
None
)
,
None
)
};
// Finalize the local block set by adding NIXL metadata
...
...
@@ -144,6 +150,8 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
local_block_set
.set_nixl_metadata
(
nixl_agent
.get_local_md
()
?
);
}
let
offload_manager
=
OffloadManager
::
new
(
device_pool
.clone
(),
host_pool
.clone
())
?
;
let
state
=
Arc
::
new
(
Self
{
worker_id
,
cancellation_token
,
...
...
@@ -153,6 +161,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
device_pool
,
local_block_set
,
remote_block_sets
:
RwLock
::
new
(
HashMap
::
new
()),
offload_manager
,
});
if
let
Some
(
mut
blocks
)
=
host_blocks
{
...
...
@@ -163,6 +172,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state
.host_pool
.as_ref
()
.as_ref
()
.unwrap
()
.add_blocks_blocking
(
blocks
)
?
;
}
...
...
@@ -175,6 +185,7 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
state
.device_pool
.as_ref
()
.as_ref
()
.unwrap
()
.add_blocks_blocking
(
blocks
)
?
;
}
...
...
@@ -334,16 +345,33 @@ impl<Metadata: BlockMetadata> KvBlockManagerState<Metadata> {
}
pub
fn
host
(
&
self
)
->
Option
<&
BlockPool
<
PinnedStorage
,
Metadata
>>
{
self
.host_pool
.as_ref
()
self
.host_pool
.as_ref
()
.as_ref
()
}
pub
fn
device
(
&
self
)
->
Option
<&
BlockPool
<
DeviceStorage
,
Metadata
>>
{
self
.device_pool
.as_ref
()
self
.device_pool
.as_ref
()
.as_ref
()
}
pub
fn
worker_id
(
&
self
)
->
WorkerID
{
self
.worker_id
}
pub
(
crate
)
async
fn
enqueue_offload_block
<
S
:
Storage
+
'static
>
(
&
self
,
block
:
&
ImmutableBlock
<
S
,
Metadata
>
,
priority
:
u64
,
)
->
Result
<
()
>
{
self
.offload_manager
.offload
(
block
,
priority
)
.await
?
;
Ok
(())
}
pub
async
fn
onboard_blocks
(
&
self
,
blocks
:
Vec
<
ImmutableBlock
<
PinnedStorage
,
Metadata
>>
,
)
->
core
::
result
::
Result
<
Vec
<
ImmutableBlock
<
DeviceStorage
,
Metadata
>>
,
BlockPoolError
>
{
self
.offload_manager
.onboard
(
blocks
)
.await
}
}
impl
<
Metadata
:
BlockMetadata
>
std
::
fmt
::
Debug
for
KvBlockManagerState
<
Metadata
>
{
...
...
lib/llm/src/block_manager/storage.rs
View file @
b813befa
...
...
@@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![deny(missing_docs)]
// TODO: Add docs.
#![allow(missing_docs)]
//! # Storage Management
//!
...
...
@@ -121,7 +122,8 @@ pub trait Remote {}
/// Marker trait for [`Storage`] types that can be accessed by the standard
/// mechanisms of the system, e.g. `memcpy`, `memset`, etc.
pub
trait
SystemAccessible
:
Storage
{}
pub
trait
SystemAccessible
{}
pub
trait
CudaAccessible
{}
/// Errors that can occur during storage operations
#[derive(Debug,
Error)]
...
...
@@ -139,15 +141,15 @@ pub enum StorageError {
#[error(
"Storage operation failed: {0}"
)]
OperationFailed
(
String
),
#[error(
"CUDA error: {0}"
)]
Cuda
(
#[from]
cudarc
::
driver
::
DriverError
),
#[error(
"Registration key already exists: {0}"
)]
RegistrationKeyExists
(
String
),
#[error(
"Handle not found for key: {0}"
)]
HandleNotFound
(
String
),
#[error(
"CUDA error: {0}"
)]
CudaError
(
#[from]
cudarc
::
driver
::
DriverError
),
#[error(
"NIXL error: {0}"
)]
NixlError
(
#[from]
nixl_sys
::
NixlError
),
}
...
...
lib/llm/src/block_manager/storage/cuda.rs
View file @
b813befa
...
...
@@ -114,7 +114,7 @@ impl Cuda {
/// If the context does not exist, it will return None.
///
/// This will not lazily instantiate a context for a device. Use
/// [Cuda::
get_or_init_devic
e]
/// [Cuda::
device_or_creat
e]
pub
fn
device
(
device_id
:
usize
)
->
Option
<
Arc
<
CudaContext
>>
{
Cuda
::
instance
()
.lock
()
...
...
@@ -127,7 +127,7 @@ impl Cuda {
///
/// This will lazily instantiate a context for a device. Use
/// [CudaContextManager::device] to get an existing context.
pub
fn
get_or_init_devic
e
(
device_id
:
usize
)
->
Result
<
Arc
<
CudaContext
>
,
StorageError
>
{
pub
fn
device_or_creat
e
(
device_id
:
usize
)
->
Result
<
Arc
<
CudaContext
>
,
StorageError
>
{
Cuda
::
instance
()
.lock
()
.unwrap
()
.get_context
(
device_id
)
}
...
...
@@ -159,12 +159,12 @@ impl Cuda {
}
// Get a context if it exists, but don't create one
fn
get_existing_context
(
&
self
,
device_id
:
usize
)
->
Option
<
Arc
<
CudaContext
>>
{
pub
fn
get_existing_context
(
&
self
,
device_id
:
usize
)
->
Option
<
Arc
<
CudaContext
>>
{
self
.contexts
.get
(
&
device_id
)
.cloned
()
}
// Check if a context exists for a device
fn
has_context
(
&
self
,
device_id
:
usize
)
->
bool
{
pub
fn
has_context
(
&
self
,
device_id
:
usize
)
->
bool
{
self
.contexts
.contains_key
(
&
device_id
)
}
}
...
...
@@ -186,10 +186,10 @@ impl PinnedStorage {
/// Create a new pinned storage with the given size
pub
fn
new
(
ctx
:
&
Arc
<
CudaContext
>
,
size
:
usize
)
->
Result
<
Self
,
StorageError
>
{
unsafe
{
ctx
.bind_to_thread
()
.map_err
(
StorageError
::
Cuda
Error
)
?
;
ctx
.bind_to_thread
()
.map_err
(
StorageError
::
Cuda
)
?
;
let
ptr
=
cudarc
::
driver
::
result
::
malloc_host
(
size
,
sys
::
CU_MEMHOSTALLOC_WRITECOMBINED
)
.map_err
(
StorageError
::
Cuda
Error
)
?
;
.map_err
(
StorageError
::
Cuda
)
?
;
let
ptr
=
ptr
as
*
mut
u8
;
assert
!
(
!
ptr
.is_null
(),
"Failed to allocate pinned memory"
);
...
...
@@ -283,7 +283,7 @@ pub struct PinnedAllocator {
impl
Default
for
PinnedAllocator
{
fn
default
()
->
Self
{
Self
{
ctx
:
Cuda
::
get_or_init_devic
e
(
0
)
.expect
(
"Failed to create CUDA context"
),
ctx
:
Cuda
::
device_or_creat
e
(
0
)
.expect
(
"Failed to create CUDA context"
),
}
}
}
...
...
@@ -292,7 +292,7 @@ impl PinnedAllocator {
/// Create a new pinned allocator
pub
fn
new
()
->
Result
<
Self
,
StorageError
>
{
Ok
(
Self
{
ctx
:
Cuda
::
get_or_init_devic
e
(
0
)
?
,
ctx
:
Cuda
::
device_or_creat
e
(
0
)
?
,
})
}
}
...
...
@@ -318,9 +318,8 @@ impl CudaAccessible for DeviceStorage {}
impl
DeviceStorage
{
/// Create a new device storage with the given size
pub
fn
new
(
ctx
:
&
Arc
<
CudaContext
>
,
size
:
usize
)
->
Result
<
Self
,
StorageError
>
{
ctx
.bind_to_thread
()
.map_err
(
StorageError
::
CudaError
)
?
;
let
ptr
=
unsafe
{
cudarc
::
driver
::
result
::
malloc_sync
(
size
)
.map_err
(
StorageError
::
CudaError
)
?
};
ctx
.bind_to_thread
()
.map_err
(
StorageError
::
Cuda
)
?
;
let
ptr
=
unsafe
{
cudarc
::
driver
::
result
::
malloc_sync
(
size
)
.map_err
(
StorageError
::
Cuda
)
?
};
Ok
(
Self
{
ptr
,
...
...
@@ -406,11 +405,10 @@ impl DeviceAllocator {
/// Create a new device allocator
pub
fn
new
(
device_id
:
usize
)
->
Result
<
Self
,
StorageError
>
{
Ok
(
Self
{
ctx
:
Cuda
::
get_or_init_devic
e
(
device_id
)
?
,
ctx
:
Cuda
::
device_or_creat
e
(
device_id
)
?
,
})
}
/// Get the CUDA context
pub
fn
ctx
(
&
self
)
->
&
Arc
<
CudaContext
>
{
&
self
.ctx
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment