nixl.rs 9.92 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
// SPDX-License-Identifier: Apache-2.0

//! NIXL registration wrapper for storage types.

mod agent;
mod config;

9
use super::{MemoryDescriptor, StorageKind};
10
11
use std::any::Any;
use std::fmt;
12
use std::sync::Arc;
13
14
15
16

pub use agent::NixlAgent;
pub use config::NixlBackendConfig;

17
pub use nixl_sys::{
Ryan Olson's avatar
Ryan Olson committed
18
19
    Agent, MemType, NotificationMap, OptArgs, RegistrationHandle, XferDescList, XferOp,
    XferRequest, is_stub,
20
};
21
pub use serde::{Deserialize, Serialize};
22
23
24
25
26
27
28
29
30

/// Trait for storage types that can be registered with NIXL.
pub trait NixlCompatible {
    /// Get parameters needed for NIXL registration.
    ///
    /// Returns (ptr, size, mem_type, device_id)
    fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
}

31
32
33
34
35
36
37
38
39
40
/// Combined trait for memory that can be registered with NIXL.
///
/// This supertrait enables type erasure via `Arc<dyn NixlMemory>`.
/// Any type implementing both `MemoryDescriptor` and `NixlCompatible`
/// automatically implements this trait via the blanket implementation.
pub trait NixlMemory: MemoryDescriptor + NixlCompatible {}

// Blanket impl - any type with both traits automatically implements NixlMemory
impl<T: MemoryDescriptor + NixlCompatible + ?Sized> NixlMemory for T {}

41
/// NIXL descriptor containing registration information.
42
43
44
///
/// This struct holds the information needed to describe a memory region
/// to NIXL for transfer operations.
45
#[derive(Debug, Clone, Serialize, Deserialize)]
46
pub struct NixlDescriptor {
47
    /// Base address of the memory region.
48
    pub addr: u64,
49
    /// Size of the memory region in bytes.
50
    pub size: usize,
51
    /// Type of memory (host, device, etc.).
52
    pub mem_type: MemType,
53
    /// Device identifier (GPU index for device memory, 0 for host memory).
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    pub device_id: u64,
}

impl nixl_sys::MemoryRegion for NixlDescriptor {
    unsafe fn as_ptr(&self) -> *const u8 {
        self.addr as *const u8
    }

    fn size(&self) -> usize {
        self.size
    }
}

impl nixl_sys::NixlDescriptor for NixlDescriptor {
    fn mem_type(&self) -> MemType {
        self.mem_type
    }

    fn device_id(&self) -> u64 {
        self.device_id
    }
}

/// View trait for accessing registration information without unwrapping.
pub trait RegisteredView {
    /// Get the name of the NIXL agent that registered this memory.
    fn agent_name(&self) -> &str;

    /// Get the NIXL descriptor for this registered memory.
    fn descriptor(&self) -> NixlDescriptor;
}

/// Wrapper for storage that has been registered with NIXL.
///
/// This wrapper ensures proper drop order: the registration handle is
/// dropped before the storage, ensuring deregistration happens before
/// the memory is freed.
pub struct NixlRegistered<S: NixlCompatible> {
    storage: S,
    handle: Option<RegistrationHandle>,
    agent_name: String,
}

impl<S: NixlCompatible> Drop for NixlRegistered<S> {
    fn drop(&mut self) {
        // Explicitly drop the registration handle first
        drop(self.handle.take());
        // Storage drops naturally after
    }
}

impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("NixlRegistered")
            .field("storage", &self.storage)
            .field("agent_name", &self.agent_name)
            .field("handle", &self.handle.is_some())
            .finish()
    }
}

115
impl<S: MemoryDescriptor + NixlCompatible + 'static> MemoryDescriptor for NixlRegistered<S> {
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    fn addr(&self) -> usize {
        self.storage.addr()
    }

    fn size(&self) -> usize {
        self.storage.size()
    }

    fn storage_kind(&self) -> StorageKind {
        self.storage.storage_kind()
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
        Some(self.descriptor())
    }
}

137
impl<S: MemoryDescriptor + NixlCompatible> RegisteredView for NixlRegistered<S> {
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    fn agent_name(&self) -> &str {
        &self.agent_name
    }

    fn descriptor(&self) -> NixlDescriptor {
        let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
        NixlDescriptor {
            addr: ptr as u64,
            size,
            mem_type,
            device_id,
        }
    }
}

153
impl<S: MemoryDescriptor + NixlCompatible> NixlRegistered<S> {
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    /// Get a reference to the underlying storage.
    pub fn storage(&self) -> &S {
        &self.storage
    }

    /// Get a mutable reference to the underlying storage.
    pub fn storage_mut(&mut self) -> &mut S {
        &mut self.storage
    }

    /// Check if the registration handle is still valid.
    pub fn is_registered(&self) -> bool {
        self.handle.is_some()
    }

