block_manager.rs 8.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#![cfg(feature = "block-manager")]

use super::*;
use pyo3::PyResult;

mod block;
mod block_list;
23
24
mod dlpack;
mod layer;
25
26
27

/// Add bingings from this crate to the provided module
pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
28
    m.add_class::<layer::Layer>()?;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    m.add_class::<block::Block>()?;
    m.add_class::<block_list::BlockList>()?;
    m.add_class::<BlockManager>()?;
    Ok(())
}

#[pyclass]
pub struct BlockManager {
    inner: Arc<dynamo_llm::block_manager::ReferenceBlockManager>,
    // TODO: Metadata should be stored in the block manager?
    dtype: dynamo_llm::common::dtype::DType,
    device_id: usize,
}

#[pymethods]
impl BlockManager {
    #[new]
46
    #[pyo3(signature = (worker_id, num_layer, outer_dim, page_size, inner_dim, dtype=None, host_num_blocks=None, device_num_blocks=None, device_id=0))]
47
48
49
    fn new(
        worker_id: u64,
        num_layer: usize,
50
        outer_dim: usize,
51
52
53
54
55
56
57
58
59
60
61
        page_size: usize,
        inner_dim: usize,
        dtype: Option<String>,
        host_num_blocks: Option<usize>,
        device_num_blocks: Option<usize>,
        device_id: usize,
    ) -> PyResult<Self> {
        let mut config = dynamo_llm::block_manager::KvBlockManagerConfig::builder().runtime(
            dynamo_llm::block_manager::KvManagerRuntimeConfig::builder()
                .worker_id(worker_id)
                .build()
62
                .map_err(to_pyerr)?,
63
64
65
        );
        let mut model_config = dynamo_llm::block_manager::KvManagerModelConfig::builder()
            .num_layers(num_layer)
66
            .outer_dim(outer_dim)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            .page_size(page_size)
            .inner_dim(inner_dim);
        let mut dtype_ = dynamo_llm::common::dtype::DType::FP16; // Default in block_manager config
        if let Some(dtype_str) = dtype {
            dtype_ = match dtype_str.as_str() {
                "fp8" | "FP8" => dynamo_llm::common::dtype::DType::FP8,
                "fp16" | "FP16" => dynamo_llm::common::dtype::DType::FP16,
                "bf16" | "BF16" => dynamo_llm::common::dtype::DType::BF16,
                "fp32" | "FP32" => dynamo_llm::common::dtype::DType::FP32,
                "u8" | "U8" => dynamo_llm::common::dtype::DType::U8,
                "u16" | "U16" => dynamo_llm::common::dtype::DType::U16,
                "u32" | "U32" => dynamo_llm::common::dtype::DType::U32,
                "u64" | "U64" => dynamo_llm::common::dtype::DType::U64,
                "i8" | "I8" => dynamo_llm::common::dtype::DType::I8,
                "i16" | "I16" => dynamo_llm::common::dtype::DType::I16,
                "i32" | "I32" => dynamo_llm::common::dtype::DType::I32,
                "i64" | "I64" => dynamo_llm::common::dtype::DType::I64,
                _ => {
                    return Err(pyo3::exceptions::PyValueError::new_err(format!(
                        "Unsupported dtype: {}",
                        dtype_str
                    )))
                }
            };
        }
        model_config = model_config.dtype(dtype_.clone());
93
        config = config.model(model_config.build().map_err(to_pyerr)?);
94
95
96
97
        if let Some(host_num_blocks) = host_num_blocks {
            config = config.host_layout(
                dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
                    .num_blocks(host_num_blocks)
98
99
100
101
                    .allocator(
                        dynamo_llm::block_manager::storage::PinnedAllocator::new()
                            .map_err(to_pyerr)?,
                    )
102
                    .build()
103
                    .map_err(to_pyerr)?,
104
105
106
107
108
109
110
111
            );
        }
        if let Some(device_num_blocks) = device_num_blocks {
            config = config.device_layout(
                dynamo_llm::block_manager::KvManagerLayoutConfig::builder()
                    .num_blocks(device_num_blocks)
                    .allocator(
                        dynamo_llm::block_manager::storage::DeviceAllocator::new(device_id)
112
                            .map_err(to_pyerr)?,
113
114
                    )
                    .build()
115
                    .map_err(to_pyerr)?,
116
117
            );
        }
118
119
        let config = config.build().map_err(to_pyerr)?;
        let tokio_runtime = pyo3_async_runtimes::tokio::get_runtime();
120
        Ok(BlockManager {
121
122
123
124
125
126
127
            inner: Arc::from(
                tokio_runtime
                    .block_on(async {
                        dynamo_llm::block_manager::ReferenceBlockManager::new(config)
                    })
                    .map_err(to_pyerr)?,
            ),
128
129
130
131
132
133
134
135
136
            dtype: dtype_,
            device_id: device_id,
        })
    }

