block.rs 6.76 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
// SPDX-License-Identifier: Apache-2.0

#![cfg(feature = "block-manager")]

use super::*;
use dynamo_llm::block_manager::block::BlockDataExt;
Ryan Olson's avatar
Ryan Olson committed
8
use dynamo_llm::block_manager::block::BlockDataProviderMut;
9
10
11
12
13
use pyo3::{
    types::{PyList, PyTuple},
    PyObject, PyResult, Python,
};
use std::sync::{Arc, Mutex};
14
15
16
17
18

pub enum BlockType {
    Pinned(
        dynamo_llm::block_manager::block::MutableBlock<
            dynamo_llm::block_manager::storage::PinnedStorage,
Ryan Olson's avatar
Ryan Olson committed
19
            dynamo_llm::block_manager::block::locality::Local,
20
21
22
23
24
25
            dynamo_llm::block_manager::block::BasicMetadata,
        >,
    ),
    Device(
        dynamo_llm::block_manager::block::MutableBlock<
            dynamo_llm::block_manager::storage::DeviceStorage,
Ryan Olson's avatar
Ryan Olson committed
26
            dynamo_llm::block_manager::block::locality::Local,
27
28
29
30
31
32
33
34
35
36
37
            dynamo_llm::block_manager::block::BasicMetadata,
        >,
    ),
}

#[pyclass]
pub struct Block {
    inner: Arc<Mutex<BlockType>>,
    // TODO: Metadata should be stored in the block manager?
    dtype: dynamo_llm::common::dtype::DType,
    device_id: usize,
38
39
    // Python iterator state
    py_itr_idx: usize,
40
41
42
43
44
45
46
47
48
49
}

impl Block {
    pub fn from_rust(
        block: Arc<Mutex<BlockType>>,
        dtype: dynamo_llm::common::dtype::DType,
        device_id: usize,
    ) -> Self {
        Self {
            inner: block,
Ryan Olson's avatar
Ryan Olson committed
50
51
            dtype,
            device_id,
52
53
54
55
56
57
58
59
60
            py_itr_idx: 0,
        }
    }

    fn num_layers(&self) -> usize {
        let mutable_block = self.inner.lock().unwrap();
        match &*mutable_block {
            BlockType::Pinned(block) => block.num_layers(),
            BlockType::Device(block) => block.num_layers(),
61
62
63
64
65
66
        }
    }
}

#[pymethods]
impl Block {
67
68
69
70
    #[pyo3(signature = ())]
    fn to_list<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
        let layers: Vec<layer::Layer> = (0..self.num_layers())
            .map(|layer_idx| {
Ryan Olson's avatar
Ryan Olson committed
71
                layer::Layer::from_rust(self.inner.clone(), layer_idx, self.dtype, self.device_id)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            })
            .collect();
        PyList::new(py, layers)
    }

    fn __len__(&self) -> PyResult<usize> {
        Ok(self.num_layers())
    }

    fn __getitem__(&self, index: usize) -> PyResult<layer::Layer> {
        let num_layers = self.num_layers();
        if index >= num_layers {
            return Err(pyo3::exceptions::PyIndexError::new_err(format!(
                "Index {} out of range for Block with {} layers",
                index, num_layers
            )));
        }
Ryan Olson's avatar
Ryan Olson committed
89
        let layer = layer::Layer::from_rust(self.inner.clone(), index, self.dtype, self.device_id);
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        Ok(layer)
    }

    fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
        // Reset iterator index at the beginning of each iteration
        // Use to_list() for iterating concurrently
        slf.py_itr_idx = 0;
        Ok(slf)
    }

    fn __next__(&mut self) -> PyResult<layer::Layer> {
        if self.py_itr_idx >= self.num_layers() {
            return Err(pyo3::exceptions::PyStopIteration::new_err(
                "No more items in Block",
            ));
        }
        let layer = layer::Layer::from_rust(
            self.inner.clone(),
            self.py_itr_idx,
Ryan Olson's avatar
Ryan Olson committed
109
            self.dtype,
110
111
112
113
114
115
            self.device_id,
        );
        self.py_itr_idx += 1;
        Ok(layer)
    }

116
    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
117
    fn __dlpack__<'py>(
118
        &self,
119
        py: Python<'py>,
120
121
122
123
124
        stream: Option<PyObject>,
        max_version: Option<PyObject>,
        dl_device: Option<PyObject>,
        copy: Option<bool>,
    ) -> PyResult<PyObject> {
125
        // Return error if any arguments are provided
126
        if stream.is_some() {
127
128
129
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "stream argument is not supported",
            ));
130
131
        }
        if max_version.is_some() {
132
133
134
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "max_version argument is not supported",
            ));
135
136
        }
        if dl_device.is_some() {
137
138
139
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "dl_device argument is not supported",
            ));
140
141
        }
        if copy.is_some() {
142
143
144
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "copy argument is not supported",
            ));
145
146
        }

147
148
149
150
151
152
153
154
155
156
157
        // Extract all necessary data for dlpack
        let ptr: *mut std::ffi::c_void;
        let num_blocks: i64;
        let num_layers: i64;
        let num_outer_dims: i64;
        let page_size: i64;
        let inner_dim: i64;
        {
            let mut mutable_block = self.inner.lock().unwrap();
            ptr = match &mut *mutable_block {
                BlockType::Pinned(block) => {
Ryan Olson's avatar
Ryan Olson committed
158
159
160
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
                    let mut block_view_mut = block_data.block_view_mut().map_err(to_pyerr)?;
161
162
163
                    (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
                BlockType::Device(block) => {
Ryan Olson's avatar
Ryan Olson committed
164
165
166
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
                    let mut block_view_mut = block_data.block_view_mut().map_err(to_pyerr)?;
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
                    (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
            };
            (num_blocks, num_layers, num_outer_dims, page_size, inner_dim) = match &*mutable_block {
                BlockType::Pinned(block) => (
                    block.num_blocks() as i64,
                    block.num_layers() as i64,
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
                BlockType::Device(block) => (
                    block.num_blocks() as i64,
                    block.num_layers() as i64,
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
185
            };
186
187
188
189
190
191
192
193
        }

        // Create the DLPack tensor
        dlpack::dlpack(
            py,
            self.inner.clone(),
            ptr,
            vec![num_blocks, num_layers, num_outer_dims, page_size, inner_dim],
Ryan Olson's avatar
Ryan Olson committed
194
            self.dtype,
195
196
            self.device_id,
        )
197
198
    }

199
200
201
    #[pyo3(signature = ())]
    fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
        dlpack::dlpack_device(py, self.inner.clone(), self.device_id)
202
    }
203
}