nixl.rs 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! # NIXL Integration for Block Layouts 🤝
//!
//! This module extends the core block layout functionalities defined in the parent `layout` module
//! with [NIXL](http://github.com/ai-dynamo/nixl) specific capabilities. It enables block layouts,
//! whose underlying storage is NIXL-registerable, to be registered with a NIXL agent and
//! serialized into a format suitable for sharing and reconstruction in distributed environments.
//!
//! ## Key Features & Components
//!
//! ### 1. NIXL-Specific Layout Traits
//! - [`NixlLayout`]: An umbrella trait that augments a [`BlockLayout`]. It requires the layout's
//!   associated `StorageType` to implement [`NixlRegisterableStorage`]. This trait provides the
//!   `nixl_register` method to register all underlying storage regions of the layout with a NIXL agent.
//! - [`BlockLayoutNixlStorage`]: A trait implemented by layouts to provide NIXL-specific memory
//!   information like `mem_type` and `device_id` directly from the layout structure, typically
//!   derived from its underlying storage.
//! - [`ToSerializedNixlBlockLayout`]: Implemented by layouts that can be converted into a
//!   [`SerializedNixlBlockLayout`]. This involves capturing the layout configuration and the NIXL
//!   descriptors of its storage.
//!
//! ### 2. Serializable NIXL Layout
//! - [`SerializedNixlBlockLayout`]: A struct that holds the serialized representation (as `Vec<u8>`)
//!   of a NIXL-compatible block layout. It can be deserialized to reconstruct the layout, typically
//!   on a remote node, assuming the described NIXL memory regions are accessible.
//! - `NixlBlockLayoutKinds`: An internal enum used during serialization to differentiate between
//!   different types of layouts (e.g., `FullyContiguous`).
//! - `SerializableNixlLayout<C>`: An internal generic struct that captures the configuration (`C`),
//!   base offset, NIXL storage descriptors, and storage type for a specific layout kind.
//!
//! ### 3. Integration with Core Layouts
//! The module provides implementations of these NIXL traits for concrete layout types from the
//! parent module, such as [`FullyContiguous`]. For example:
//! - `FullyContiguous<S>` (where `S:` [`NixlRegisterableStorage`]) implements [`NixlLayout`], allowing
//!   its storage to be registered.
//! - It also implements [`ToSerializedNixlBlockLayout`], enabling its configuration and NIXL storage
//!   descriptors to be serialized.
//!
//! ### 4. Layout Creation and Allocation Extensions
//! The [`LayoutConfig`] from the parent module is extended with methods like:
//! - `create_layout`: To create a NIXL-aware layout from existing NIXL-registerable storage.
//! - `allocate_layout`: To allocate storage using a NIXL-registerable storage allocator and then
//!   create the NIXL-aware layout.
//!
//! ## Usage Flow
//!
//! 1.  **Create/Allocate Layout**: A block layout (e.g., [`FullyContiguous`]) is created or allocated,
//!     ensuring its underlying storage is NIXL-compatible (e.g., using [`SystemStorage`] that implements
//!     [`NixlRegisterableStorage`]).
//! 2.  **Register with NIXL**: The [`nixl_register`] method from the [`NixlLayout`] trait is called on the
//!     layout instance with a [`NixlAgent`].
//! 3.  **Serialize**: The [`serialize`] method from [`ToSerializedNixlBlockLayout`] is used to get a
//!     [`SerializedNixlBlockLayout`].
//! 4.  **Transmit**: The [`SerializedNixlBlockLayout`] (or its byte representation) is sent to another
//!     process/node.
//! 5.  **Deserialize**: On the receiving end, [`SerializedNixlBlockLayout::deserialize`] is called to
//!     reconstruct an `Arc<dyn BlockLayout<StorageType = NixlStorage>>`. This reconstructed layout now
//!     refers to the remote NIXL memory regions.
//!
//! ```rust
//! use dynamo_llm::block_manager::layout::{LayoutConfig, LayoutType};
//! use dynamo_llm::block_manager::layout::nixl::{NixlLayout, ToSerializedNixlBlockLayout, SerializedNixlBlockLayout};
//! use dynamo_llm::block_manager::storage::nixl::NixlAgent;
//! use dynamo_llm::block_manager::storage::PinnedAllocator; // Assuming PinnedStorage is NixlRegisterable
//! use std::sync::Arc;
//!
//! // Configuration
//! let config = LayoutConfig::builder()
//!     .num_blocks(10)
//!     .num_layers(2)
//!     .page_size(4)
//!     .inner_dim(13)
//!     .build().unwrap();
//!
//! // 1. Allocate a NIXL-compatible layout
//! let allocator = Arc::new(PinnedAllocator::new().unwrap()); // PinnedAllocator provides NixlRegisterable PinnedStorage
//! let mut layout = config.allocate_layout(LayoutType::FullyContiguous, allocator).unwrap();
//!
//! // 2. Register with NIXL Agent
//! let agent = NixlAgent::new("my_agent").unwrap();
//! layout.nixl_register(&agent, None).unwrap();
//!
//! // 3. Serialize the layout
//! let serialized_layout = layout.serialize().unwrap();
//!
//! // 4. (Transmit serialized_layout to another process)
//!
//! // 5. Deserialize on the other end
//! let reconstructed_layout = SerializedNixlBlockLayout::deserialize(&serialized_layout).unwrap();
//! println!("Reconstructed layout refers to storage type: {:?}", reconstructed_layout.storage_type());
//! ```
//!
//! This module effectively bridges the local layout definitions with the requirements of distributed memory management via NIXL.

