Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
7677f74f
"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "5602dd2f6380e9e42fb120c3e8f14a5b6fd026fb"
Unverified
Commit
7677f74f
authored
May 29, 2025
by
Jacky
Committed by
GitHub
May 29, 2025
Browse files
feat: KVBM async Python bindings and Layer class (#1141)
parent
a0512bd1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
784 additions
and
245 deletions
+784
-245
lib/bindings/python/rust/llm/block_manager.rs
lib/bindings/python/rust/llm/block_manager.rs
+83
-26
lib/bindings/python/rust/llm/block_manager/block.rs
lib/bindings/python/rust/llm/block_manager/block.rs
+139
-146
lib/bindings/python/rust/llm/block_manager/block_list.rs
lib/bindings/python/rust/llm/block_manager/block_list.rs
+12
-20
lib/bindings/python/rust/llm/block_manager/dlpack.rs
lib/bindings/python/rust/llm/block_manager/dlpack.rs
+129
-0
lib/bindings/python/rust/llm/block_manager/layer.rs
lib/bindings/python/rust/llm/block_manager/layer.rs
+129
-0
lib/bindings/python/src/dynamo/_core.pyi
lib/bindings/python/src/dynamo/_core.pyi
+83
-1
lib/bindings/python/tests/test_block_manager.py
lib/bindings/python/tests/test_block_manager.py
+209
-52
No files found.
lib/bindings/python/rust/llm/block_manager.rs
View file @
7677f74f
...
@@ -14,18 +14,18 @@
...
@@ -14,18 +14,18 @@
// limitations under the License.
// limitations under the License.
#![cfg(feature
=
"block-manager"
)]
#![cfg(feature
=
"block-manager"
)]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use
super
::
*
;
use
super
::
*
;
use
pyo3
::
PyResult
;
use
pyo3
::
PyResult
;
use
tokio
;
mod
block
;
mod
block
;
mod
block_list
;
mod
block_list
;
mod
dlpack
;
mod
layer
;
/// Add bingings from this crate to the provided module
/// Add bingings from this crate to the provided module
pub
fn
add_to_module
(
m
:
&
Bound
<
'_
,
PyModule
>
)
->
PyResult
<
()
>
{
pub
fn
add_to_module
(
m
:
&
Bound
<
'_
,
PyModule
>
)
->
PyResult
<
()
>
{
m
.add_class
::
<
layer
::
Layer
>
()
?
;
m
.add_class
::
<
block
::
Block
>
()
?
;
m
.add_class
::
<
block
::
Block
>
()
?
;
m
.add_class
::
<
block_list
::
BlockList
>
()
?
;
m
.add_class
::
<
block_list
::
BlockList
>
()
?
;
m
.add_class
::
<
BlockManager
>
()
?
;
m
.add_class
::
<
BlockManager
>
()
?
;
...
@@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
...
@@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
#[pyclass]
#[pyclass]
pub
struct
BlockManager
{
pub
struct
BlockManager
{
// TODO: Can this be implicitly created and referenced?
tokio_runtime
:
tokio
::
runtime
::
Runtime
,
// Block manager
inner
:
Arc
<
dynamo_llm
::
block_manager
::
ReferenceBlockManager
>
,
inner
:
Arc
<
dynamo_llm
::
block_manager
::
ReferenceBlockManager
>
,
// TODO: Metadata should be stored in the block manager?
// TODO: Metadata should be stored in the block manager?
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
...
@@ -62,7 +59,7 @@ impl BlockManager {
...
@@ -62,7 +59,7 @@ impl BlockManager {
dynamo_llm
::
block_manager
::
KvManagerRuntimeConfig
::
builder
()
dynamo_llm
::
block_manager
::
KvManagerRuntimeConfig
::
builder
()
.worker_id
(
worker_id
)
.worker_id
(
worker_id
)
.build
()
.build
()
.
unwrap
()
,
.
map_err
(
to_pyerr
)
?
,
);
);
let
mut
model_config
=
dynamo_llm
::
block_manager
::
KvManagerModelConfig
::
builder
()
let
mut
model_config
=
dynamo_llm
::
block_manager
::
KvManagerModelConfig
::
builder
()
.num_layers
(
num_layer
)
.num_layers
(
num_layer
)
...
@@ -93,14 +90,17 @@ impl BlockManager {
...
@@ -93,14 +90,17 @@ impl BlockManager {
};
};
}
}
model_config
=
model_config
.dtype
(
dtype_
.clone
());
model_config
=
model_config
.dtype
(
dtype_
.clone
());
config
=
config
.model
(
model_config
.build
()
.
unwrap
()
);
config
=
config
.model
(
model_config
.build
()
.
map_err
(
to_pyerr
)
?
);
if
let
Some
(
host_num_blocks
)
=
host_num_blocks
{
if
let
Some
(
host_num_blocks
)
=
host_num_blocks
{
config
=
config
.host_layout
(
config
=
config
.host_layout
(
dynamo_llm
::
block_manager
::
KvManagerLayoutConfig
::
builder
()
dynamo_llm
::
block_manager
::
KvManagerLayoutConfig
::
builder
()
.num_blocks
(
host_num_blocks
)
.num_blocks
(
host_num_blocks
)
.allocator
(
dynamo_llm
::
block_manager
::
storage
::
PinnedAllocator
::
new
()
.unwrap
())
.allocator
(
dynamo_llm
::
block_manager
::
storage
::
PinnedAllocator
::
new
()
.map_err
(
to_pyerr
)
?
,
)
.build
()
.build
()
.
unwrap
()
,
.
map_err
(
to_pyerr
)
?
,
);
);
}
}
if
let
Some
(
device_num_blocks
)
=
device_num_blocks
{
if
let
Some
(
device_num_blocks
)
=
device_num_blocks
{
...
@@ -109,23 +109,22 @@ impl BlockManager {
...
@@ -109,23 +109,22 @@ impl BlockManager {
.num_blocks
(
device_num_blocks
)
.num_blocks
(
device_num_blocks
)
.allocator
(
.allocator
(
dynamo_llm
::
block_manager
::
storage
::
DeviceAllocator
::
new
(
device_id
)
dynamo_llm
::
block_manager
::
storage
::
DeviceAllocator
::
new
(
device_id
)
.
unwrap
()
,
.
map_err
(
to_pyerr
)
?
,
)
)
.build
()
.build
()
.
unwrap
()
,
.
map_err
(
to_pyerr
)
?
,
);
);
}
}
let
config
=
config
.build
()
.unwrap
();
let
config
=
config
.build
()
.map_err
(
to_pyerr
)
?
;
let
tokio_runtime
=
tokio
::
runtime
::
Builder
::
new_multi_thread
()
let
tokio_runtime
=
pyo3_async_runtimes
::
tokio
::
get_runtime
();
.enable_all
()
.build
()
.unwrap
();
let
block_manager
=
tokio_runtime
.block_on
(
async
{
dynamo_llm
::
block_manager
::
ReferenceBlockManager
::
new
(
config
)
.unwrap
()
});
Ok
(
BlockManager
{
Ok
(
BlockManager
{
tokio_runtime
:
tokio_runtime
,
inner
:
Arc
::
from
(
inner
:
Arc
::
from
(
block_manager
),
tokio_runtime
.block_on
(
async
{
dynamo_llm
::
block_manager
::
ReferenceBlockManager
::
new
(
config
)
})
.map_err
(
to_pyerr
)
?
,
),
dtype
:
dtype_
,
dtype
:
dtype_
,
device_id
:
device_id
,
device_id
:
device_id
,
})
})
...
@@ -135,9 +134,11 @@ impl BlockManager {
...
@@ -135,9 +134,11 @@ impl BlockManager {
let
blocks
=
self
let
blocks
=
self
.inner
.inner
.host
()
.host
()
.unwrap
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
"Host allocator not available"
)
})
?
.allocate_blocks_blocking
(
count
)
.allocate_blocks_blocking
(
count
)
.
unwrap
()
;
.
map_err
(
to_pyerr
)
?
;
// Wrap each block in an enum accounting for Pinned & Device block
// Wrap each block in an enum accounting for Pinned & Device block
let
blocks
=
blocks
let
blocks
=
blocks
.into_iter
()
.into_iter
()
...
@@ -150,13 +151,42 @@ impl BlockManager {
...
@@ -150,13 +151,42 @@ impl BlockManager {
))
))
}
}
#[pyo3(signature
=
(count))]
fn
allocate_host_blocks
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
,
count
:
usize
,
)
->
PyResult
<
Bound
<
'py
,
PyAny
>>
{
let
inner
=
self
.inner
.clone
();
let
dtype
=
self
.dtype
.clone
();
let
device_id
=
self
.device_id
;
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
blocks
=
inner
.host
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
"Host allocator not available"
)
})
?
.allocate_blocks
(
count
)
.await
.map_err
(
to_pyerr
)
?
;
// Wrap each block in an enum accounting for Pinned & Device block
let
blocks
=
blocks
.into_iter
()
.map
(|
b
|
block
::
BlockType
::
Pinned
(
b
))
.collect
();
Ok
(
block_list
::
BlockList
::
from_rust
(
blocks
,
dtype
,
device_id
))
})
}
fn
allocate_device_blocks_blocking
(
&
self
,
count
:
usize
)
->
PyResult
<
block_list
::
BlockList
>
{
fn
allocate_device_blocks_blocking
(
&
self
,
count
:
usize
)
->
PyResult
<
block_list
::
BlockList
>
{
let
blocks
=
self
let
blocks
=
self
.inner
.inner
.device
()
.device
()
.unwrap
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
"Device allocator not available"
)
})
?
.allocate_blocks_blocking
(
count
)
.allocate_blocks_blocking
(
count
)
.
unwrap
()
;
.
map_err
(
to_pyerr
)
?
;
// Wrap each block in an enum accounting for Pinned & Device block
// Wrap each block in an enum accounting for Pinned & Device block
let
blocks
=
blocks
let
blocks
=
blocks
.into_iter
()
.into_iter
()
...
@@ -168,4 +198,31 @@ impl BlockManager {
...
@@ -168,4 +198,31 @@ impl BlockManager {
self
.device_id
,
self
.device_id
,
))
))
}
}
#[pyo3(signature
=
(count))]
fn
allocate_device_blocks
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
,
count
:
usize
,
)
->
PyResult
<
Bound
<
'py
,
PyAny
>>
{
let
inner
=
self
.inner
.clone
();
let
dtype
=
self
.dtype
.clone
();
let
device_id
=
self
.device_id
;
pyo3_async_runtimes
::
tokio
::
future_into_py
(
py
,
async
move
{
let
blocks
=
inner
.device
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
"Device allocator not available"
)
})
?
.allocate_blocks
(
count
)
.await
.map_err
(
to_pyerr
)
?
;
// Wrap each block in an enum accounting for Pinned & Device block
let
blocks
=
blocks
.into_iter
()
.map
(|
b
|
block
::
BlockType
::
Device
(
b
))
.collect
();
Ok
(
block_list
::
BlockList
::
from_rust
(
blocks
,
dtype
,
device_id
))
})
}
}
}
lib/bindings/python/rust/llm/block_manager/block.rs
View file @
7677f74f
...
@@ -14,16 +14,14 @@
...
@@ -14,16 +14,14 @@
// limitations under the License.
// limitations under the License.
#![cfg(feature
=
"block-manager"
)]
#![cfg(feature
=
"block-manager"
)]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use
super
::
*
;
use
super
::
*
;
use
dlpark
::
prelude
::{
DataType
,
Device
,
ManagerCtx
,
ShapeAndStrides
,
ToTensor
};
use
pyo3
::{
ffi
::
c_str
,
prelude
::
IntoPy
,
types
::
PyTuple
,
PyObject
,
PyResult
,
Python
};
use
std
::
sync
::{
Arc
,
Mutex
};
use
dynamo_llm
::
block_manager
::
block
::
BlockDataExt
;
use
dynamo_llm
::
block_manager
::
block
::
BlockDataExt
;
use
pyo3
::{
types
::{
PyList
,
PyTuple
},
PyObject
,
PyResult
,
Python
,
};
use
std
::
sync
::{
Arc
,
Mutex
};
pub
enum
BlockType
{
pub
enum
BlockType
{
Pinned
(
Pinned
(
...
@@ -40,111 +38,14 @@ pub enum BlockType {
...
@@ -40,111 +38,14 @@ pub enum BlockType {
),
),
}
}
struct
DlPackTensor
{
block
:
Arc
<
Mutex
<
BlockType
>>
,
// TODO: Metadata should be stored in the block manager?
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
}
impl
ToTensor
for
DlPackTensor
{
fn
data_ptr
(
&
self
)
->
*
mut
std
::
ffi
::
c_void
{
let
mut
mutable_block
=
self
.block
.lock
()
.unwrap
();
let
ptr
=
match
&
mut
*
mutable_block
{
BlockType
::
Pinned
(
block
)
=>
{
let
mut
block_view_mut
=
block
.block_view_mut
()
.expect
(
"Failed to get mutable Pinned block view"
);
unsafe
{
block_view_mut
.as_mut_ptr
()
}
}
BlockType
::
Device
(
block
)
=>
{
let
mut
block_view_mut
=
block
.block_view_mut
()
.expect
(
"Failed to get mutable Device block view"
);
unsafe
{
block_view_mut
.as_mut_ptr
()
}
}
};
ptr
as
*
mut
std
::
ffi
::
c_void
}
fn
byte_offset
(
&
self
)
->
u64
{
0
}
fn
device
(
&
self
)
->
Device
{
let
mutable_block
=
self
.block
.lock
()
.unwrap
();
match
&*
mutable_block
{
BlockType
::
Pinned
(
_
)
=>
{
// TODO: Why torch does not support CPU_PINNED here?
/*Device {
device_type: DeviceType::CudaHost,
device_id: 0,
}*/
Device
::
CPU
}
BlockType
::
Device
(
_
)
=>
Device
::
cuda
(
self
.device_id
),
}
}
fn
dtype
(
&
self
)
->
DataType
{
// Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType
match
self
.dtype
{
dynamo_llm
::
common
::
dtype
::
DType
::
FP8
=>
{
// No direct FP8 equivalent, use U8 as closest alternative
DataType
::
U8
}
dynamo_llm
::
common
::
dtype
::
DType
::
FP16
=>
DataType
::
F16
,
dynamo_llm
::
common
::
dtype
::
DType
::
BF16
=>
DataType
::
BF16
,
dynamo_llm
::
common
::
dtype
::
DType
::
FP32
=>
DataType
::
F32
,
dynamo_llm
::
common
::
dtype
::
DType
::
U8
=>
DataType
::
U8
,
dynamo_llm
::
common
::
dtype
::
DType
::
U16
=>
DataType
::
U16
,
dynamo_llm
::
common
::
dtype
::
DType
::
U32
=>
DataType
::
U32
,
dynamo_llm
::
common
::
dtype
::
DType
::
U64
=>
DataType
::
U64
,
dynamo_llm
::
common
::
dtype
::
DType
::
I8
=>
DataType
::
I8
,
dynamo_llm
::
common
::
dtype
::
DType
::
I16
=>
DataType
::
I16
,
dynamo_llm
::
common
::
dtype
::
DType
::
I32
=>
DataType
::
I32
,
dynamo_llm
::
common
::
dtype
::
DType
::
I64
=>
DataType
::
I64
,
}
}
fn
shape_and_strides
(
&
self
)
->
ShapeAndStrides
{
let
mutable_block
=
self
.block
.lock
()
.unwrap
();
let
(
num_blocks
,
num_layers
,
page_size
,
inner_dim
)
=
match
&*
mutable_block
{
BlockType
::
Pinned
(
block
)
=>
(
block
.num_blocks
(),
block
.num_layers
(),
block
.page_size
(),
block
.inner_dim
(),
),
BlockType
::
Device
(
block
)
=>
(
block
.num_blocks
(),
block
.num_layers
(),
block
.page_size
(),
block
.inner_dim
(),
),
};
let
shape_i64
:
Vec
<
i64
>
=
vec!
[
num_blocks
as
i64
,
num_layers
as
i64
,
page_size
as
i64
,
inner_dim
as
i64
,
];
ShapeAndStrides
::
new_contiguous
(
&
shape_i64
)
}
}
/*impl Drop for DlPackTensor {
fn drop(&mut self) {
println!("Dropping DlPackTensor");
}
}*/
#[pyclass]
#[pyclass]
pub
struct
Block
{
pub
struct
Block
{
inner
:
Arc
<
Mutex
<
BlockType
>>
,
inner
:
Arc
<
Mutex
<
BlockType
>>
,
// TODO: Metadata should be stored in the block manager?
// TODO: Metadata should be stored in the block manager?
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
device_id
:
usize
,
// Python iterator state
py_itr_idx
:
usize
,
}
}
impl
Block
{
impl
Block
{
...
@@ -157,69 +58,161 @@ impl Block {
...
@@ -157,69 +58,161 @@ impl Block {
inner
:
block
,
inner
:
block
,
dtype
:
dtype
,
dtype
:
dtype
,
device_id
:
device_id
,
device_id
:
device_id
,
py_itr_idx
:
0
,
}
}
fn
num_layers
(
&
self
)
->
usize
{
let
mutable_block
=
self
.inner
.lock
()
.unwrap
();
match
&*
mutable_block
{
BlockType
::
Pinned
(
block
)
=>
block
.num_layers
(),
BlockType
::
Device
(
block
)
=>
block
.num_layers
(),
}
}
}
}
}
}
#[pymethods]
#[pymethods]
impl
Block
{
impl
Block
{
#[pyo3(signature
=
())]
fn
to_list
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
)
->
PyResult
<
Bound
<
'py
,
PyList
>>
{
let
layers
:
Vec
<
layer
::
Layer
>
=
(
0
..
self
.num_layers
())
.map
(|
layer_idx
|
{
layer
::
Layer
::
from_rust
(
self
.inner
.clone
(),
layer_idx
,
self
.dtype
.clone
(),
self
.device_id
,
)
})
.collect
();
PyList
::
new
(
py
,
layers
)
}
fn
__
len__
(
&
self
)
->
PyResult
<
usize
>
{
Ok
(
self
.num_layers
())
}
fn
__
getitem__
(
&
self
,
index
:
usize
)
->
PyResult
<
layer
::
Layer
>
{
let
num_layers
=
self
.num_layers
();
if
index
>=
num_layers
{
return
Err
(
pyo3
::
exceptions
::
PyIndexError
::
new_err
(
format!
(
"Index {} out of range for Block with {} layers"
,
index
,
num_layers
)));
}
let
layer
=
layer
::
Layer
::
from_rust
(
self
.inner
.clone
(),
index
,
self
.dtype
.clone
(),
self
.device_id
,
);
Ok
(
layer
)
}
fn
__
iter__
(
mut
slf
:
PyRefMut
<
'_
,
Self
>
)
->
PyResult
<
PyRefMut
<
'_
,
Self
>>
{
// Reset iterator index at the beginning of each iteration
// Use to_list() for iterating concurrently
slf
.py_itr_idx
=
0
;
Ok
(
slf
)
}
fn
__
next__
(
&
mut
self
)
->
PyResult
<
layer
::
Layer
>
{
if
self
.py_itr_idx
>=
self
.num_layers
()
{
return
Err
(
pyo3
::
exceptions
::
PyStopIteration
::
new_err
(
"No more items in Block"
,
));
}
let
layer
=
layer
::
Layer
::
from_rust
(
self
.inner
.clone
(),
self
.py_itr_idx
,
self
.dtype
.clone
(),
self
.device_id
,
);
self
.py_itr_idx
+=
1
;
Ok
(
layer
)
}
#[pyo3(signature
=
(stream=None,
max_version=None,
dl_device=None,
copy=None))]
#[pyo3(signature
=
(stream=None,
max_version=None,
dl_device=None,
copy=None))]
fn
__
dlpack__
(
fn
__
dlpack__
<
'py
>
(
&
self
,
&
self
,
py
:
Python
<
'py
>
,
stream
:
Option
<
PyObject
>
,
stream
:
Option
<
PyObject
>
,
max_version
:
Option
<
PyObject
>
,
max_version
:
Option
<
PyObject
>
,
dl_device
:
Option
<
PyObject
>
,
dl_device
:
Option
<
PyObject
>
,
copy
:
Option
<
bool
>
,
copy
:
Option
<
bool
>
,
)
->
PyResult
<
PyObject
>
{
)
->
PyResult
<
PyObject
>
{
//
Panic
if any arguments are provided
//
Return error
if any arguments are provided
if
stream
.is_some
()
{
if
stream
.is_some
()
{
panic!
(
"stream argument is not supported"
);
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"stream argument is not supported"
,
));
}
}
if
max_version
.is_some
()
{
if
max_version
.is_some
()
{
panic!
(
"max_version argument is not supported"
);
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"max_version argument is not supported"
,
));
}
}
if
dl_device
.is_some
()
{
if
dl_device
.is_some
()
{
panic!
(
"dl_device argument is not supported"
);
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"dl_device argument is not supported"
,
));
}
}
if
copy
.is_some
()
{
if
copy
.is_some
()
{
panic!
(
"copy argument is not supported"
);
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"copy argument is not supported"
,
));
}
}
// Create DLPack PyCapsule
// Extract all necessary data for dlpack
let
manager_ctx
=
ManagerCtx
::
new
(
DlPackTensor
{
let
ptr
:
*
mut
std
::
ffi
::
c_void
;
block
:
self
.inner
.clone
(),
let
num_blocks
:
i64
;
dtype
:
self
.dtype
.clone
(),
let
num_layers
:
i64
;
device_id
:
self
.device_id
,
let
num_outer_dims
:
i64
;
});
let
page_size
:
i64
;
let
py_capsule
=
Python
::
with_gil
(|
py
|
manager_ctx
.into_py
(
py
));
let
inner_dim
:
i64
;
Ok
(
py_capsule
)
{
}
let
mut
mutable_block
=
self
.inner
.lock
()
.unwrap
();
ptr
=
match
&
mut
*
mutable_block
{
fn
__
dlpack_device__
(
&
self
)
->
PyResult
<
Py
<
PyTuple
>>
{
BlockType
::
Pinned
(
block
)
=>
{
let
dlpack_device
=
Python
::
with_gil
(|
py
|
{
let
mut
block_view_mut
=
block
.block_view_mut
()
.map_err
(
to_pyerr
)
?
;
let
device_type_list
=
py
.eval
(
c_str!
(
"[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"
),
None
,
None
)
.unwrap
();
(
unsafe
{
block_view_mut
.as_mut_ptr
()
})
as
*
mut
std
::
ffi
::
c_void
let
device_type_enum
=
py
}
.import
(
"enum"
)
BlockType
::
Device
(
block
)
=>
{
.unwrap
()
let
mut
block_view_mut
=
block
.block_view_mut
()
.map_err
(
to_pyerr
)
?
;
.getattr
(
"Enum"
)
(
unsafe
{
block_view_mut
.as_mut_ptr
()
})
as
*
mut
std
::
ffi
::
c_void
.unwrap
()
}
.call1
((
"DLDeviceType"
,
device_type_list
))
};
.unwrap
();
(
num_blocks
,
num_layers
,
num_outer_dims
,
page_size
,
inner_dim
)
=
match
&*
mutable_block
{
let
block
=
self
.inner
.lock
()
.unwrap
();
BlockType
::
Pinned
(
block
)
=>
(
let
device_type
=
match
&*
block
{
block
.num_blocks
()
as
i64
,
BlockType
::
Pinned
(
_
)
=>
device_type_enum
.getattr
(
"CPU_PINNED"
)
.unwrap
(),
block
.num_layers
()
as
i64
,
BlockType
::
Device
(
_
)
=>
device_type_enum
.getattr
(
"CUDA"
)
.unwrap
(),
block
.num_outer_dims
()
as
i64
,
block
.page_size
()
as
i64
,
block
.inner_dim
()
as
i64
,
),
BlockType
::
Device
(
block
)
=>
(
block
.num_blocks
()
as
i64
,
block
.num_layers
()
as
i64
,
block
.num_outer_dims
()
as
i64
,
block
.page_size
()
as
i64
,
block
.inner_dim
()
as
i64
,
),
};
};
let
device_id
=
self
.device_id
.into_py
(
py
)
.into_bound
(
py
);
}
let
device
=
vec!
[
device_type
,
device_id
];
PyTuple
::
new
(
py
,
device
)
.unwrap
()
.unbind
()
// Create the DLPack tensor
});
dlpack
::
dlpack
(
Ok
(
dlpack_device
)
py
,
self
.inner
.clone
(),
ptr
,
vec!
[
num_blocks
,
num_layers
,
num_outer_dims
,
page_size
,
inner_dim
],
self
.dtype
.clone
(),
self
.device_id
,
)
}
}
}
/*impl Drop for Block {
#[pyo3(signature
=
())]
fn
drop(&mut self)
{
fn
__
dlpack_device__
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
)
->
PyResult
<
Bound
<
'py
,
PyTuple
>>
{
println!("Dropping Block");
dlpack
::
dlpack_device
(
py
,
self
.inner
.clone
(),
self
.device_id
)
}
}
}
*/
}
lib/bindings/python/rust/llm/block_manager/block_list.rs
View file @
7677f74f
...
@@ -14,11 +14,8 @@
...
@@ -14,11 +14,8 @@
// limitations under the License.
// limitations under the License.
#![cfg(feature
=
"block-manager"
)]
#![cfg(feature
=
"block-manager"
)]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use
super
::
*
;
use
super
::
*
;
use
pyo3
::{
types
::
PyList
,
PyResult
,
Python
};
use
pyo3
::{
types
::
PyList
,
PyResult
,
Python
};
use
std
::
sync
::{
Arc
,
Mutex
};
use
std
::
sync
::{
Arc
,
Mutex
};
...
@@ -52,16 +49,14 @@ impl BlockList {
...
@@ -52,16 +49,14 @@ impl BlockList {
#[pymethods]
#[pymethods]
impl
BlockList
{
impl
BlockList
{
fn
to_list
(
&
self
)
->
PyResult
<
Py
<
PyList
>>
{
#[pyo3(signature
=
())]
let
py_list
=
Python
::
with_gil
(|
py
|
{
fn
to_list
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
)
->
PyResult
<
Bound
<
'py
,
PyList
>>
{
let
blocks
:
Vec
<
block
::
Block
>
=
self
let
blocks
:
Vec
<
block
::
Block
>
=
self
.inner
.inner
.iter
()
.iter
()
.map
(|
b
|
block
::
Block
::
from_rust
(
b
.clone
(),
self
.dtype
.clone
(),
self
.device_id
))
.map
(|
b
|
block
::
Block
::
from_rust
(
b
.clone
(),
self
.dtype
.clone
(),
self
.device_id
))
.collect
();
.collect
();
PyList
::
new
(
py
,
blocks
)
.unwrap
()
.unbind
()
PyList
::
new
(
py
,
blocks
)
});
Ok
(
py_list
)
}
}
fn
__
len__
(
&
self
)
->
PyResult
<
usize
>
{
fn
__
len__
(
&
self
)
->
PyResult
<
usize
>
{
...
@@ -84,13 +79,10 @@ impl BlockList {
...
@@ -84,13 +79,10 @@ impl BlockList {
Ok
(
block
)
Ok
(
block
)
}
}
fn
__
iter__
(
slf
:
Py
<
Self
>
)
->
PyResult
<
Py
<
Self
>>
{
fn
__
iter__
(
mut
slf
:
PyRefMut
<
'_
,
Self
>
)
->
PyResult
<
PyRefMut
<
'_
,
Self
>>
{
Python
::
with_gil
(|
py
|
{
// Reset iterator index at the beginning of each iteration
let
mut
slf
=
slf
.borrow_mut
(
py
);
// Use to_list() for iterating concurrently
// Reset iterator index at the beginning of each iteration
slf
.py_itr_idx
=
0
;
// Use to_list() for iterating concurrently
slf
.py_itr_idx
=
0
;
});
Ok
(
slf
)
Ok
(
slf
)
}
}
...
...
lib/bindings/python/rust/llm/block_manager/dlpack.rs
0 → 100644
View file @
7677f74f
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg(feature
=
"block-manager"
)]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use
super
::
*
;
use
dlpark
::
prelude
::{
DataType
,
Device
,
ManagerCtx
,
ShapeAndStrides
,
ToTensor
};
use
pyo3
::{
ffi
::
c_str
,
prelude
::
IntoPy
,
types
::
PyTuple
,
PyObject
,
PyResult
,
Python
};
use
std
::
sync
::{
Arc
,
Mutex
};
struct
DlPackTensor
{
block
:
Arc
<
Mutex
<
block
::
BlockType
>>
,
ptr
:
*
mut
std
::
ffi
::
c_void
,
shape
:
Vec
<
i64
>
,
// TODO: Metadata should be stored in the block?
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
}
impl
ToTensor
for
DlPackTensor
{
fn
data_ptr
(
&
self
)
->
*
mut
std
::
ffi
::
c_void
{
self
.ptr
}
fn
byte_offset
(
&
self
)
->
u64
{
0
}
fn
device
(
&
self
)
->
Device
{
let
mutable_block
=
self
.block
.lock
()
.unwrap
();
match
&*
mutable_block
{
block
::
BlockType
::
Pinned
(
_
)
=>
{
// TODO: Why torch does not support CPU_PINNED here?
/*Device {
device_type: DeviceType::CudaHost,
device_id: 0,
}*/
Device
::
CPU
}
block
::
BlockType
::
Device
(
_
)
=>
Device
::
cuda
(
self
.device_id
),
}
}
fn
dtype
(
&
self
)
->
DataType
{
// Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType
match
self
.dtype
{
dynamo_llm
::
common
::
dtype
::
DType
::
FP8
=>
{
// No direct FP8 equivalent, use U8 as closest alternative
DataType
::
U8
}
dynamo_llm
::
common
::
dtype
::
DType
::
FP16
=>
DataType
::
F16
,
dynamo_llm
::
common
::
dtype
::
DType
::
BF16
=>
DataType
::
BF16
,
dynamo_llm
::
common
::
dtype
::
DType
::
FP32
=>
DataType
::
F32
,
dynamo_llm
::
common
::
dtype
::
DType
::
U8
=>
DataType
::
U8
,
dynamo_llm
::
common
::
dtype
::
DType
::
U16
=>
DataType
::
U16
,
dynamo_llm
::
common
::
dtype
::
DType
::
U32
=>
DataType
::
U32
,
dynamo_llm
::
common
::
dtype
::
DType
::
U64
=>
DataType
::
U64
,
dynamo_llm
::
common
::
dtype
::
DType
::
I8
=>
DataType
::
I8
,
dynamo_llm
::
common
::
dtype
::
DType
::
I16
=>
DataType
::
I16
,
dynamo_llm
::
common
::
dtype
::
DType
::
I32
=>
DataType
::
I32
,
dynamo_llm
::
common
::
dtype
::
DType
::
I64
=>
DataType
::
I64
,
}
}
fn
shape_and_strides
(
&
self
)
->
ShapeAndStrides
{
ShapeAndStrides
::
new_contiguous
(
&
self
.shape
)
}
}
/*impl Drop for DlPackTensor {
fn drop(&mut self) {
println!("Dropping DlPackTensor");
}
}*/
pub
fn
dlpack
<
'py
>
(
py
:
Python
<
'py
>
,
block
:
Arc
<
Mutex
<
block
::
BlockType
>>
,
ptr
:
*
mut
std
::
ffi
::
c_void
,
shape
:
Vec
<
i64
>
,
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
)
->
PyResult
<
PyObject
>
{
let
manager_ctx
=
ManagerCtx
::
new
(
DlPackTensor
{
block
:
block
,
ptr
:
ptr
,
shape
:
shape
,
dtype
:
dtype
,
device_id
:
device_id
,
});
let
py_capsule
=
manager_ctx
.into_py
(
py
);
Ok
(
py_capsule
)
}
pub
fn
dlpack_device
<
'py
>
(
py
:
Python
<
'py
>
,
block
:
Arc
<
Mutex
<
block
::
BlockType
>>
,
device_id
:
usize
,
)
->
PyResult
<
Bound
<
'py
,
PyTuple
>>
{
let
dev_type_list
=
py
.eval
(
c_str!
(
"[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"
),
None
,
None
)
.unwrap
();
let
dev_type_enum
=
py
.import
(
"enum"
)
.unwrap
()
.getattr
(
"Enum"
)
.unwrap
()
.call1
((
"DLDeviceType"
,
dev_type_list
))
.unwrap
();
let
dev_type
=
match
&*
block
.lock
()
.unwrap
()
{
block
::
BlockType
::
Pinned
(
_
)
=>
dev_type_enum
.getattr
(
"CPU_PINNED"
)
.unwrap
(),
block
::
BlockType
::
Device
(
_
)
=>
dev_type_enum
.getattr
(
"CUDA"
)
.unwrap
(),
};
let
dev_id
=
device_id
.into_py
(
py
)
.into_bound
(
py
);
let
dev
=
vec!
[
dev_type
,
dev_id
];
PyTuple
::
new
(
py
,
dev
)
}
lib/bindings/python/rust/llm/block_manager/layer.rs
0 → 100644
View file @
7677f74f
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg(feature
=
"block-manager"
)]
use
super
::
*
;
use
dynamo_llm
::
block_manager
::
block
::
BlockDataExt
;
use
pyo3
::{
types
::
PyTuple
,
PyObject
,
PyResult
,
Python
};
use
std
::
sync
::{
Arc
,
Mutex
};
// Layer struct that represents a layer within a block
#[pyclass]
pub
struct
Layer
{
inner
:
Arc
<
Mutex
<
block
::
BlockType
>>
,
layer_idx
:
usize
,
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
}
impl
Layer
{
pub
fn
from_rust
(
block
:
Arc
<
Mutex
<
block
::
BlockType
>>
,
layer_idx
:
usize
,
dtype
:
dynamo_llm
::
common
::
dtype
::
DType
,
device_id
:
usize
,
)
->
Self
{
Self
{
inner
:
block
,
layer_idx
,
dtype
,
device_id
,
}
}
}
#[pymethods]
impl
Layer
{
#[pyo3(signature
=
(stream=None,
max_version=None,
dl_device=None,
copy=None))]
fn
__
dlpack__
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
,
stream
:
Option
<
PyObject
>
,
max_version
:
Option
<
PyObject
>
,
dl_device
:
Option
<
PyObject
>
,
copy
:
Option
<
bool
>
,
)
->
PyResult
<
PyObject
>
{
// Return error if any arguments are provided
if
stream
.is_some
()
{
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"stream argument is not supported"
,
));
}
if
max_version
.is_some
()
{
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"max_version argument is not supported"
,
));
}
if
dl_device
.is_some
()
{
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"dl_device argument is not supported"
,
));
}
if
copy
.is_some
()
{
return
Err
(
pyo3
::
exceptions
::
PyNotImplementedError
::
new_err
(
"copy argument is not supported"
,
));
}
// Extract all necessary data for dlpack
let
ptr
:
*
mut
std
::
ffi
::
c_void
;
let
num_outer_dims
:
i64
;
let
page_size
:
i64
;
let
inner_dim
:
i64
;
{
let
mut
mutable_block
=
self
.inner
.lock
()
.unwrap
();
ptr
=
match
&
mut
*
mutable_block
{
block
::
BlockType
::
Pinned
(
block
)
=>
{
let
mut
layer_view_mut
=
block
.layer_view_mut
(
self
.layer_idx
,
0
)
.map_err
(
to_pyerr
)
?
;
(
unsafe
{
layer_view_mut
.as_mut_ptr
()
})
as
*
mut
std
::
ffi
::
c_void
}
block
::
BlockType
::
Device
(
block
)
=>
{
let
mut
layer_view_mut
=
block
.layer_view_mut
(
self
.layer_idx
,
0
)
.map_err
(
to_pyerr
)
?
;
(
unsafe
{
layer_view_mut
.as_mut_ptr
()
})
as
*
mut
std
::
ffi
::
c_void
}
};
(
num_outer_dims
,
page_size
,
inner_dim
)
=
match
&*
mutable_block
{
block
::
BlockType
::
Pinned
(
block
)
=>
(
block
.num_outer_dims
()
as
i64
,
block
.page_size
()
as
i64
,
block
.inner_dim
()
as
i64
,
),
block
::
BlockType
::
Device
(
block
)
=>
(
block
.num_outer_dims
()
as
i64
,
block
.page_size
()
as
i64
,
block
.inner_dim
()
as
i64
,
),
};
}
// Create the DLPack tensor
dlpack
::
dlpack
(
py
,
self
.inner
.clone
(),
ptr
,
vec!
[
1
,
1
,
num_outer_dims
,
page_size
,
inner_dim
],
self
.dtype
.clone
(),
self
.device_id
,
)
}
#[pyo3(signature
=
())]
fn
__
dlpack_device__
<
'py
>
(
&
self
,
py
:
Python
<
'py
>
)
->
PyResult
<
Bound
<
'py
,
PyTuple
>>
{
dlpack
::
dlpack_device
(
py
,
self
.inner
.clone
(),
self
.device_id
)
}
}
lib/bindings/python/src/dynamo/_core.pyi
View file @
7677f74f
...
@@ -710,6 +710,25 @@ class NatsQueue:
...
@@ -710,6 +710,25 @@ class NatsQueue:
"""
"""
...
...
class Layer:
"""
A KV cache block layer
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
"""
Get a dlpack capsule of the layer
"""
...
def __dlpack_device__(self) -> Any:
"""
Get the dlpack device of the layer
"""
...
class Block:
class Block:
"""
"""
A KV cache block
A KV cache block
...
@@ -717,9 +736,40 @@ class Block:
...
@@ -717,9 +736,40 @@ class Block:
...
...
def __len__(self) -> int:
"""
Get the number of layers in the list
"""
...
def __getitem__(self, index: int) -> Layer:
"""
Get a layer by index
"""
...
def __iter__(self) -> 'Block':
"""
Get an iterator over the layers
"""
...
def __next__(self) -> Block:
"""
Get the next layer in the iterator
"""
...
def to_list(self) -> List[Layer]:
"""
Get a list of layers
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
"""
"""
Get a dlpack capsule from the block
Get a dlpack capsule of the block
Exception raised if the block is not contiguous
"""
"""
...
...
...
@@ -822,6 +872,22 @@ class BlockManager:
...
@@ -822,6 +872,22 @@ class BlockManager:
"""
"""
...
...
async def allocate_host_blocks(self, count: int) -> BlockList:
"""
Allocate a list of host blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
def allocate_device_blocks_blocking(self, count: int) -> BlockList:
def allocate_device_blocks_blocking(self, count: int) -> BlockList:
"""
"""
Allocate a list of device blocks (blocking call)
Allocate a list of device blocks (blocking call)
...
@@ -837,3 +903,19 @@ class BlockManager:
...
@@ -837,3 +903,19 @@ class BlockManager:
List of allocated blocks
List of allocated blocks
"""
"""
...
...
async def allocate_device_blocks(self, count: int) -> BlockList:
"""
Allocate a list of device blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
lib/bindings/python/tests/test_block_manager.py
View file @
7677f74f
...
@@ -35,9 +35,7 @@ DEVICE_NUM_BLOCKS = 16
...
@@ -35,9 +35,7 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID
=
0
DEVICE_ID
=
0
@
pytest
.
fixture
def
new_block_manager
():
def
block_manager
():
"""Pytest fixture for creating a BlockManager instance."""
return
BlockManager
(
return
BlockManager
(
WORKER_ID
,
WORKER_ID
,
NUM_LAYER
,
NUM_LAYER
,
...
@@ -51,6 +49,11 @@ def block_manager():
...
@@ -51,6 +49,11 @@ def block_manager():
)
)
@
pytest
.
fixture
def
block_manager
():
return
new_block_manager
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_manager_initialization
():
async
def
test_block_manager_initialization
():
# Python should drop the BlockManager instance as soon as it goes out of scope, but
# Python should drop the BlockManager instance as soon as it goes out of scope, but
...
@@ -106,22 +109,22 @@ async def test_block_manager_initialization():
...
@@ -106,22 +109,22 @@ async def test_block_manager_initialization():
async
def
test_cpu_block_access
(
block_manager
:
BlockManager
):
async
def
test_cpu_block_access
(
block_manager
:
BlockManager
):
block_count
=
2
block_count
=
2
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
block_count
)
py_
blocks
=
block_list
.
to_list
()
blocks
=
block_list
.
to_list
()
assert
len
(
py_
blocks
)
==
block_count
assert
len
(
blocks
)
==
block_count
tensors
=
[
torch
.
from_dlpack
(
b
)
for
b
in
py_
blocks
]
tensors
=
[
torch
.
from_dlpack
(
b
)
for
b
in
blocks
]
for
tensor
in
tensors
:
for
tensor
in
tensors
:
assert
tensor
.
get_device
()
==
-
1
# CPU
assert
tensor
.
get_device
()
==
-
1
# CPU
assert
tensor
.
shape
==
(
1
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
shape
==
(
1
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
dtype
==
TORCH_DTYPE
assert
tensor
.
dtype
==
TORCH_DTYPE
# print(tensors)
# print(tensors)
for
tensor
in
tensors
:
for
tensor
in
tensors
:
tensor
[
0
][
0
][
0
][
0
]
=
1.0
tensor
[
0
][
0
][
0
][
0
]
[
0
]
=
1.0
tensor
[
0
][
NUM_LAYER
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
tensor
[
0
][
NUM_LAYER
-
1
][
OUTER_DIM
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
# print(tensors)
# print(tensors)
py_
blocks_
=
block_list
.
to_list
()
blocks_
=
block_list
.
to_list
()
assert
py_
blocks
is
not
py_
blocks_
assert
blocks
is
not
blocks_
assert
len
(
py_
blocks
)
==
len
(
py_
blocks_
)
assert
len
(
blocks
)
==
len
(
blocks_
)
tensors_
=
[
torch
.
from_dlpack
(
b
)
for
b
in
py_
blocks_
]
tensors_
=
[
torch
.
from_dlpack
(
b
)
for
b
in
blocks_
]
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
assert
tensor
is
not
tensor_
assert
tensor
is
not
tensor_
assert
tensor
.
shape
==
tensor_
.
shape
assert
tensor
.
shape
==
tensor_
.
shape
...
@@ -133,22 +136,22 @@ async def test_cpu_block_access(block_manager: BlockManager):
...
@@ -133,22 +136,22 @@ async def test_cpu_block_access(block_manager: BlockManager):
async
def
test_gpu_block_access
(
block_manager
:
BlockManager
):
async
def
test_gpu_block_access
(
block_manager
:
BlockManager
):
block_count
=
6
block_count
=
6
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
block_count
)
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
block_count
)
py_
blocks
=
block_list
.
to_list
()
blocks
=
block_list
.
to_list
()
assert
len
(
py_
blocks
)
==
block_count
assert
len
(
blocks
)
==
block_count
tensors
=
[
torch
.
from_dlpack
(
b
)
for
b
in
py_
blocks
]
tensors
=
[
torch
.
from_dlpack
(
b
)
for
b
in
blocks
]
for
tensor
in
tensors
:
for
tensor
in
tensors
:
assert
tensor
.
get_device
()
==
DEVICE_ID
# GPU
assert
tensor
.
get_device
()
==
DEVICE_ID
# GPU
assert
tensor
.
shape
==
(
1
,
NUM_LAYER
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
shape
==
(
1
,
NUM_LAYER
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
dtype
==
TORCH_DTYPE
assert
tensor
.
dtype
==
TORCH_DTYPE
# print(tensors)
# print(tensors)
for
tensor
in
tensors
:
for
tensor
in
tensors
:
tensor
[
0
][
0
][
0
][
0
]
=
1.0
tensor
[
0
][
0
][
0
][
0
]
[
0
]
=
1.0
tensor
[
0
][
NUM_LAYER
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
tensor
[
0
][
NUM_LAYER
-
1
][
OUTER_DIM
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
# print(tensors)
# print(tensors)
py_
blocks_
=
block_list
.
to_list
()
blocks_
=
block_list
.
to_list
()
assert
py_
blocks
is
not
py_
blocks_
assert
blocks
is
not
blocks_
assert
len
(
py_
blocks
)
==
len
(
py_
blocks_
)
assert
len
(
blocks
)
==
len
(
blocks_
)
tensors_
=
[
torch
.
from_dlpack
(
b
)
for
b
in
py_
blocks_
]
tensors_
=
[
torch
.
from_dlpack
(
b
)
for
b
in
blocks_
]
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
assert
tensor
is
not
tensor_
assert
tensor
is
not
tensor_
assert
tensor
.
shape
==
tensor_
.
shape
assert
tensor
.
shape
==
tensor_
.
shape
...
@@ -159,27 +162,27 @@ async def test_gpu_block_access(block_manager: BlockManager):
...
@@ -159,27 +162,27 @@ async def test_gpu_block_access(block_manager: BlockManager):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_list_iteration
(
block_manager
:
BlockManager
):
async
def
test_block_list_iteration
(
block_manager
:
BlockManager
):
block_count
=
4
block_count
=
4
block_list
=
block_manager
.
allocate_host_blocks
_blocking
(
block_count
)
block_list
=
await
block_manager
.
allocate_host_blocks
(
block_count
)
# Test __len__()
# Test __len__()
assert
len
(
block_list
)
==
block_count
assert
len
(
block_list
)
==
block_count
# Test __getitem__()
# Test __getitem__()
for
i
in
range
(
block_count
):
for
i
in
range
(
block_count
):
block
=
block_list
[
i
]
block
=
block_list
[
i
]
tensor
=
torch
.
from_dlpack
(
block
)
tensor
=
torch
.
from_dlpack
(
block
)
tensor
[
0
][
0
][
0
][
0
]
=
1.0
+
i
tensor
[
0
][
0
][
0
][
0
]
[
0
]
=
1.0
+
i
# Test __iter__() and __next__()
# Test __iter__() and __next__()
idx
=
1.0
idx
=
1.0
for
block
in
block_list
:
for
block
in
block_list
:
tensor
=
torch
.
from_dlpack
(
block
)
tensor
=
torch
.
from_dlpack
(
block
)
assert
tensor
[
0
][
0
][
0
][
0
]
==
idx
assert
tensor
[
0
][
0
][
0
][
0
]
[
0
]
==
idx
tensor
[
0
][
0
][
0
][
0
]
+=
0.5
tensor
[
0
][
0
][
0
][
0
]
[
0
]
+=
0.5
idx
+=
1.0
idx
+=
1.0
assert
idx
==
1.0
+
block_count
assert
idx
==
1.0
+
block_count
# Test __iter__() should reset current index
# Test __iter__() should reset current index
idx
=
1.0
idx
=
1.0
for
block
in
block_list
:
for
block
in
block_list
:
tensor
=
torch
.
from_dlpack
(
block
)
tensor
=
torch
.
from_dlpack
(
block
)
assert
tensor
[
0
][
0
][
0
][
0
]
==
idx
+
0.5
assert
tensor
[
0
][
0
][
0
][
0
]
[
0
]
==
idx
+
0.5
idx
+=
1.0
idx
+=
1.0
assert
idx
==
1.0
+
block_count
assert
idx
==
1.0
+
block_count
...
@@ -187,27 +190,37 @@ async def test_block_list_iteration(block_manager: BlockManager):
...
@@ -187,27 +190,37 @@ async def test_block_list_iteration(block_manager: BlockManager):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_copy_g1_g2
(
block_manager
:
BlockManager
):
async
def
test_block_copy_g1_g2
(
block_manager
:
BlockManager
):
# Allocate device (G1) and host (G2) block
# Allocate device (G1) and host (G2) block
host_block_list
=
block_manager
.
allocate_host_blocks
_blocking
(
1
)
host_block_list
=
await
block_manager
.
allocate_host_blocks
(
1
)
device_block_list
=
block_manager
.
allocate_device_blocks
_blocking
(
1
)
device_block_list
=
await
block_manager
.
allocate_device_blocks
(
1
)
# Populate host block with unique values
# Populate host block with unique values
host_tensor
=
torch
.
from_dlpack
(
host_block_list
[
0
])
host_tensor
=
torch
.
from_dlpack
(
host_block_list
[
0
])
for
i
in
range
(
NUM_LAYER
):
for
i
in
range
(
NUM_LAYER
):
for
j
in
range
(
PAGE_SIZE
):
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
INNER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
host_tensor
[
0
][
i
][
j
][
k
]
=
i
*
PAGE_SIZE
*
INNER_DIM
+
j
*
INNER_DIM
+
k
for
w
in
range
(
INNER_DIM
):
host_tensor
[
0
][
i
][
j
][
k
][
w
]
=
(
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
# Copy host block to device block after permuting
# Copy host block to device block after permuting
permute_dims
=
(
0
,
2
,
3
,
1
)
permute_dims
=
(
0
,
2
,
4
,
3
,
1
)
device_tensor_
=
torch
.
from_dlpack
(
device_block_list
[
0
]).
permute
(
*
permute_dims
)
device_tensor_
=
torch
.
from_dlpack
(
device_block_list
[
0
]).
permute
(
*
permute_dims
)
device_tensor_
.
copy_
(
host_tensor
.
permute
(
*
permute_dims
))
device_tensor_
.
copy_
(
host_tensor
.
permute
(
*
permute_dims
))
# Assert device block is contiguous and updated in block manager
# Assert device block is contiguous and updated in block manager
device_tensor
=
torch
.
from_dlpack
(
device_block_list
[
0
])
device_tensor
=
torch
.
from_dlpack
(
device_block_list
[
0
])
for
i
in
range
(
NUM_LAYER
):
for
i
in
range
(
NUM_LAYER
):
for
j
in
range
(
PAGE_SIZE
):
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
INNER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
assert
(
for
w
in
range
(
INNER_DIM
):
device_tensor
[
0
][
i
][
j
][
k
]
assert
(
==
i
*
PAGE_SIZE
*
INNER_DIM
+
j
*
INNER_DIM
+
k
device_tensor
[
0
][
i
][
j
][
k
][
w
]
)
==
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
# Set host block to zero and assert updated in block manager
# Set host block to zero and assert updated in block manager
host_tensor_
=
torch
.
from_dlpack
(
host_block_list
[
0
]).
permute
(
*
permute_dims
)
host_tensor_
=
torch
.
from_dlpack
(
host_block_list
[
0
]).
permute
(
*
permute_dims
)
host_tensor_
.
zero_
()
host_tensor_
.
zero_
()
...
@@ -216,22 +229,166 @@ async def test_block_copy_g1_g2(block_manager: BlockManager):
...
@@ -216,22 +229,166 @@ async def test_block_copy_g1_g2(block_manager: BlockManager):
host_tensor_
.
copy_
(
device_tensor_
)
host_tensor_
.
copy_
(
device_tensor_
)
# Assert host block is updated in block manager
# Assert host block is updated in block manager
for
i
in
range
(
NUM_LAYER
):
for
i
in
range
(
NUM_LAYER
):
for
j
in
range
(
PAGE_SIZE
):
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
INNER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
assert
(
for
w
in
range
(
INNER_DIM
):
host_tensor
[
0
][
i
][
j
][
k
]
assert
(
==
i
*
PAGE_SIZE
*
INNER_DIM
+
j
*
INNER_DIM
+
k
host_tensor
[
0
][
i
][
j
][
k
][
w
]
)
==
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_cpu_layer_access
(
block_manager
:
BlockManager
):
block_list
=
block_manager
.
allocate_host_blocks_blocking
(
1
)
block
=
block_list
[
0
]
layers
=
block
.
to_list
()
assert
len
(
layers
)
==
NUM_LAYER
tensors
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
layers
]
for
tensor
in
tensors
:
assert
tensor
.
get_device
()
==
-
1
# CPU
assert
tensor
.
shape
==
(
1
,
1
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
dtype
==
TORCH_DTYPE
# print(tensors)
for
tensor
in
tensors
:
tensor
[
0
][
0
][
0
][
0
][
0
]
=
1.0
tensor
[
0
][
0
][
OUTER_DIM
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
# print(tensors)
layers_
=
block
.
to_list
()
assert
layers
is
not
layers_
assert
len
(
layers
)
==
len
(
layers_
)
tensors_
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
layers_
]
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
assert
tensor
is
not
tensor_
assert
tensor
.
shape
==
tensor_
.
shape
assert
tensor
.
dtype
==
tensor_
.
dtype
assert
torch
.
allclose
(
tensor
,
tensor_
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_gpu_layer_access
(
block_manager
:
BlockManager
):
block_list
=
block_manager
.
allocate_device_blocks_blocking
(
1
)
block
=
block_list
[
0
]
layers
=
block
.
to_list
()
assert
len
(
layers
)
==
NUM_LAYER
tensors
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
layers
]
for
tensor
in
tensors
:
assert
tensor
.
get_device
()
==
DEVICE_ID
# GPU
assert
tensor
.
shape
==
(
1
,
1
,
OUTER_DIM
,
PAGE_SIZE
,
INNER_DIM
)
assert
tensor
.
dtype
==
TORCH_DTYPE
# print(tensors)
for
tensor
in
tensors
:
tensor
[
0
][
0
][
0
][
0
][
0
]
=
1.0
tensor
[
0
][
0
][
OUTER_DIM
-
1
][
PAGE_SIZE
-
1
][
INNER_DIM
-
1
]
=
1.0
# print(tensors)
layers_
=
block
.
to_list
()
assert
layers
is
not
layers_
assert
len
(
layers
)
==
len
(
layers_
)
tensors_
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
layers_
]
for
tensor
,
tensor_
in
zip
(
tensors
,
tensors_
):
assert
tensor
is
not
tensor_
assert
tensor
.
shape
==
tensor_
.
shape
assert
tensor
.
dtype
==
tensor_
.
dtype
assert
torch
.
allclose
(
tensor
,
tensor_
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_iteration
(
block_manager
:
BlockManager
):
block
=
(
await
block_manager
.
allocate_host_blocks
(
1
))[
0
]
# Test __len__()
assert
len
(
block
)
==
NUM_LAYER
# Test __getitem__()
for
i
in
range
(
NUM_LAYER
):
layer
=
block
[
i
]
tensor
=
torch
.
from_dlpack
(
layer
)
tensor
[
0
][
0
][
0
][
0
][
0
]
=
1.0
+
i
# Test __iter__() and __next__()
idx
=
1.0
for
layer
in
block
:
tensor
=
torch
.
from_dlpack
(
layer
)
assert
tensor
[
0
][
0
][
0
][
0
][
0
]
==
idx
tensor
[
0
][
0
][
0
][
0
][
0
]
+=
0.5
idx
+=
1.0
assert
idx
==
1.0
+
NUM_LAYER
# Test __iter__() should reset current index
idx
=
1.0
for
layer
in
block
:
tensor
=
torch
.
from_dlpack
(
layer
)
assert
tensor
[
0
][
0
][
0
][
0
][
0
]
==
idx
+
0.5
idx
+=
1.0
assert
idx
==
1.0
+
NUM_LAYER
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA unavailable"
)
async
def
test_block_layer_copy_g1_g2
(
block_manager
:
BlockManager
):
# Allocate device (G1) and host (G2) block
host_block
=
(
await
block_manager
.
allocate_host_blocks
(
1
))[
0
]
device_block
=
(
await
block_manager
.
allocate_device_blocks
(
1
))[
0
]
# Populate host block at layer level with unique values
host_layer_tensors
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
host_block
]
for
i
in
range
(
NUM_LAYER
):
host_layer_tensor
=
host_layer_tensors
[
i
]
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
for
w
in
range
(
INNER_DIM
):
host_layer_tensor
[
0
][
0
][
j
][
k
][
w
]
=
(
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
# Copy host block to device block after permuting
permute_dims
=
(
0
,
2
,
4
,
3
,
1
)
host_block_tensor_
=
torch
.
from_dlpack
(
host_block
).
permute
(
*
permute_dims
)
device_block_tensor_
=
torch
.
from_dlpack
(
device_block
).
permute
(
*
permute_dims
)
device_block_tensor_
.
copy_
(
host_block_tensor_
)
# Assert device block is contiguous and updated in block manager at layer level
device_layer_tensors
=
[
torch
.
from_dlpack
(
bl
)
for
bl
in
device_block
]
for
i
in
range
(
NUM_LAYER
):
device_layer_tensor
=
device_layer_tensors
[
i
]
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
for
w
in
range
(
INNER_DIM
):
assert
(
device_layer_tensor
[
0
][
0
][
j
][
k
][
w
]
==
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
# Set host block to zero and assert updated in block manager
host_block_tensor
=
torch
.
from_dlpack
(
host_block
)
host_block_tensor
.
zero_
()
assert
torch
.
all
(
host_block_tensor_
==
0
)
# Copy device block back to host block
host_block_tensor_
.
copy_
(
device_block_tensor_
)
# Assert host block is updated in block manager
for
i
in
range
(
NUM_LAYER
):
for
j
in
range
(
OUTER_DIM
):
for
k
in
range
(
PAGE_SIZE
):
for
w
in
range
(
INNER_DIM
):
assert
(
host_block_tensor
[
0
][
i
][
j
][
k
][
w
]
==
i
*
OUTER_DIM
*
PAGE_SIZE
*
INNER_DIM
+
j
*
PAGE_SIZE
*
INNER_DIM
+
k
*
INNER_DIM
+
w
)
async
def
main
():
async
def
main
():
await
test_block_manager_initialization
()
await
test_block_manager_initialization
()
await
test_cpu_block_access
(
new_block_manager
())
# todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v)
await
test_gpu_block_access
(
new_block_manager
())
# await test_cpu_block_access()
await
test_block_list_iteration
(
new_block_manager
())
# await test_gpu_block_access()
await
test_block_copy_g1_g2
(
new_block_manager
())
# await test_block_list_iteration()
await
test_cpu_layer_access
(
new_block_manager
())
# await test_block_copy_g1_g2()
await
test_gpu_layer_access
(
new_block_manager
())
await
test_block_iteration
(
new_block_manager
())
await
test_block_layer_copy_g1_g2
(
new_block_manager
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment