lib.rs 9.66 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
// SPDX-License-Identifier: Apache-2.0

//! Clean, minimal storage API for v2 block manager.
//!
//! This module provides a simplified storage abstraction with:
7
//! - Single trait for type erasure (`MemoryDescriptor`)
8
9
10
11
//! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)

12
13
#![deny(missing_docs)]

14
15
16
pub mod actions;
pub mod arena;
pub mod nixl;
17
#[cfg(target_os = "linux")]
18
19
20
pub mod numa;

/// Offset-based buffer views into underlying storage.
21
pub mod offset;
22
23

/// CUDA memory pool utilities.
24
pub mod pool;
25
26

/// Common imports for working with memory types.
27
28
29
pub mod prelude;

mod device;
30
#[cfg(target_os = "linux")]
31
mod disk;
32
mod external;
33
34
mod pinned;
mod system;
35
mod tensor;
36
37
38
39
40
41

#[cfg(test)]
mod tests;

pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
pub use device::DeviceStorage;
42
#[cfg(target_os = "linux")]
43
pub use disk::DiskStorage;
44
pub use external::ExternalDeviceMemory;
45
#[cfg(target_os = "linux")]
46
pub use numa::{NumaNode, is_numa_disabled};
47
pub use offset::OffsetBuffer;
48
pub use pinned::PinnedStorage;
49
pub use pool::{CudaMemPool, CudaMemPoolBuilder};
50
pub use system::SystemStorage;
51
pub use tensor::{TensorDescriptor, TensorDescriptorExt};
52
53
54
55
56
57
58
59
60
61

use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;

/// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>;

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescriptor>`.
pub trait MemoryDescriptor: Send + Sync + fmt::Debug {
    /// Base address of the memory region.
    fn addr(&self) -> usize;

    /// Size of the memory region in bytes.
    fn size(&self) -> usize;

    /// Type of storage backing this region.
    fn storage_kind(&self) -> StorageKind;

    /// Enable downcasting to concrete type.
    fn as_any(&self) -> &dyn Any;

    /// Get the NIXL descriptor for this memory region.
    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
}

83
84
/// Errors that can occur during storage operations.
#[derive(Debug, Error)]
85
#[allow(missing_docs)]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
pub enum StorageError {
    #[error("allocation failed: {0}")]
    AllocationFailed(String),

    #[error("registration failed: {0}")]
    RegistrationFailed(String),

    #[error("operation failed: {0}")]
    OperationFailed(String),

    #[error("unsupported operation: {0}")]
    Unsupported(String),

    #[error("I/O error: {0}")]
    Io(#[from] std::io::Error),

    // #[cfg(feature = "cuda")]
    #[error("CUDA error: {0}")]
    Cuda(#[from] cudarc::driver::DriverError),

    #[error("NIXL error: {0}")]
    Nixl(#[from] nixl_sys::NixlError),
}

/// Storage type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageKind {
    /// System memory (malloc)
    System,

    /// CUDA pinned host memory
    // #[cfg(feature = "cuda")]
    Pinned,

    /// CUDA device memory with device ID
    // #[cfg(feature = "cuda")]
    Device(u32),

    /// Disk-backed memory (mmap)
    Disk(u64),
}

128
129
130
131
132
133
134
135
impl StorageKind {
    /// Returns the CUDA device index if this is device memory.
    pub fn cuda_device_index(&self) -> Option<u32> {
        match self {
            StorageKind::Device(idx) => Some(*idx),
            _ => None,
        }
    }
136

137
138
139
140
    /// Returns true if this is CUDA device memory.
    pub fn is_cuda(&self) -> bool {
        matches!(self, StorageKind::Device(_))
    }
141

142
143
144
145
    /// Returns true if this is system memory (malloc).
    pub fn is_system(&self) -> bool {
        matches!(self, StorageKind::System)
    }
146

147
148
149
150
    /// Returns true if this is CUDA pinned host memory.
    pub fn is_pinned(&self) -> bool {
        matches!(self, StorageKind::Pinned)
    }
151

152
153
154
155
    /// Returns true if this is disk-backed memory.
    pub fn is_disk(&self) -> bool {
        matches!(self, StorageKind::Disk(_))
    }
156
157
158
159
}

/// Type-erased memory region for use in layouts.
#[derive(Clone)]
160
161
162
163
164
165
166
167
168
169
170
pub struct Buffer(Arc<dyn MemoryDescriptor>);

impl Buffer {
    /// Wraps a concrete storage type into a type-erased [`Buffer`].
    ///
    /// This is the primary way to create a `Buffer` from any type that
    /// implements [`MemoryDescriptor`].
    pub fn new<S: MemoryDescriptor + 'static>(memory: S) -> Self {
        Buffer(Arc::new(memory))
    }
}
171

172
impl MemoryDescriptor for Buffer {
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    fn addr(&self) -> usize {
        self.0.addr()
    }
    fn size(&self) -> usize {
        self.0.size()
    }
    fn storage_kind(&self) -> StorageKind {
        self.0.storage_kind()
    }
    fn as_any(&self) -> &dyn Any {
        self.0.as_any()
    }
    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
        self.0.nixl_descriptor()
    }
}

impl std::ops::Deref for Buffer {
191
    type Target = dyn MemoryDescriptor;
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    fn deref(&self) -> &Self::Target {
        self.0.as_ref()
    }
}

impl std::fmt::Debug for Buffer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Buffer")
            .field("addr", &self.addr())
            .field("size", &self.size())
            .field("kind", &self.storage_kind())
            .finish()
    }
}

/// Helper function to convert concrete storage to type-erased form.
209
pub fn create_buffer<S: MemoryDescriptor + 'static>(memory: S) -> Buffer {
210
211
212
    Buffer(Arc::new(memory))
}

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
impl Buffer {
    /// Create a Buffer from an existing Arc<dyn MemoryDescriptor>.
    pub fn from_arc(arc: Arc<dyn MemoryDescriptor>) -> Self {
        Buffer(arc)
    }
}

// From implementations for ergonomic Buffer creation
impl From<Arc<dyn MemoryDescriptor>> for Buffer {
    fn from(arc: Arc<dyn MemoryDescriptor>) -> Self {
        Buffer::from_arc(arc)
    }
}

impl From<Arc<dyn nixl::NixlMemory + Send + Sync>> for Buffer {
    fn from(arc: Arc<dyn nixl::NixlMemory + Send + Sync>) -> Self {
        // Arc<dyn NixlMemory> implements MemoryDescriptor, so we can wrap it
        Buffer::new(arc)
    }
}

234
235
236
237
238
239
240
241
242
243
244
/// An unowned contiguous chunk of memory, not storage specific.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryRegion {
    /// Start address of the memory region.
    pub addr: usize,

    /// Size of the memory region in bytes.
    pub size: usize,
}

impl MemoryRegion {
245
    /// Creates a new memory region with the given base address and size.
246
247
248
249
    pub fn new(addr: usize, size: usize) -> Self {
        Self { addr, size }
    }

250
    /// Returns the base address of this memory region.
251
252
253
254
255
    #[inline]
    pub fn addr(&self) -> usize {
        self.addr
    }

256
    /// Returns the size of this memory region in bytes.
257
258
259
260
    #[inline]
    pub fn size(&self) -> usize {
        self.size
    }
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    /// Get a slice view of this memory region.
    ///
    /// # Safety
    /// This is unsafe because:
    /// - The caller must ensure the memory region is valid and properly initialized
    /// - The caller must ensure no mutable references exist to this memory
    /// - The caller must ensure the memory remains valid for the lifetime of the slice
    #[cfg(feature = "unsafe-slices")]
    pub unsafe fn as_slice(&self) -> Result<&[u8]> {
        if self.size == 0 {
            return Ok(&[]);
        }
        // SAFETY: Caller guarantees memory is valid
        unsafe {
            Ok(std::slice::from_raw_parts(
                self.addr as *const u8,
                self.size,
            ))
        }
    }

    /// Get a mutable slice view of this memory region.
    ///
    /// # Safety
    /// This is unsafe because:
    /// - The caller must ensure the memory region is valid and properly initialized
    /// - The caller must ensure no other references (mutable or immutable) exist to this memory
    /// - The caller must ensure the memory remains valid for the lifetime of the slice
    #[cfg(feature = "unsafe-slices")]
    pub unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8]> {
        if self.size == 0 {
            return Ok(&mut []);
        }
        // SAFETY: Caller guarantees memory is valid and exclusively accessible
        unsafe {
            Ok(std::slice::from_raw_parts_mut(
                self.addr as *mut u8,
                self.size,
            ))
        }
    }
303
}
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

/// Check if an environment variable is truthy
pub fn env_is_truthy(env: &str) -> bool {
    match std::env::var(env) {
        Ok(val) => matches!(val.to_lowercase().as_str(), "1" | "true" | "on" | "yes"),
        Err(_) => false,
    }
}

/// Parse a string as a boolean value, returning an error if invalid.
/// Do not use this unless you really need to support vague values. Prefer only allowing a specific
/// value.
///
/// This function strictly validates that the input is a valid boolean representation.
///
/// # Arguments
/// * `val` - The string value to parse
///
/// # Returns
/// * `Ok(true)` - For truthy values: "1", "true", "on", "yes" (case-insensitive)
/// * `Ok(false)` - For falsey values: "0", "false", "off", "no" (case-insensitive)
/// * `Err(_)` - For any other value
///
/// # Example
/// ```ignore
/// assert_eq!(parse_bool("true")?, true);
/// assert_eq!(parse_bool("0")?, false);
/// assert!(parse_bool("maybe").is_err());
/// ```
pub fn parse_bool(val: &str) -> anyhow::Result<bool> {
    if matches!(val.to_lowercase().as_str(), "1" | "true" | "on" | "yes") {
        Ok(true)
    } else if matches!(val.to_lowercase().as_str(), "0" | "false" | "off" | "no") {
        Ok(false)
    } else {
        anyhow::bail!(
            "Invalid boolean value: '{val}'. Expected one of: true/false, 1/0, on/off, yes/no",
        )
    }
}