    fn allocate_host_blocks_blocking(&self, count: usize) -> PyResult<block_list::BlockList> {
        let blocks = self
            .inner
            .host()
137
138
139
            .ok_or_else(|| {
                pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
            })?
140
            .allocate_blocks_blocking(count)
141
            .map_err(to_pyerr)?;
142
143
144
145
146
147
148
149
150
151
152
153
        // Wrap each block in an enum accounting for Pinned & Device block
        let blocks = blocks
            .into_iter()
            .map(|b| block::BlockType::Pinned(b))
            .collect();
        Ok(block_list::BlockList::from_rust(
            blocks,
            self.dtype.clone(),
            self.device_id,
        ))
    }

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    #[pyo3(signature = (count))]
    fn allocate_host_blocks<'py>(
        &self,
        py: Python<'py>,
        count: usize,
    ) -> PyResult<Bound<'py, PyAny>> {
        let inner = self.inner.clone();
        let dtype = self.dtype.clone();
        let device_id = self.device_id;
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let blocks = inner
                .host()
                .ok_or_else(|| {
                    pyo3::exceptions::PyRuntimeError::new_err("Host allocator not available")
                })?
                .allocate_blocks(count)
                .await
                .map_err(to_pyerr)?;
            // Wrap each block in an enum accounting for Pinned & Device block
            let blocks = blocks
                .into_iter()
                .map(|b| block::BlockType::Pinned(b))
                .collect();
            Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
        })
    }

181
182
183
184
    fn allocate_device_blocks_blocking(&self, count: usize) -> PyResult<block_list::BlockList> {
        let blocks = self
            .inner
            .device()
185
186
187
            .ok_or_else(|| {
                pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
            })?
188
            .allocate_blocks_blocking(count)
189
            .map_err(to_pyerr)?;
190
191
192
193
194
195
196
197
198
199
200
        // Wrap each block in an enum accounting for Pinned & Device block
        let blocks = blocks
            .into_iter()
            .map(|b| block::BlockType::Device(b))
            .collect();
        Ok(block_list::BlockList::from_rust(
            blocks,
            self.dtype.clone(),
            self.device_id,
        ))
    }
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    #[pyo3(signature = (count))]
    fn allocate_device_blocks<'py>(
        &self,
        py: Python<'py>,
        count: usize,
    ) -> PyResult<Bound<'py, PyAny>> {
        let inner = self.inner.clone();
        let dtype = self.dtype.clone();
        let device_id = self.device_id;
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            let blocks = inner
                .device()
                .ok_or_else(|| {
                    pyo3::exceptions::PyRuntimeError::new_err("Device allocator not available")
                })?
                .allocate_blocks(count)
                .await
                .map_err(to_pyerr)?;
            // Wrap each block in an enum accounting for Pinned & Device block
            let blocks = blocks
                .into_iter()
                .map(|b| block::BlockType::Device(b))
                .collect();
            Ok(block_list::BlockList::from_rust(blocks, dtype, device_id))
        })
    }
228
}