use crate::block_manager::storage::StorageType;

use super::{BlockLayout, BlockLayoutConfig, LayoutConfig, LayoutError, LayoutType};

use super::super::storage::{
    nixl::{MemType, NixlAgent, NixlRegisterableStorage, NixlStorage, OptArgs},
    Storage, StorageAllocator,
};
use super::{FullyContiguous, FullyContiguousConfig};
use serde::{Deserialize, Serialize};
use std::sync::Arc;

/// Extends [BlockLayout] with NIXL-specific methods for registering with an NIXL agent.
pub trait NixlLayout: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout {
    /// Register the layout with an NIXL agent
    ///
    /// This will register all the individual memory regions associated with the [BlockLayout].
    fn nixl_register(
        &mut self,
        agent: &NixlAgent,
        opt_args: Option<&OptArgs>,
    ) -> anyhow::Result<()>;
}

/// Trait for providing NIXL-specific memory information
pub trait BlockLayoutNixlStorage {
    /// Returns the memory type of the storage
    fn mem_type(&self) -> MemType;

    /// Returns the device ID of the storage
    fn device_id(&self) -> u64;
}

// Umbrella impl for all BlockLayout types that are NixlRegisterableStorage
impl<T> NixlLayout for T
where
    T: BlockLayout + BlockLayoutNixlStorage + ToSerializedNixlBlockLayout + ?Sized, // Implement for any T that is BlockLayout (potentially unsized)
    T::StorageType: NixlRegisterableStorage, // T's associated StorageType must be NixlStorage
{
    fn nixl_register(
        &mut self,
        agent: &NixlAgent,
        opt_args: Option<&OptArgs>,
    ) -> anyhow::Result<()> {
        for storage in self.storage_mut() {
            storage.nixl_register(agent, opt_args)?;
        }
        Ok(())
    }
}

impl LayoutConfig {
    /// Create a new NIXL-aware layout from existing NIXL-registerable storage.
    pub fn create_layout<S: Storage + NixlRegisterableStorage>(
        &self,
        layout_type: LayoutType,
        storage: Vec<S>,
    ) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
        match layout_type {
            LayoutType::FullyContiguous => FullyContiguous::new(self.clone(), storage),
        }
    }

    /// Allocate a new NIXL-aware layout using a NIXL-registerable storage allocator.
    pub fn allocate_layout<S: Storage + NixlRegisterableStorage>(
        &self,
        layout_type: LayoutType,
        allocator: Arc<dyn StorageAllocator<S>>,
    ) -> Result<impl NixlLayout<StorageType = S>, LayoutError> {
        match layout_type {
            LayoutType::FullyContiguous => {
                FullyContiguous::allocate(self.clone(), allocator.as_ref())
            }
        }
    }
}

/// Trait to convert a BlockLayout instance into its NIXL-specific serializable representation.
pub trait ToSerializedNixlBlockLayout: BlockLayout<StorageType: NixlRegisterableStorage> {
    /// Converts the layout into a serializable format, ensuring it's backed by NIXL storage.
    /// Returns an error if the layout is not backed by storage providing NIXL descriptors.
    fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError>;
}

/// Serializable representation of a BlockLayout backed by NIXL storage.
#[derive(Serialize, Deserialize, Clone)]
pub struct SerializedNixlBlockLayout(Vec<u8>);

/// Enum representing the serializable state of different BlockLayout types
/// specifically when backed by NIXL-compatible storage.
#[derive(Serialize, Deserialize, Debug, Clone)]
enum NixlBlockLayoutKinds {
    FullyContiguous(SerializableNixlLayout<FullyContiguousConfig>),
    // Add variants for other layout types here
}

/// Serializable representation of FullyContiguous layout backed by NIXL storage.
#[derive(Serialize, Deserialize, Debug, Clone)]
struct SerializableNixlLayout<C: BlockLayoutConfig> {
    config: C,
    base_offset: usize,
    storage_descriptors: Vec<NixlStorage>,
    storage_type: StorageType,
}

