layer.rs 4.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#![cfg(feature = "block-manager")]

use super::*;
use dynamo_llm::block_manager::block::BlockDataExt;
Ryan Olson's avatar
Ryan Olson committed
20
use dynamo_llm::block_manager::block::BlockDataProviderMut;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use pyo3::{types::PyTuple, PyObject, PyResult, Python};
use std::sync::{Arc, Mutex};

// Layer struct that represents a layer within a block
#[pyclass]
pub struct Layer {
    inner: Arc<Mutex<block::BlockType>>,
    layer_idx: usize,
    dtype: dynamo_llm::common::dtype::DType,
    device_id: usize,
}

impl Layer {
    pub fn from_rust(
        block: Arc<Mutex<block::BlockType>>,
        layer_idx: usize,
        dtype: dynamo_llm::common::dtype::DType,
        device_id: usize,
    ) -> Self {
        Self {
            inner: block,
            layer_idx,
            dtype,
            device_id,
        }
    }
}

#[pymethods]
impl Layer {
    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
    fn __dlpack__<'py>(
        &self,
        py: Python<'py>,
        stream: Option<PyObject>,
        max_version: Option<PyObject>,
        dl_device: Option<PyObject>,
        copy: Option<bool>,
    ) -> PyResult<PyObject> {
        // Return error if any arguments are provided
        if stream.is_some() {
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "stream argument is not supported",
            ));
        }
        if max_version.is_some() {
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "max_version argument is not supported",
            ));
        }
        if dl_device.is_some() {
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "dl_device argument is not supported",
            ));
        }
        if copy.is_some() {
            return Err(pyo3::exceptions::PyNotImplementedError::new_err(
                "copy argument is not supported",
            ));
        }

        // Extract all necessary data for dlpack
        let ptr: *mut std::ffi::c_void;
        let num_outer_dims: i64;
        let page_size: i64;
        let inner_dim: i64;
        {
            let mut mutable_block = self.inner.lock().unwrap();
            ptr = match &mut *mutable_block {
                block::BlockType::Pinned(block) => {
Ryan Olson's avatar
Ryan Olson committed
91
92
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
93
                    let mut layer_view_mut =
Ryan Olson's avatar
Ryan Olson committed
94
                        block_data.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?;
95
96
97
                    (unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
                block::BlockType::Device(block) => {
Ryan Olson's avatar
Ryan Olson committed
98
99
                    use dynamo_llm::block_manager::block::private::PrivateToken;
                    let block_data = block.block_data_mut(PrivateToken);
100
                    let mut layer_view_mut =
Ryan Olson's avatar
Ryan Olson committed
101
                        block_data.layer_view_mut(self.layer_idx, 0).map_err(to_pyerr)?;
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
                    (unsafe { layer_view_mut.as_mut_ptr() }) as *mut std::ffi::c_void
                }
            };
            (num_outer_dims, page_size, inner_dim) = match &*mutable_block {
                block::BlockType::Pinned(block) => (
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
                block::BlockType::Device(block) => (
                    block.num_outer_dims() as i64,
                    block.page_size() as i64,
                    block.inner_dim() as i64,
                ),
            };
        }

        // Create the DLPack tensor
        dlpack::dlpack(
            py,
            self.inner.clone(),
            ptr,
            vec![1, 1, num_outer_dims, page_size, inner_dim],
Ryan Olson's avatar
Ryan Olson committed
125
            self.dtype,
126
127
128
129
130
131
132
133
134
            self.device_id,
        )
    }

    #[pyo3(signature = ())]
    fn __dlpack_device__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
        dlpack::dlpack_device(py, self.inner.clone(), self.device_id)
    }
}