"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "5602dd2f6380e9e42fb120c3e8f14a5b6fd026fb"
Unverified Commit 7677f74f authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: KVBM async Python bindings and Layer class (#1141)

parent a0512bd1
...@@ -14,18 +14,18 @@ ...@@ -14,18 +14,18 @@
// limitations under the License. // limitations under the License.
#![cfg(feature = "block-manager")] #![cfg(feature = "block-manager")]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use super::*; use super::*;
use pyo3::PyResult; use pyo3::PyResult;
use tokio;
mod block; mod block;
mod block_list; mod block_list;
mod dlpack;
mod layer;
/// Add bingings from this crate to the provided module /// Add bingings from this crate to the provided module
pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<layer::Layer>()?;
m.add_class::<block::Block>()?; m.add_class::<block::Block>()?;
m.add_class::<block_list::BlockList>()?; m.add_class::<block_list::BlockList>()?;
m.add_class::<BlockManager>()?; m.add_class::<BlockManager>()?;
...@@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -34,9 +34,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
#[pyclass] #[pyclass]
pub struct BlockManager { pub struct BlockManager {
// TODO: Can this be implicitly created and referenced?
tokio_runtime: tokio::runtime::Runtime,
// Block manager
inner: Arc<dynamo_llm::block_manager::ReferenceBlockManager>, inner: Arc<dynamo_llm::block_manager::ReferenceBlockManager>,
// TODO: Metadata should be stored in the block manager? // TODO: Metadata should be stored in the block manager?
dtype: dynamo_llm::common::dtype::DType, dtype: dynamo_llm::common::dtype::DType,
...@@ -62,7 +59,7 @@ impl BlockManager { ...@@ -62,7 +59,7 @@ impl BlockManager {
dynamo_llm::block_manager::KvManagerRuntimeConfig::builder() dynamo_llm::block_manager::KvManagerRuntimeConfig::builder()
.worker_id(worker_id) .worker_id(worker_id)
.build() .build()
.unwrap(), .map_err(to_pyerr)?,
); );
let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder() let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
.num_layers(num_layer) .num_layers(num_layer)
...@@ -93,14 +90,17 @@ impl BlockManager { ...@@ -93,14 +90,17 @@ impl BlockManager {
}; };
} }
model_config = model_config.dtype(dtype_.clone()); model_config = model_config.dtype(dtype_.clone());
config = config.model(model_config.build().unwrap()); config = config.model(model_config.build().map_err(to_pyerr)?);
if let Some(host_num_blocks) = host_num_blocks { if let Some(host_num_blocks) = host_num_blocks {
config = config.host_layout( config = config.host_layout(
dynamo_llm::block_manager::KvManagerLayoutConfig::builder() dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
.num_blocks(host_num_blocks) .num_blocks(host_num_blocks)
.allocator(dynamo_llm::block_manager::storage::PinnedAllocator::new().unwrap()) .allocator(
dynamo_llm::block_manager::storage::PinnedAllocator::new()
.map_err(to_pyerr)?,
)
.build() .build()
.unwrap(), .map_err(to_pyerr)?,
); );
} }
if let Some(device_num_blocks) = device_num_blocks { if let Some(device_num_blocks) = device_num_blocks {
...@@ -109,23 +109,22 @@ impl BlockManager { ...@@ -109,23 +109,22 @@ impl BlockManager {
.num_blocks(device_num_blocks) .num_blocks(device_num_blocks)
.allocator( .allocator(
dynamo_llm::block_manager::storage::DeviceAllocator::new(device_id) dynamo_llm::block_manager::storage::DeviceAllocator::new(device_id)
.unwrap(), .map_err(to_pyerr)?,
) )
.build() .build()
.unwrap(), .map_err(to_pyerr)?,
); );
} }
let config = config.build().unwrap(); let config = config.build().map_err(to_pyerr)?;
let tokio_runtime = tokio::runtime::Builder::new_multi_thread() let tokio_runtime = pyo3_async_runtimes::tokio::get_runtime();
.enable_all()
.build()
.unwrap();
let block_manager = tokio_runtime.block_on(async {
dynamo_llm::block_manager::ReferenceBlockManager::new(config).unwrap()
});
Ok(BlockManager { Ok(BlockManager {
tokio_runtime: tokio_runtime, inner: Arc::from(
inner: Arc::from(block_manager), tokio_runtime
.block_on(async {
dynamo_llm::block_manager::ReferenceBlockManager::new(config)
})
.map_err(to_pyerr)?,
),
dtype: dtype_, dtype: dtype_,
device_id: device_id, device_id: device_id,
}) })
...@@ -135,9 +134,11 @@ impl BlockManager { ...@@ -135,9 +134,11 @@ impl BlockManager {
let blocks = self let blocks = self
.inner .inner
.host() .host()
.unwrap() .ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
})?
.allocate_blocks_blocking(count) .allocate_blocks_blocking(count)
.unwrap(); .map_err(to_pyerr)?;
// Wrap each block in an enum accounting for Pinned & Device block // Wrap each block in an enum accounting for Pinned & Device block
let blocks = blocks let blocks = blocks
.into_iter() .into_iter()
...@@ -150,13 +151,42 @@ impl BlockManager { ...@@ -150,13 +151,42 @@ impl BlockManager {
)) ))
} }
#[pyo3(signature = (count))]
fn allocate_host_blocks<'py>(
&self,
py: Python<'py>,
count: usize,
) -> PyResult<Bound<'py, PyAny>> {
let inner = self.inner.clone();
let dtype = self.dtype.clone();
let device_id = self.device_id;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let blocks = inner
.host()
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
})?
.allocate_blocks(count)
.await
.map_err(to_pyerr)?;
// Wrap each block in an enum accounting for Pinned & Device block
let blocks = blocks
.into_iter()
.map(|b| block::BlockType::Pinned(b))
.collect();
Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
})
}
fn allocate_device_blocks_blocking(&self, count: usize) -> PyResult<block_list::BlockList> { fn allocate_device_blocks_blocking(&self, count: usize) -> PyResult<block_list::BlockList> {
let blocks = self let blocks = self
.inner .inner
.device() .device()
.unwrap() .ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
})?
.allocate_blocks_blocking(count) .allocate_blocks_blocking(count)
.unwrap(); .map_err(to_pyerr)?;
// Wrap each block in an enum accounting for Pinned & Device block // Wrap each block in an enum accounting for Pinned & Device block
let blocks = blocks let blocks = blocks
.into_iter() .into_iter()
...@@ -168,4 +198,31 @@ impl BlockManager { ...@@ -168,4 +198,31 @@ impl BlockManager {
self.device_id, self.device_id,
)) ))
} }
#[pyo3(signature = (count))]
fn allocate_device_blocks<'py>(
&self,
py: Python<'py>,
count: usize,
) -> PyResult<Bound<'py, PyAny>> {
let inner = self.inner.clone();
let dtype = self.dtype.clone();
let device_id = self.device_id;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let blocks = inner
.device()
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
})?
.allocate_blocks(count)
.await
.map_err(to_pyerr)?;
// Wrap each block in an enum accounting for Pinned & Device block
let blocks = blocks
.into_iter()
.map(|b| block::BlockType::Device(b))
.collect();
Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
})
}
} }
...@@ -14,16 +14,14 @@ ...@@ -14,16 +14,14 @@
// limitations under the License. // limitations under the License.
#![cfg(feature = "block-manager")] #![cfg(feature = "block-manager")]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use super::*; use super::*;
use dlpark::prelude::{DataType, Device, ManagerCtx, ShapeAndStrides, ToTensor};
use pyo3::{ffi::c_str, prelude::IntoPy, types::PyTuple, PyObject, PyResult, Python};
use std::sync::{Arc, Mutex};
use dynamo_llm::block_manager::block::BlockDataExt; use dynamo_llm::block_manager::block::BlockDataExt;
use pyo3::{
types::{PyList, PyTuple},
PyObject, PyResult, Python,
};
use std::sync::{Arc, Mutex};
pub enum BlockType { pub enum BlockType {
Pinned( Pinned(
...@@ -40,111 +38,14 @@ pub enum BlockType { ...@@ -40,111 +38,14 @@ pub enum BlockType {
), ),
} }
struct DlPackTensor {
block: Arc<Mutex<BlockType>>,
// TODO: Metadata should be stored in the block manager?
dtype: dynamo_llm::common::dtype::DType,
device_id: usize,
}
impl ToTensor for DlPackTensor {
fn data_ptr(&self) -> *mut std::ffi::c_void {
let mut mutable_block = self.block.lock().unwrap();
let ptr = match &mut *mutable_block {
BlockType::Pinned(block) => {
let mut block_view_mut = block
.block_view_mut()
.expect("Failed to get mutable Pinned block view");
unsafe { block_view_mut.as_mut_ptr() }
}
BlockType::Device(block) => {
let mut block_view_mut = block
.block_view_mut()
.expect("Failed to get mutable Device block view");
unsafe { block_view_mut.as_mut_ptr() }
}
};
ptr as *mut std::ffi::c_void
}
fn byte_offset(&self) -> u64 {
0
}
fn device(&self) -> Device {
let mutable_block = self.block.lock().unwrap();
match &*mutable_block {
BlockType::Pinned(_) => {
// TODO: Why torch does not support CPU_PINNED here?
/*Device {
device_type: DeviceType::CudaHost,
device_id: 0,
}*/
Device::CPU
}
BlockType::Device(_) => Device::cuda(self.device_id),
}
}
fn dtype(&self) -> DataType {
// Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType
match self.dtype {
dynamo_llm::common::dtype::DType::FP8 => {
// No direct FP8 equivalent, use U8 as closest alternative
DataType::U8
}
dynamo_llm::common::dtype::DType::FP16 => DataType::F16,
dynamo_llm::common::dtype::DType::BF16 => DataType::BF16,
dynamo_llm::common::dtype::DType::FP32 => DataType::F32,
dynamo_llm::common::dtype::DType::U8 => DataType::U8,
dynamo_llm::common::dtype::DType::U16 => DataType::U16,
dynamo_llm::common::dtype::DType::U32 => DataType::U32,
dynamo_llm::common::dtype::DType::U64 => DataType::U64,
dynamo_llm::common::dtype::DType::I8 => DataType::I8,
dynamo_llm::common::dtype::DType::I16 => DataType::I16,
dynamo_llm::common::dtype::DType::I32 => DataType::I32,
dynamo_llm::common::dtype::DType::I64 => DataType::I64,
}
}
fn shape_and_strides(&self) -> ShapeAndStrides {
let mutable_block = self.block.lock().unwrap();
let (num_blocks, num_layers, page_size, inner_dim) = match &*mutable_block {
BlockType::Pinned(block) => (
block.num_blocks(),
block.num_layers(),
block.page_size(),
block.inner_dim(),
),
BlockType::Device(block) => (
block.num_blocks(),
block.num_layers(),
block.page_size(),
block.inner_dim(),
),
};
let shape_i64: Vec<i64> = vec![
num_blocks as i64,
num_layers as i64,
page_size as i64,
inner_dim as i64,
];
ShapeAndStrides::new_contiguous(&shape_i64)
}
}
/*impl Drop for DlPackTensor {
fn drop(&mut self) {
println!("Dropping DlPackTensor");
}
}*/
#[pyclass] #[pyclass]
pub struct Block { pub struct Block {
inner: Arc<Mutex<BlockType>>, inner: Arc<Mutex<BlockType>>,
// TODO: Metadata should be stored in the block manager? // TODO: Metadata should be stored in the block manager?
dtype: dynamo_llm::common::dtype::DType, dtype: dynamo_llm::common::dtype::DType,
device_id: usize, device_id: usize,
// Python iterator state
py_itr_idx: usize,
} }
impl Block { impl Block {
...@@ -157,69 +58,161 @@ impl Block { ...@@ -157,69 +58,161 @@ impl Block {
inner: block, inner: block,
dtype: dtype, dtype: dtype,
device_id: device_id, device_id: device_id,
py_itr_idx: 0,
}
}
fn num_layers(&self) -> usize {
let mutable_block = self.inner.lock().unwrap();
match &*mutable_block {
BlockType::Pinned(block) => block.num_layers(),
BlockType::Device(block) => block.num_layers(),
} }
} }
} }
#[pymethods] #[pymethods]
impl Block { impl Block {
#[pyo3(signature = ())]
fn to_list<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
let layers: Vec<layer::Layer> = (0..self.num_layers())
.map(|layer_idx| {
layer::Layer::from_rust(
self.inner.clone(),
layer_idx,
self.dtype.clone(),
self.device_id,
)
})
.collect();
PyList::new(py, layers)
}
fn __len__(&self) -> PyResult<usize> {
Ok(self.num_layers())
}
fn __getitem__(&self, index: usize) -> PyResult<layer::Layer> {
let num_layers = self.num_layers();
if index >= num_layers {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"Index {} out of range for Block with {} layers",
index, num_layers
)));
}
let layer = layer::Layer::from_rust(
self.inner.clone(),
index,
self.dtype.clone(),
self.device_id,
);
Ok(layer)
}
fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
// Reset iterator index at the beginning of each iteration
// Use to_list() for iterating concurrently
slf.py_itr_idx = 0;
Ok(slf)
}
fn __next__(&mut self) -> PyResult<layer::Layer> {
if self.py_itr_idx >= self.num_layers() {
return Err(pyo3::exceptions::PyStopIteration::new_err(
"No more items in Block",
));
}
let layer = layer::Layer::from_rust(
self.inner.clone(),
self.py_itr_idx,
self.dtype.clone(),
self.device_id,
);
self.py_itr_idx += 1;
Ok(layer)
}
#[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))] #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
fn __dlpack__( fn __dlpack__<'py>(
&self, &self,
py: Python<'py>,
stream: Option<PyObject>, stream: Option<PyObject>,
max_version: Option<PyObject>, max_version: Option<PyObject>,
dl_device: Option<PyObject>, dl_device: Option<PyObject>,
copy: Option<bool>, copy: Option<bool>,
) -> PyResult<PyObject> { ) -> PyResult<PyObject> {
// Panic if any arguments are provided // Return error if any arguments are provided
if stream.is_some() { if stream.is_some() {
panic!("stream argument is not supported"); return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"stream argument is not supported",
));
} }
if max_version.is_some() { if max_version.is_some() {
panic!("max_version argument is not supported"); return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"max_version argument is not supported",
));
} }
if dl_device.is_some() { if dl_device.is_some() {
panic!("dl_device argument is not supported"); return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"dl_device argument is not supported",
));
} }
if copy.is_some() { if copy.is_some() {
panic!("copy argument is not supported"); return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"copy argument is not supported",
));
} }
// Create DLPack PyCapsule // Extract all necessary data for dlpack
let manager_ctx = ManagerCtx::new(DlPackTensor { let ptr: *mut std::ffi::c_void;
block: self.inner.clone(), let num_blocks: i64;
dtype: self.dtype.clone(), let num_layers: i64;
device_id: self.device_id, let num_outer_dims: i64;
}); let page_size: i64;
let py_capsule = Python::with_gil(|py| manager_ctx.into_py(py)); let inner_dim: i64;
Ok(py_capsule) {
} let mut mutable_block = self.inner.lock().unwrap();
ptr = match &mut *mutable_block {
fn __dlpack_device__(&self) -> PyResult<Py<PyTuple>> { BlockType::Pinned(block) => {
let dlpack_device = Python::with_gil(|py| { let mut block_view_mut = block.block_view_mut().map_err(to_pyerr)?;
let device_type_list = py.eval(c_str!("[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"), None, None).unwrap(); (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
let device_type_enum = py }
.import("enum") BlockType::Device(block) => {
.unwrap() let mut block_view_mut = block.block_view_mut().map_err(to_pyerr)?;
.getattr("Enum") (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
.unwrap() }
.call1(("DLDeviceType", device_type_list)) };
.unwrap(); (num_blocks, num_layers, num_outer_dims, page_size, inner_dim) = match &*mutable_block {
let block = self.inner.lock().unwrap(); BlockType::Pinned(block) => (
let device_type = match &*block { block.num_blocks() as i64,
BlockType::Pinned(_) => device_type_enum.getattr("CPU_PINNED").unwrap(), block.num_layers() as i64,
BlockType::Device(_) => device_type_enum.getattr("CUDA").unwrap(), block.num_outer_dims() as i64,
block.page_size() as i64,
block.inner_dim() as i64,
),
BlockType::Device(block) => (
block.num_blocks() as i64,
block.num_layers() as i64,
block.num_outer_dims() as i64,
block.page_size() as i64,
block.inner_dim() as i64,
),
}; };
let device_id = self.device_id.into_py(py).into_bound(py); }
let device = vec![device_type, device_id];
PyTuple::new(py, device).unwrap().unbind() // Create the DLPack tensor
}); dlpack::dlpack(
Ok(dlpack_device) py,
self.inner.clone(),
ptr,
vec![num_blocks, num_layers, num_outer_dims, page_size, inner_dim],
self.dtype.clone(),
self.device_id,
)
} }
}
/*impl Drop for Block { #[pyo3(signature = ())]
fn drop(&mut self) { fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
println!("Dropping Block"); dlpack::dlpack_device(py, self.inner.clone(), self.device_id)
} }
}*/ }
...@@ -14,11 +14,8 @@ ...@@ -14,11 +14,8 @@
// limitations under the License. // limitations under the License.
#![cfg(feature = "block-manager")] #![cfg(feature = "block-manager")]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use super::*; use super::*;
use pyo3::{types::PyList, PyResult, Python}; use pyo3::{types::PyList, PyResult, Python};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
...@@ -52,16 +49,14 @@ impl BlockList { ...@@ -52,16 +49,14 @@ impl BlockList {
#[pymethods] #[pymethods]
impl BlockList { impl BlockList {
fn to_list(&self) -> PyResult<Py<PyList>> { #[pyo3(signature = ())]
let py_list = Python::with_gil(|py| { fn to_list<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
let blocks: Vec<block::Block> = self let blocks: Vec<block::Block> = self
.inner .inner
.iter() .iter()
.map(|b| block::Block::from_rust(b.clone(), self.dtype.clone(), self.device_id)) .map(|b| block::Block::from_rust(b.clone(), self.dtype.clone(), self.device_id))
.collect(); .collect();
PyList::new(py, blocks).unwrap().unbind() PyList::new(py, blocks)
});
Ok(py_list)
} }
fn __len__(&self) -> PyResult<usize> { fn __len__(&self) -> PyResult<usize> {
...@@ -84,13 +79,10 @@ impl BlockList { ...@@ -84,13 +79,10 @@ impl BlockList {
Ok(block) Ok(block)
} }
fn __iter__(slf: Py<Self>) -> PyResult<Py<Self>> { fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
Python::with_gil(|py| { // Reset iterator index at the beginning of each iteration
let mut slf = slf.borrow_mut(py); // Use to_list() for iterating concurrently
// Reset iterator index at the beginning of each iteration slf.py_itr_idx = 0;
// Use to_list() for iterating concurrently
slf.py_itr_idx = 0;
});
Ok(slf) Ok(slf)
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg(feature = "block-manager")]
// Silence warnings about deprecated features (like pyo3::IntoPy::into_py)
#![allow(deprecated)]
use super::*;
use dlpark::prelude::{DataType, Device, ManagerCtx, ShapeAndStrides, ToTensor};
use pyo3::{ffi::c_str, prelude::IntoPy, types::PyTuple, PyObject, PyResult, Python};
use std::sync::{Arc, Mutex};
struct DlPackTensor {
block: Arc<Mutex<block::BlockType>>,
ptr: *mut std::ffi::c_void,
shape: Vec<i64>,
// TODO: Metadata should be stored in the block?
dtype: dynamo_llm::common::dtype::DType,
device_id: usize,
}
impl ToTensor for DlPackTensor {
fn data_ptr(&self) -> *mut std::ffi::c_void {
self.ptr
}
fn byte_offset(&self) -> u64 {
0
}
fn device(&self) -> Device {
let mutable_block = self.block.lock().unwrap();
match &*mutable_block {
block::BlockType::Pinned(_) => {
// TODO: Why torch does not support CPU_PINNED here?
/*Device {
device_type: DeviceType::CudaHost,
device_id: 0,
}*/
Device::CPU
}
block::BlockType::Device(_) => Device::cuda(self.device_id),
}
}
fn dtype(&self) -> DataType {
// Map from dynamo_llm::common::dtype::DType to dlpark::prelude::DataType
match self.dtype {
dynamo_llm::common::dtype::DType::FP8 => {
// No direct FP8 equivalent, use U8 as closest alternative
DataType::U8
}
dynamo_llm::common::dtype::DType::FP16 => DataType::F16,
dynamo_llm::common::dtype::DType::BF16 => DataType::BF16,
dynamo_llm::common::dtype::DType::FP32 => DataType::F32,
dynamo_llm::common::dtype::DType::U8 => DataType::U8,
dynamo_llm::common::dtype::DType::U16 => DataType::U16,
dynamo_llm::common::dtype::DType::U32 => DataType::U32,
dynamo_llm::common::dtype::DType::U64 => DataType::U64,
dynamo_llm::common::dtype::DType::I8 => DataType::I8,
dynamo_llm::common::dtype::DType::I16 => DataType::I16,
dynamo_llm::common::dtype::DType::I32 => DataType::I32,
dynamo_llm::common::dtype::DType::I64 => DataType::I64,
}
}
fn shape_and_strides(&self) -> ShapeAndStrides {
ShapeAndStrides::new_contiguous(&self.shape)
}
}
/*impl Drop for DlPackTensor {
fn drop(&mut self) {
println!("Dropping DlPackTensor");
}
}*/
pub fn dlpack<'py>(
py: Python<'py>,
block: Arc<Mutex<block::BlockType>>,
ptr: *mut std::ffi::c_void,
shape: Vec<i64>,
dtype: dynamo_llm::common::dtype::DType,
device_id: usize,
) -> PyResult<PyObject> {
let manager_ctx = ManagerCtx::new(DlPackTensor {
block: block,
ptr: ptr,
shape: shape,
dtype: dtype,
device_id: device_id,
});
let py_capsule = manager_ctx.into_py(py);
Ok(py_capsule)
}
pub fn dlpack_device<'py>(
py: Python<'py>,
block: Arc<Mutex<block::BlockType>>,
device_id: usize,
) -> PyResult<Bound<'py, PyTuple>> {
let dev_type_list = py.eval(c_str!("[('CPU', 1), ('CUDA', 2), ('CPU_PINNED', 3), ('OPENCL', 4), ('VULKAN', 7), ('METAL', 8), ('VPI', 9), ('ROCM', 10)]"), None, None).unwrap();
let dev_type_enum = py
.import("enum")
.unwrap()
.getattr("Enum")
.unwrap()
.call1(("DLDeviceType", dev_type_list))
.unwrap();
let dev_type = match &*block.lock().unwrap() {
block::BlockType::Pinned(_) => dev_type_enum.getattr("CPU_PINNED").unwrap(),
block::BlockType::Device(_) => dev_type_enum.getattr("CUDA").unwrap(),
};
let dev_id = device_id.into_py(py).into_bound(py);
let dev = vec![dev_type, dev_id];
PyTuple::new(py, dev)
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg(feature = "block-manager")]
use super::*;
use dynamo_llm::block_manager::block::BlockDataExt;
use pyo3::{types::PyTuple, PyObject, PyResult, Python};
use std::sync::{Arc, Mutex};
// Layer struct that represents a layer within a block
#[pyclass]
pub struct Layer {
inner: Arc<Mutex<block::BlockType>>,
layer_idx: usize,
dtype: dynamo_llm::common::dtype::DType,
device_id: usize,
}
impl Layer {
pub fn from_rust(
block: Arc<Mutex<block::BlockType>>,
layer_idx: usize,
dtype: dynamo_llm::common::dtype::DType,
device_id: usize,
) -> Self {
Self {
inner: block,
layer_idx,
dtype,
device_id,
}
}
}
#[pymethods]
impl Layer {
#[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
fn __dlpack__<'py>(
&self,
py: Python<'py>,
stream: Option<PyObject>,
max_version: Option<PyObject>,
dl_device: Option<PyObject>,
copy: Option<bool>,
) -> PyResult<PyObject> {
// Return error if any arguments are provided
if stream.is_some() {
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"stream argument is not supported",
));
}
if max_version.is_some() {
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"max_version argument is not supported",
));
}
if dl_device.is_some() {
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"dl_device argument is not supported",
));
}
if copy.is_some() {
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"copy argument is not supported",
));
}
// Extract all necessary data for dlpack
let ptr: *mut std::ffi::c_void;
let num_outer_dims: i64;
let page_size: i64;
let inner_dim: i64;
{
let mut mutable_block = self.inner.lock().unwrap();
ptr = match &mut *mutable_block {
block::BlockType::Pinned(block) => {
let mut layer_view_mut =
block.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?;
(unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
}
block::BlockType::Device(block) => {
let mut layer_view_mut =
block.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?;
(unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
}
};
(num_outer_dims, page_size, inner_dim) = match &*mutable_block {
block::BlockType::Pinned(block) => (
block.num_outer_dims() as i64,
block.page_size() as i64,
block.inner_dim() as i64,
),
block::BlockType::Device(block) => (
block.num_outer_dims() as i64,
block.page_size() as i64,
block.inner_dim() as i64,
),
};
}
// Create the DLPack tensor
dlpack::dlpack(
py,
self.inner.clone(),
ptr,
vec![1, 1, num_outer_dims, page_size, inner_dim],
self.dtype.clone(),
self.device_id,
)
}
#[pyo3(signature = ())]
fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
dlpack::dlpack_device(py, self.inner.clone(), self.device_id)
}
}
...@@ -710,6 +710,25 @@ class NatsQueue: ...@@ -710,6 +710,25 @@ class NatsQueue:
""" """
... ...
class Layer:
"""
A KV cache block layer
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
"""
Get a dlpack capsule of the layer
"""
...
def __dlpack_device__(self) -> Any:
"""
Get the dlpack device of the layer
"""
...
class Block: class Block:
""" """
A KV cache block A KV cache block
...@@ -717,9 +736,40 @@ class Block: ...@@ -717,9 +736,40 @@ class Block:
... ...
def __len__(self) -> int:
"""
Get the number of layers in the list
"""
...
def __getitem__(self, index: int) -> Layer:
"""
Get a layer by index
"""
...
def __iter__(self) -> 'Block':
"""
Get an iterator over the layers
"""
...
def __next__(self) -> Block:
"""
Get the next layer in the iterator
"""
...
def to_list(self) -> List[Layer]:
"""
Get a list of layers
"""
...
def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any: def __dlpack__(self, stream: Optional[Any] = None, max_version: Optional[Any] = None, dl_device: Optional[Any] = None, copy: Optional[bool] = None) -> Any:
""" """
Get a dlpack capsule from the block Get a dlpack capsule of the block
Exception raised if the block is not contiguous
""" """
... ...
...@@ -822,6 +872,22 @@ class BlockManager: ...@@ -822,6 +872,22 @@ class BlockManager:
""" """
... ...
async def allocate_host_blocks(self, count: int) -> BlockList:
"""
Allocate a list of host blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
def allocate_device_blocks_blocking(self, count: int) -> BlockList: def allocate_device_blocks_blocking(self, count: int) -> BlockList:
""" """
Allocate a list of device blocks (blocking call) Allocate a list of device blocks (blocking call)
...@@ -837,3 +903,19 @@ class BlockManager: ...@@ -837,3 +903,19 @@ class BlockManager:
List of allocated blocks List of allocated blocks
""" """
... ...
async def allocate_device_blocks(self, count: int) -> BlockList:
"""
Allocate a list of device blocks
Parameters:
-----------
count: int
Number of blocks to allocate
Returns:
--------
BlockList
List of allocated blocks
"""
...
...@@ -35,9 +35,7 @@ DEVICE_NUM_BLOCKS = 16 ...@@ -35,9 +35,7 @@ DEVICE_NUM_BLOCKS = 16
DEVICE_ID = 0 DEVICE_ID = 0
@pytest.fixture def new_block_manager():
def block_manager():
"""Pytest fixture for creating a BlockManager instance."""
return BlockManager( return BlockManager(
WORKER_ID, WORKER_ID,
NUM_LAYER, NUM_LAYER,
...@@ -51,6 +49,11 @@ def block_manager(): ...@@ -51,6 +49,11 @@ def block_manager():
) )
@pytest.fixture
def block_manager():
return new_block_manager()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_manager_initialization(): async def test_block_manager_initialization():
# Python should drop the BlockManager instance as soon as it goes out of scope, but # Python should drop the BlockManager instance as soon as it goes out of scope, but
...@@ -106,22 +109,22 @@ async def test_block_manager_initialization(): ...@@ -106,22 +109,22 @@ async def test_block_manager_initialization():
async def test_cpu_block_access(block_manager: BlockManager): async def test_cpu_block_access(block_manager: BlockManager):
block_count = 2 block_count = 2
block_list = block_manager.allocate_host_blocks_blocking(block_count) block_list = block_manager.allocate_host_blocks_blocking(block_count)
py_blocks = block_list.to_list() blocks = block_list.to_list()
assert len(py_blocks) == block_count assert len(blocks) == block_count
tensors = [torch.from_dlpack(b) for b in py_blocks] tensors = [torch.from_dlpack(b) for b in blocks]
for tensor in tensors: for tensor in tensors:
assert tensor.get_device() == -1 # CPU assert tensor.get_device() == -1 # CPU
assert tensor.shape == (1, NUM_LAYER, PAGE_SIZE, INNER_DIM) assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE assert tensor.dtype == TORCH_DTYPE
# print(tensors) # print(tensors)
for tensor in tensors: for tensor in tensors:
tensor[0][0][0][0] = 1.0 tensor[0][0][0][0][0] = 1.0
tensor[0][NUM_LAYER - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors) # print(tensors)
py_blocks_ = block_list.to_list() blocks_ = block_list.to_list()
assert py_blocks is not py_blocks_ assert blocks is not blocks_
assert len(py_blocks) == len(py_blocks_) assert len(blocks) == len(blocks_)
tensors_ = [torch.from_dlpack(b) for b in py_blocks_] tensors_ = [torch.from_dlpack(b) for b in blocks_]
for tensor, tensor_ in zip(tensors, tensors_): for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_ assert tensor is not tensor_
assert tensor.shape == tensor_.shape assert tensor.shape == tensor_.shape
...@@ -133,22 +136,22 @@ async def test_cpu_block_access(block_manager: BlockManager): ...@@ -133,22 +136,22 @@ async def test_cpu_block_access(block_manager: BlockManager):
async def test_gpu_block_access(block_manager: BlockManager): async def test_gpu_block_access(block_manager: BlockManager):
block_count = 6 block_count = 6
block_list = block_manager.allocate_device_blocks_blocking(block_count) block_list = block_manager.allocate_device_blocks_blocking(block_count)
py_blocks = block_list.to_list() blocks = block_list.to_list()
assert len(py_blocks) == block_count assert len(blocks) == block_count
tensors = [torch.from_dlpack(b) for b in py_blocks] tensors = [torch.from_dlpack(b) for b in blocks]
for tensor in tensors: for tensor in tensors:
assert tensor.get_device() == DEVICE_ID # GPU assert tensor.get_device() == DEVICE_ID # GPU
assert tensor.shape == (1, NUM_LAYER, PAGE_SIZE, INNER_DIM) assert tensor.shape == (1, NUM_LAYER, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE assert tensor.dtype == TORCH_DTYPE
# print(tensors) # print(tensors)
for tensor in tensors: for tensor in tensors:
tensor[0][0][0][0] = 1.0 tensor[0][0][0][0][0] = 1.0
tensor[0][NUM_LAYER - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0 tensor[0][NUM_LAYER - 1][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors) # print(tensors)
py_blocks_ = block_list.to_list() blocks_ = block_list.to_list()
assert py_blocks is not py_blocks_ assert blocks is not blocks_
assert len(py_blocks) == len(py_blocks_) assert len(blocks) == len(blocks_)
tensors_ = [torch.from_dlpack(b) for b in py_blocks_] tensors_ = [torch.from_dlpack(b) for b in blocks_]
for tensor, tensor_ in zip(tensors, tensors_): for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_ assert tensor is not tensor_
assert tensor.shape == tensor_.shape assert tensor.shape == tensor_.shape
...@@ -159,27 +162,27 @@ async def test_gpu_block_access(block_manager: BlockManager): ...@@ -159,27 +162,27 @@ async def test_gpu_block_access(block_manager: BlockManager):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_list_iteration(block_manager: BlockManager): async def test_block_list_iteration(block_manager: BlockManager):
block_count = 4 block_count = 4
block_list = block_manager.allocate_host_blocks_blocking(block_count) block_list = await block_manager.allocate_host_blocks(block_count)
# Test __len__() # Test __len__()
assert len(block_list) == block_count assert len(block_list) == block_count
# Test __getitem__() # Test __getitem__()
for i in range(block_count): for i in range(block_count):
block = block_list[i] block = block_list[i]
tensor = torch.from_dlpack(block) tensor = torch.from_dlpack(block)
tensor[0][0][0][0] = 1.0 + i tensor[0][0][0][0][0] = 1.0 + i
# Test __iter__() and __next__() # Test __iter__() and __next__()
idx = 1.0 idx = 1.0
for block in block_list: for block in block_list:
tensor = torch.from_dlpack(block) tensor = torch.from_dlpack(block)
assert tensor[0][0][0][0] == idx assert tensor[0][0][0][0][0] == idx
tensor[0][0][0][0] += 0.5 tensor[0][0][0][0][0] += 0.5
idx += 1.0 idx += 1.0
assert idx == 1.0 + block_count assert idx == 1.0 + block_count
# Test __iter__() should reset current index # Test __iter__() should reset current index
idx = 1.0 idx = 1.0
for block in block_list: for block in block_list:
tensor = torch.from_dlpack(block) tensor = torch.from_dlpack(block)
assert tensor[0][0][0][0] == idx + 0.5 assert tensor[0][0][0][0][0] == idx + 0.5
idx += 1.0 idx += 1.0
assert idx == 1.0 + block_count assert idx == 1.0 + block_count
...@@ -187,27 +190,37 @@ async def test_block_list_iteration(block_manager: BlockManager): ...@@ -187,27 +190,37 @@ async def test_block_list_iteration(block_manager: BlockManager):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_copy_g1_g2(block_manager: BlockManager): async def test_block_copy_g1_g2(block_manager: BlockManager):
# Allocate device (G1) and host (G2) block # Allocate device (G1) and host (G2) block
host_block_list = block_manager.allocate_host_blocks_blocking(1) host_block_list = await block_manager.allocate_host_blocks(1)
device_block_list = block_manager.allocate_device_blocks_blocking(1) device_block_list = await block_manager.allocate_device_blocks(1)
# Populate host block with unique values # Populate host block with unique values
host_tensor = torch.from_dlpack(host_block_list[0]) host_tensor = torch.from_dlpack(host_block_list[0])
for i in range(NUM_LAYER): for i in range(NUM_LAYER):
for j in range(PAGE_SIZE): for j in range(OUTER_DIM):
for k in range(INNER_DIM): for k in range(PAGE_SIZE):
host_tensor[0][i][j][k] = i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k for w in range(INNER_DIM):
host_tensor[0][i][j][k][w] = (
i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Copy host block to device block after permuting # Copy host block to device block after permuting
permute_dims = (0, 2, 3, 1) permute_dims = (0, 2, 4, 3, 1)
device_tensor_ = torch.from_dlpack(device_block_list[0]).permute(*permute_dims) device_tensor_ = torch.from_dlpack(device_block_list[0]).permute(*permute_dims)
device_tensor_.copy_(host_tensor.permute(*permute_dims)) device_tensor_.copy_(host_tensor.permute(*permute_dims))
# Assert device block is contiguous and updated in block manager # Assert device block is contiguous and updated in block manager
device_tensor = torch.from_dlpack(device_block_list[0]) device_tensor = torch.from_dlpack(device_block_list[0])
for i in range(NUM_LAYER): for i in range(NUM_LAYER):
for j in range(PAGE_SIZE): for j in range(OUTER_DIM):
for k in range(INNER_DIM): for k in range(PAGE_SIZE):
assert ( for w in range(INNER_DIM):
device_tensor[0][i][j][k] assert (
== i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k device_tensor[0][i][j][k][w]
) == i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Set host block to zero and assert updated in block manager # Set host block to zero and assert updated in block manager
host_tensor_ = torch.from_dlpack(host_block_list[0]).permute(*permute_dims) host_tensor_ = torch.from_dlpack(host_block_list[0]).permute(*permute_dims)
host_tensor_.zero_() host_tensor_.zero_()
...@@ -216,22 +229,166 @@ async def test_block_copy_g1_g2(block_manager: BlockManager): ...@@ -216,22 +229,166 @@ async def test_block_copy_g1_g2(block_manager: BlockManager):
host_tensor_.copy_(device_tensor_) host_tensor_.copy_(device_tensor_)
# Assert host block is updated in block manager # Assert host block is updated in block manager
for i in range(NUM_LAYER): for i in range(NUM_LAYER):
for j in range(PAGE_SIZE): for j in range(OUTER_DIM):
for k in range(INNER_DIM): for k in range(PAGE_SIZE):
assert ( for w in range(INNER_DIM):
host_tensor[0][i][j][k] assert (
== i * PAGE_SIZE * INNER_DIM + j * INNER_DIM + k host_tensor[0][i][j][k][w]
) == i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_cpu_layer_access(block_manager: BlockManager):
block_list = block_manager.allocate_host_blocks_blocking(1)
block = block_list[0]
layers = block.to_list()
assert len(layers) == NUM_LAYER
tensors = [torch.from_dlpack(bl) for bl in layers]
for tensor in tensors:
assert tensor.get_device() == -1 # CPU
assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
layers_ = block.to_list()
assert layers is not layers_
assert len(layers) == len(layers_)
tensors_ = [torch.from_dlpack(bl) for bl in layers_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_gpu_layer_access(block_manager: BlockManager):
block_list = block_manager.allocate_device_blocks_blocking(1)
block = block_list[0]
layers = block.to_list()
assert len(layers) == NUM_LAYER
tensors = [torch.from_dlpack(bl) for bl in layers]
for tensor in tensors:
assert tensor.get_device() == DEVICE_ID # GPU
assert tensor.shape == (1, 1, OUTER_DIM, PAGE_SIZE, INNER_DIM)
assert tensor.dtype == TORCH_DTYPE
# print(tensors)
for tensor in tensors:
tensor[0][0][0][0][0] = 1.0
tensor[0][0][OUTER_DIM - 1][PAGE_SIZE - 1][INNER_DIM - 1] = 1.0
# print(tensors)
layers_ = block.to_list()
assert layers is not layers_
assert len(layers) == len(layers_)
tensors_ = [torch.from_dlpack(bl) for bl in layers_]
for tensor, tensor_ in zip(tensors, tensors_):
assert tensor is not tensor_
assert tensor.shape == tensor_.shape
assert tensor.dtype == tensor_.dtype
assert torch.allclose(tensor, tensor_)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_iteration(block_manager: BlockManager):
block = (await block_manager.allocate_host_blocks(1))[0]
# Test __len__()
assert len(block) == NUM_LAYER
# Test __getitem__()
for i in range(NUM_LAYER):
layer = block[i]
tensor = torch.from_dlpack(layer)
tensor[0][0][0][0][0] = 1.0 + i
# Test __iter__() and __next__()
idx = 1.0
for layer in block:
tensor = torch.from_dlpack(layer)
assert tensor[0][0][0][0][0] == idx
tensor[0][0][0][0][0] += 0.5
idx += 1.0
assert idx == 1.0 + NUM_LAYER
# Test __iter__() should reset current index
idx = 1.0
for layer in block:
tensor = torch.from_dlpack(layer)
assert tensor[0][0][0][0][0] == idx + 0.5
idx += 1.0
assert idx == 1.0 + NUM_LAYER
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
async def test_block_layer_copy_g1_g2(block_manager: BlockManager):
# Allocate device (G1) and host (G2) block
host_block = (await block_manager.allocate_host_blocks(1))[0]
device_block = (await block_manager.allocate_device_blocks(1))[0]
# Populate host block at layer level with unique values
host_layer_tensors = [torch.from_dlpack(bl) for bl in host_block]
for i in range(NUM_LAYER):
host_layer_tensor = host_layer_tensors[i]
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
host_layer_tensor[0][0][j][k][w] = (
i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Copy host block to device block after permuting
permute_dims = (0, 2, 4, 3, 1)
host_block_tensor_ = torch.from_dlpack(host_block).permute(*permute_dims)
device_block_tensor_ = torch.from_dlpack(device_block).permute(*permute_dims)
device_block_tensor_.copy_(host_block_tensor_)
# Assert device block is contiguous and updated in block manager at layer level
device_layer_tensors = [torch.from_dlpack(bl) for bl in device_block]
for i in range(NUM_LAYER):
device_layer_tensor = device_layer_tensors[i]
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
device_layer_tensor[0][0][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
# Set host block to zero and assert updated in block manager
host_block_tensor = torch.from_dlpack(host_block)
host_block_tensor.zero_()
assert torch.all(host_block_tensor_ == 0)
# Copy device block back to host block
host_block_tensor_.copy_(device_block_tensor_)
# Assert host block is updated in block manager
for i in range(NUM_LAYER):
for j in range(OUTER_DIM):
for k in range(PAGE_SIZE):
for w in range(INNER_DIM):
assert (
host_block_tensor[0][i][j][k][w]
== i * OUTER_DIM * PAGE_SIZE * INNER_DIM
+ j * PAGE_SIZE * INNER_DIM
+ k * INNER_DIM
+ w
)
async def main(): async def main():
await test_block_manager_initialization() await test_block_manager_initialization()
await test_cpu_block_access(new_block_manager())
# todo: revise these tests to index into the block via block_id, layer_id, outer_id (k/v) await test_gpu_block_access(new_block_manager())
# await test_cpu_block_access() await test_block_list_iteration(new_block_manager())
# await test_gpu_block_access() await test_block_copy_g1_g2(new_block_manager())
# await test_block_list_iteration() await test_cpu_layer_access(new_block_manager())
# await test_block_copy_g1_g2() await test_gpu_layer_access(new_block_manager())
await test_block_iteration(new_block_manager())
await test_block_layer_copy_g1_g2(new_block_manager())
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment