block.rs 7.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#![cfg(feature = "block-manager")]

use super::*;
use dynamo_llm::block_manager::block::BlockDataExt;
Ryan Olson's avatar
Ryan Olson committed
20
use dynamo_llm::block_manager::block::BlockDataProviderMut;
21
22
23
24
25
use pyo3::{
    types::{PyList, PyTuple},
    PyObject, PyResult, Python,
};
use std::sync::{Arc, Mutex};
26
27
28
29
30

pub enum BlockType {
    Pinned(
        dynamo_llm::block_manager::block::MutableBlock<
            dynamo_llm::block_manager::storage::PinnedStorage,
Ryan Olson's avatar
Ryan Olson committed
31
            dynamo_llm::block_manager::block::locality::Local,
32
33
34
35
36
37
            dynamo_llm::block_manager::block::BasicMetadata,
        >,
    ),
    Device(
        dynamo_llm::block_manager::block::MutableBlock<
            dynamo_llm::block_manager::storage::DeviceStorage,
Ryan Olson's avatar
Ryan Olson committed
38
            dynamo_llm::block_manager::block::locality::Local,
39
40
41
42
43
44
45
46
47
48
49
            dynamo_llm::block_manager::block::BasicMetadata,
        >,
    ),
}

#[pyclass]
pub struct Block {
    inner: Arc<Mutex<BlockType>>,
    // TODO: Metadata should be stored in the block manager?
    dtype: dynamo_llm::common::dtype::DType,
    device_id: usize,
50
51
    // Python iterator state
    py_itr_idx: usize,
52
53
54
55
56
57
58
59
60
61
}

impl Block {
    pub fn from_rust(
        block: Arc<Mutex<BlockType>>,
        dtype: dynamo_llm::common::dtype::DType,
        device_id: usize,
    ) -> Self {
        Self {
            inner: block,
Ryan Olson's avatar
Ryan Olson committed
62
63
            dtype,
            device_id,
64
65
66
67
68
69
70
71
72
            py_itr_idx: 0,
        }
    }

    fn num_layers(&self) -> usize {
        let mutable_block = self.inner.lock().unwrap();
        match &*mutable_block {
            BlockType::Pinned(block) => block.num_layers(),
            BlockType::Device(block) => block.num_layers(),
73
74
75
76
77
78
        }
    }
}

#[pymethods]
impl Block {
79
80
81
82
    #[pyo3(signature = ())]
    fn to_list<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
        let layers: Vec<layer::Layer> = (0..self.num_layers())
            .map(|layer_idx| {
Ryan Olson's avatar
Ryan Olson committed
83
                layer::Layer::from_rust(self.inner.clone(), layer_idx, self.dtype, self.device_id)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            })
            .collect();
        PyList::new(py, layers)
    }

    fn __len__(&self) -> PyResult<usize> {
        Ok(self.num_layers())
    }

    fn __getitem__(&self, index: usize) -> PyResult<layer::Layer> {
        let num_layers = self.num_layers();
        if index >= num_layers {
            return Err(pyo3::exceptions::PyIndexError::new_err(format!(
                "Index {} out of range for Block with {} layers",
                index, num_layers
            )));
        }
Ryan Olson's avatar
Ryan Olson committed
101
        let layer = layer::Layer::from_rust(self.inner.clone(), index, self.dtype, self.device_id);
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        Ok(layer)
    }

    fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
        // Reset iterator index at the beginning of each iteration
        // Use to_list() for iterating concurrently
        slf.py_itr_idx = 0;
        Ok(slf)
    }

    fn __next__(&mut self) -> PyResult<layer::Layer> {
        if self.py_itr_idx >= self.num_layers() {
            return Err(pyo3::exceptions::PyStopIteration::new_err(
                "No more items in Block",
            ));
        }
        let layer = layer::Layer::from_rust(
            self.inner.clone(),
            self.py_itr_idx,
Ryan Olson's avatar
Ryan Olson committed
121
            self.dtype,
122
123
124
125
126
127
            self.device_id,
        );
        self.py_itr_idx += 1;
        Ok(layer)
    }

128
    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
129
    fn __dlpack__<'py>(
130
        &self,
131
        py: Python<'py>,
132
133
134
135
136
        stream: Option<PyObject>,
        max_version: Option<PyObject>,
        dl_device: Option<PyObject>,
        copy: Option<bool>,
    ) -> PyResult<PyObject> {
137
        // Return error if any arguments are provided
138
        if stream.is_some() {
139
140
141
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "stream argument is not supported",
            ));
142
143
        }
        if max_version.is_some() {
144
145
146
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "max_version argument is not supported",
            ));
147
148
        }
        if dl_device.is_some() {
149
150
151
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "dl_device argument is not supported",
            ));
152
153
        }
        if copy.is_some() {
154
155
156
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "copy argument is not supported",
            ));
157
158
        }

159
160
161
162
163
164
165
166
167
168
169
        // Extract all necessary data for dlpack
        let ptr: *mut std::ffi::c_void;
        let num_blocks: i64;
        let num_layers: i64;
        let num_outer_dims: i64;
        let page_size: i64;
        let inner_dim: i64;
        {
            let mut mutable_block = self.inner.lock().unwrap();
            ptr = match &mut *mutable_block {
                BlockType::Pinned(block) => {
Ryan Olson's avatar
Ryan Olson committed
170
171
172
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
                    let mut block_view_mut = block_data.block_view_mut().map_err(to_pyerr)?;
173
174
175
                    (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
                BlockType::Device(block) => {
Ryan Olson's avatar
Ryan Olson committed
176
177
178
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
                    let mut block_view_mut = block_data.block_view_mut().map_err(to_pyerr)?;
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                    (unsafe { block_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
            };
            (num_blocks, num_layers, num_outer_dims, page_size, inner_dim) = match &*mutable_block {
                BlockType::Pinned(block) => (
                    block.num_blocks() as i64,
                    block.num_layers() as i64,
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
                BlockType::Device(block) => (
                    block.num_blocks() as i64,
                    block.num_layers() as i64,
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
197
            };
198
199
200
201
202
203
204
205
        }

        // Create the DLPack tensor
        dlpack::dlpack(
            py,
            self.inner.clone(),
            ptr,
            vec![num_blocks, num_layers, num_outer_dims, page_size, inner_dim],
Ryan Olson's avatar
Ryan Olson committed
206
            self.dtype,
207
208
            self.device_id,
        )
209
210
    }

211
212
213
    #[pyo3(signature = ())]
    fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
        dlpack::dlpack_device(py, self.inner.clone(), self.device_id)
214
    }
215
}