impl<C> SerializableNixlLayout<C>
where
    C: BlockLayoutConfig + Serialize + for<'de> Deserialize<'de> + Clone + std::fmt::Debug,
{
    /// Create a new SerializableNixlLayout
    fn new(
        config: C,
        base_offset: usize,
        storage_descriptors: Vec<NixlStorage>,
        storage_type: StorageType,
    ) -> Self {
        Self {
            config,
            base_offset,
            storage_descriptors,
            storage_type,
        }
    }
}

impl<S: NixlRegisterableStorage> ToSerializedNixlBlockLayout for FullyContiguous<S> {
    fn serialize(&self) -> Result<SerializedNixlBlockLayout, LayoutError> {
        // Use accessors added previously
        let config = self.config.clone();
        let base_offset = self.base_offset;

        let storages = self.storage();

        if storages.len() != 1 {
            return Err(LayoutError::InvalidConfig(
                "FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
                    .to_string(),
            ));
        }

        // FullyContiguous uses a Vec<Storage>, but should only contain one element.
        let storage_instance = storages.first().ok_or_else(|| {
            LayoutError::OperationFailed("FullyContiguous requires one storage element".to_string())
        })?;

        let storage_descriptors =
            unsafe { storage_instance.as_nixl_descriptor() }.ok_or_else(|| {
                LayoutError::OperationFailed(
                    "Storage does not provide NIXL descriptors for serialization".to_string(),
                )
            })?;

        let serializable_data = SerializableNixlLayout::new(
            config,
            base_offset,
            vec![storage_descriptors],
            self.storage_type(),
        );

        let nixl_block_layout = NixlBlockLayoutKinds::FullyContiguous(serializable_data);

        Ok(SerializedNixlBlockLayout(serde_json::to_vec(
            &nixl_block_layout,
        )?))
    }
}

impl SerializedNixlBlockLayout {
    /// Reconstructs a dynamic BlockLayout trait object backed by NixlStorage
    /// from the serialized layout information.
    /// Assumes the NixlStorage regions described within already exist and are valid.
    pub fn deserialize(
        &self,
    ) -> Result<Arc<dyn BlockLayout<StorageType = NixlStorage>>, LayoutError> {
        let nixl_block_layout: NixlBlockLayoutKinds = serde_json::from_slice(&self.0)?;
        match nixl_block_layout {
            NixlBlockLayoutKinds::FullyContiguous(config) => {
                if config.storage_descriptors.len() != 1 {
                    return Err(LayoutError::InvalidConfig(
                        "FullyContiguous reconstruction expects exactly one NixlStorage descriptor"
                            .to_string(),
                    ));
                }
                // Clone the single NixlStorage descriptor to become the storage instance
                let storage = config.storage_descriptors[0].clone();

                // Use the internal constructor which skips allocation checks
                let layout = FullyContiguous::new_internal(
                    config.config.clone(),
                    storage, // Pass the NixlStorage instance
                    config.base_offset,
                    config.storage_type,
                )?;
                Ok(Arc::new(layout))
            } // Handle other variants when added...
        }
    }
}

impl<S> BlockLayoutNixlStorage for FullyContiguous<S>
where
    S: Storage + NixlRegisterableStorage,
{
    fn mem_type(&self) -> MemType {
        self.storage.mem_type()
    }

    fn device_id(&self) -> u64 {
        self.storage.device_id()
    }
}

#[cfg(test)]
mod tests {
    use super::super::*;
    use super::*;
    use crate::block_manager::storage::SystemAllocator;
    use dynamo_runtime::logging::init as init_logging;

    #[test]
    fn test_nixl_layout() {
        init_logging();

        let config = LayoutConfig::builder()
            .num_blocks(10)
            .num_layers(2)
            .page_size(4)
            .inner_dim(13)
            .build()
            .unwrap();

        config.validate().unwrap();

        let mut layout = FullyContiguous::allocate(config, &SystemAllocator).unwrap();
        let agent = NixlAgent::new("test").unwrap();

        tracing::info!("Registering layout");
        layout.nixl_register(&agent, None).unwrap();
        tracing::info!("Layout registered");
        let local_storage_type = layout.storage_type();

        let serialized = layout.serialize().unwrap();

        let remote_layout = SerializedNixlBlockLayout::deserialize(&serialized).unwrap();
        println!("Nixl layout: {:?}", remote_layout);
        let remote_storage_type = remote_layout.storage_type();

        assert_eq!(local_storage_type, remote_storage_type);

        drop(layout);
        tracing::info!("Layout dropped");
    }
}