lib.rs 4.47 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-License-Identifier: Apache-2.0

//! Clean, minimal storage API for v2 block manager.
//!
//! This module provides a simplified storage abstraction with:
//! - Single trait for type erasure (`MemoryDescription`)
//! - Concrete storage types (no trait implementations required)
//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
//! - RAII with proper drop ordering (registration handle drops before memory)

pub mod actions;
pub mod arena;
pub mod nixl;
pub mod offset;
16
pub mod pool;
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
pub mod prelude;

mod device;
mod disk;
mod pinned;
mod system;
mod torch;

#[cfg(test)]
mod tests;

pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
pub use device::DeviceStorage;
pub use disk::DiskStorage;
pub use pinned::PinnedStorage;
pub use system::SystemStorage;
pub use torch::{TorchDevice, TorchTensor};

use serde::{Deserialize, Serialize};
use std::any::Any;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;

/// Result type for storage operations.
pub type Result<T> = std::result::Result<T, StorageError>;

/// Errors that can occur during storage operations.
#[derive(Debug, Error)]
pub enum StorageError {
    #[error("allocation failed: {0}")]
    AllocationFailed(String),

    #[error("registration failed: {0}")]
    RegistrationFailed(String),

    #[error("operation failed: {0}")]
    OperationFailed(String),

    #[error("unsupported operation: {0}")]
    Unsupported(String),

    #[error("I/O error: {0}")]
    Io(#[from] std::io::Error),

    // #[cfg(feature = "cuda")]
    #[error("CUDA error: {0}")]
    Cuda(#[from] cudarc::driver::DriverError),

    #[error("NIXL error: {0}")]
    Nixl(#[from] nixl_sys::NixlError),
}

/// Storage type classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageKind {
    /// System memory (malloc)
    System,

    /// CUDA pinned host memory
    // #[cfg(feature = "cuda")]
    Pinned,

    /// CUDA device memory with device ID
    // #[cfg(feature = "cuda")]
    Device(u32),

    /// Disk-backed memory (mmap)
    Disk(u64),
}

/// Core trait for memory regions that can be type-erased.
///
/// This is the only trait in the storage API. Concrete storage types
/// implement this trait to enable type erasure via `Arc<dyn MemoryDescription>`.
pub trait MemoryDescription: Send + Sync + fmt::Debug {
    /// Base address of the memory region.
    fn addr(&self) -> usize;

    /// Size of the memory region in bytes.
    fn size(&self) -> usize;

    /// Type of storage backing this region.
    fn storage_kind(&self) -> StorageKind;

    /// Enable downcasting to concrete type.
    fn as_any(&self) -> &dyn Any;

    /// Get the NIXL descriptor for this memory region.
    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
}

/// Type-erased memory region for use in layouts.
#[derive(Clone)]
pub struct Buffer(Arc<dyn MemoryDescription>);

impl MemoryDescription for Buffer {
    fn addr(&self) -> usize {
        self.0.addr()
    }
    fn size(&self) -> usize {
        self.0.size()
    }
    fn storage_kind(&self) -> StorageKind {
        self.0.storage_kind()
    }
    fn as_any(&self) -> &dyn Any {
        self.0.as_any()
    }
    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
        self.0.nixl_descriptor()
    }
}

impl std::ops::Deref for Buffer {
    type Target = dyn MemoryDescription;

    fn deref(&self) -> &Self::Target {
        self.0.as_ref()
    }
}

impl std::fmt::Debug for Buffer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Buffer")
            .field("addr", &self.addr())
            .field("size", &self.size())
            .field("kind", &self.storage_kind())
            .finish()
    }
}

/// Helper function to convert concrete storage to type-erased form.
pub fn create_buffer<S: MemoryDescription + 'static>(memory: S) -> Buffer {
    Buffer(Arc::new(memory))
}

/// An unowned contiguous chunk of memory, not storage specific.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MemoryRegion {
    /// Start address of the memory region.
    pub addr: usize,

    /// Size of the memory region in bytes.
    pub size: usize,
}

impl MemoryRegion {
    pub fn new(addr: usize, size: usize) -> Self {
        Self { addr, size }
    }

    #[inline]
    pub fn addr(&self) -> usize {
        self.addr
    }

    #[inline]
    pub fn size(&self) -> usize {
        self.size
    }
}