    /// Consume this wrapper and return the underlying storage.
    ///
    /// This will deregister the storage from NIXL.
    pub fn into_storage(mut self) -> S {
        drop(self.handle.take());
        let mut this = std::mem::ManuallyDrop::new(self);
        unsafe {
            let storage = std::ptr::read(&this.storage);
            std::ptr::drop_in_place(&mut this.agent_name);
            storage
        }
    }
}

/// Register storage with a NIXL agent.
///
/// This consumes the storage and returns a `NixlRegistered` wrapper that
/// manages the registration lifetime. The registration handle will be
/// automatically dropped when the wrapper is dropped, ensuring proper
/// cleanup order.
///
/// # Arguments
/// * `storage` - The storage to register (consumed)
/// * `agent` - The NIXL agent to register with
/// * `opt` - Optional arguments for registration
///
/// # Returns
/// A `NixlRegistered` wrapper containing the storage and registration handle.
pub fn register_with_nixl<S>(
    storage: S,
199
    agent: &Agent,
200
201
202
    opt: Option<&OptArgs>,
) -> std::result::Result<NixlRegistered<S>, S>
where
203
    S: MemoryDescriptor + NixlCompatible,
204
{
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    // let storage_kind = storage.storage_kind();

    // // Determine if registration is needed based on storage type and available backends
    // let should_register = match storage_kind {
    //     StorageKind::System | StorageKind::Pinned => {
    //         // System/Pinned memory needs UCX for remote transfers
    //         agent.has_backend("UCX") || agent.has_backend("POSIX")
    //     }
    //     StorageKind::Device(_) => {
    //         // Device memory needs UCX for remote transfers OR GDS for direct disk transfers
    //         agent.has_backend("UCX") || agent.has_backend("GDS_MT")
    //     }
    //     StorageKind::Disk(_) => {
    //         // Disk storage needs POSIX for regular I/O OR GDS for GPU direct I/O
    //         agent.has_backend("POSIX") || agent.has_backend("GDS_MT")
    //     } // StorageKind::Object(_) => {
    //       //     // Object storage is always registered via NIXL's OBJ plugin
    //       //     agent.has_backend("OBJ")
    //       // }
    // };

    // this is not true for our future object storage. so let's rethink this.
    // for object, if there is no device_id or device_id is 0, then we need to register
    // alternatively, the object storage holds it's own internal metadata but does not
    // expose as a nixl descriptor, thus ObjectStorag will by default like all other storage
    // types have a None for nixl_descriptor(), and we will use the internal
    if storage.nixl_descriptor().is_some() {
        return Ok(NixlRegistered {
            storage,
            handle: None,
            agent_name: agent.name().to_string(),
        });
    }

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    // Get NIXL parameters
    let (ptr, size, mem_type, device_id) = storage.nixl_params();

    // Create a NIXL descriptor for registration
    let descriptor = NixlDescriptor {
        addr: ptr as u64,
        size,
        mem_type,
        device_id,
    };

    match agent.register_memory(&descriptor, opt) {
        Ok(handle) => Ok(NixlRegistered {
            storage,
            handle: Some(handle),
            agent_name: agent.name().to_string(),
        }),
        Err(_) => Err(storage),
    }
}
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

// =============================================================================
// Arc<dyn NixlMemory> support
// =============================================================================

impl NixlCompatible for Arc<dyn NixlMemory + Send + Sync> {
    fn nixl_params(&self) -> (*const u8, usize, MemType, u64) {
        (**self).nixl_params()
    }
}

impl MemoryDescriptor for Arc<dyn NixlMemory + Send + Sync> {
    fn addr(&self) -> usize {
        (**self).addr()
    }

    fn size(&self) -> usize {
        (**self).size()
    }

    fn storage_kind(&self) -> StorageKind {
        (**self).storage_kind()
    }

    fn as_any(&self) -> &dyn Any {
        (**self).as_any()
    }

    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
        (**self).nixl_descriptor()
    }
}

// =============================================================================
// Extension trait for ergonomic API
// =============================================================================

/// Extension trait providing ergonomic `.register()` method for NIXL registration.
///
/// This trait is automatically implemented for all types that implement both
/// `MemoryDescriptor` and `NixlCompatible`. Import this trait to use the
/// method syntax:
///
///
pub trait NixlRegisterExt: MemoryDescriptor + NixlCompatible + Sized {
    /// Get this memory as NIXL-registered.
    ///
    /// This operation is idempotent - it's a no-op if the memory is already registered.
    ///
    /// # Arguments
    /// * `agent` - The NIXL agent to register with
    /// * `opt` - Optional arguments for registration
    ///
    /// # Returns
    /// A `NixlRegistered` wrapper on success, or the original storage on failure.
    fn register(
        self,
        agent: &NixlAgent,
        opt: Option<&OptArgs>,
    ) -> std::result::Result<NixlRegistered<Self>, Self> {
        register_with_nixl(self, agent, opt)
    }
}

// Blanket impl for all compatible types
impl<T: MemoryDescriptor + NixlCompatible + Sized> NixlRegisterExt for T {}