mod.rs 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! NUMA-aware memory allocation utilities.
//!
//! This module provides utilities for NUMA-aware memory allocation, which is critical
//! for optimal performance on multi-socket systems with GPUs. Memory allocated on the
//! NUMA node closest to the target GPU has significantly lower access latency.
//!
//! ## Architecture
//!
//! - [`NumaNode`]: Represents a NUMA node ID
//! - [`topology`]: Reads CPU-to-NUMA mapping from `/sys/devices/system/node`
//! - [`worker_pool`]: Dedicated worker threads pinned to specific NUMA nodes
//!
//! ## Usage
//!
18
//! NUMA optimization is enabled by default. To disable it:
19
//! ```bash
20
//! export DYN_MEMORY_DISABLE_NUMA=1
21
22
23
24
//! ```
//!
//! When enabled, pinned memory allocations are routed through NUMA workers
//! that are pinned to the target GPU's NUMA node, ensuring first-touch policy
25
26
//! places pages on the correct node. If the GPU's NUMA node cannot be
//! determined, allocation falls back to the non-NUMA path transparently.
27
28
29
30

pub mod topology;
pub mod worker_pool;

31
use cudarc::driver::sys::CUdevice_attribute_enum;
32
33
use nix::libc;
use serde::{Deserialize, Serialize};
34
35
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};
36
use std::{fs, mem, process::Command};
37

38
39
40
41
42
/// Cache for GPU PCI address → NUMA node lookups.
/// The mapping never changes at runtime, so we cache results (including negative
/// lookups) to avoid repeated sysfs reads and nvidia-smi subprocesses.
static NUMA_NODE_CACHE: OnceLock<Mutex<HashMap<String, Option<NumaNode>>>> = OnceLock::new();

43
/// Check if NUMA optimization is disabled via environment variable.
44
///
45
46
47
48
/// NUMA-aware allocation is enabled by default. Set `DYN_MEMORY_DISABLE_NUMA=1`
/// (or any truthy value) to disable it.
pub fn is_numa_disabled() -> bool {
    dynamo_config::env_is_truthy("DYN_MEMORY_DISABLE_NUMA")
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
}

/// Represents a NUMA node identifier.
///
/// NUMA nodes are typically numbered 0, 1, 2, etc. corresponding to physical
/// CPU sockets. Use [`NumaNode::UNKNOWN`] when the node cannot be determined.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct NumaNode(pub u32);

impl NumaNode {
    /// Sentinel value for unknown NUMA node.
    pub const UNKNOWN: NumaNode = NumaNode(u32::MAX);

    /// Returns true if this represents an unknown NUMA node.
    pub fn is_unknown(&self) -> bool {
        self.0 == u32::MAX
    }
}

impl std::fmt::Display for NumaNode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.is_unknown() {
            write!(f, "UNKNOWN")
        } else {
            write!(f, "NumaNode({})", self.0)
        }
    }
}

/// Get the current CPU's NUMA node.
///
/// Uses the Linux `getcpu` syscall to determine which NUMA node the current CPU belongs to.
/// Returns [`NumaNode::UNKNOWN`] if the syscall fails.
pub fn get_current_cpu_numa_node() -> NumaNode {
    unsafe {
        let mut cpu: libc::c_uint = 0;
        let mut node: libc::c_uint = 0;

        // getcpu syscall: int getcpu(unsigned *cpu, unsigned *node, struct getcpu_cache *tcache);
        let result = libc::syscall(
            libc::SYS_getcpu,
            &mut cpu,
            &mut node,
            std::ptr::null_mut::<libc::c_void>(),
        );
        if result == 0 {
            NumaNode(node)
        } else {
            NumaNode::UNKNOWN
        }
    }
}

102
/// Format a PCI bus address from domain, bus, and device IDs.
103
///
104
105
106
107
108
109
/// Returns a string in the format `"DDDD:BB:DD.0"` suitable for sysfs lookups.
fn format_pci_bus_address(domain: i32, bus: i32, device: i32) -> String {
    format!("{:04x}:{:02x}:{:02x}.0", domain, bus, device)
}

/// Query the PCI bus address for a CUDA device from the CUDA driver API.
110
///
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/// Uses `CudaContext::attribute()` to read PCI domain, bus, and device IDs.
/// This transparently handles `CUDA_VISIBLE_DEVICES` remapping since
/// `CudaContext::new(ordinal)` operates on the process-local device index.
fn get_pci_bus_address_from_cuda(device_id: u32) -> Option<String> {
    let ctx = crate::device::cuda_context(device_id).ok()?;
    let domain = ctx
        .attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID)
        .ok()?;
    let bus = ctx
        .attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID)
        .ok()?;
    let device = ctx
        .attribute(CUdevice_attribute_enum::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID)
        .ok()?;
    Some(format_pci_bus_address(domain, bus, device))
}
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/// Read the NUMA node for a PCI device from sysfs.
///
/// Reads `/sys/bus/pci/devices/<pci_address>/numa_node`. Returns `None` if the
/// file doesn't exist, can't be read, or contains `-1` (no NUMA affinity).
fn read_numa_node_from_sysfs(pci_address: &str) -> Option<NumaNode> {
    let path = format!("/sys/bus/pci/devices/{}/numa_node", pci_address);
    let content = fs::read_to_string(&path).ok()?;
    let node: i32 = content.trim().parse().ok()?;
    if node < 0 {
        // -1 means no NUMA affinity info available
        None
    } else {
        Some(NumaNode(node as u32))
    }
}

/// Fallback: query NUMA node from nvidia-smi using PCI bus address.
///
/// Uses the PCI BDF address (not env-var-based device index) so it is
/// correct regardless of `CUDA_VISIBLE_DEVICES` remapping.
fn get_numa_node_from_nvidia_smi(pci_address: &str) -> Option<NumaNode> {
    let output = Command::new("nvidia-smi")
        .args(["topo", "--get-numa-id-of-nearby-cpu", "-i", pci_address])
        .output()
        .ok()?;

    if !output.status.success() {
        return None;
156
157
    }

158
159
160
161
162
    let stdout = std::str::from_utf8(&output.stdout).ok()?;
    let line = stdout.lines().next()?;
    let numa_str = line.split(':').nth(1)?;
    let node: u32 = numa_str.trim().parse().ok()?;
    Some(NumaNode(node))
163
164
}

165
166
/// Get NUMA node for a GPU device.
///
167
168
169
170
/// Queries the PCI bus address from the CUDA driver API, then reads the NUMA
/// node from sysfs. Falls back to nvidia-smi with the PCI address. Returns
/// `None` if the NUMA node cannot be determined, signaling the caller to skip
/// NUMA-aware allocation entirely rather than guessing wrong.
171
///
172
173
/// `CUDA_VISIBLE_DEVICES` is handled transparently because `CudaContext::new(ordinal)`
/// operates on the process-local device index.
174
///
175
/// # Arguments
176
/// * `device_id` - CUDA device index (0, 1, 2, ...) as seen by the process
177
178
///
/// # Returns
179
180
181
182
183
184
/// The NUMA node closest to the specified GPU, or `None` if it cannot be determined.
pub fn get_device_numa_node(device_id: u32) -> Option<NumaNode> {
    // Step 1: Get PCI bus address from CUDA driver
    let pci_address = match get_pci_bus_address_from_cuda(device_id) {
        Some(addr) => addr,
        None => {
185
            tracing::warn!(
186
187
                "Failed to get PCI address from CUDA for device {}, skipping NUMA optimization",
                device_id
188
            );
189
            return None;
190
191
192
        }
    };

193
194
195
196
197
198
199
    // Step 2: Check cache (includes negative lookups)
    let cache = NUMA_NODE_CACHE.get_or_init(|| Mutex::new(HashMap::new()));
    {
        let guard = cache.lock().unwrap();
        if let Some(cached) = guard.get(&pci_address) {
            return *cached;
        }
200
201
    }

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    // Step 3: Read NUMA node from sysfs
    let result = read_numa_node_from_sysfs(&pci_address)
        .or_else(|| get_numa_node_from_nvidia_smi(&pci_address));

    match result {
        Some(node) => {
            tracing::trace!(
                "GPU {} (PCI {}) on NUMA node {}",
                device_id,
                pci_address,
                node.0
            );
        }
        None => {
            tracing::warn!(
                "Could not determine NUMA node for GPU {} (PCI {}), skipping NUMA optimization",
                device_id,
                pci_address
            );
        }
222
    }
223

224
225
226
    // Cache result (including None for negative lookups)
    cache.lock().unwrap().insert(pci_address, result);
    result
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
}

/// Pin the current thread to a specific NUMA node's CPUs.
///
/// This sets the CPU affinity for the calling thread to only run on CPUs
/// belonging to the specified NUMA node. This is critical for ensuring
/// that memory allocations follow the first-touch policy on the correct node.
///
/// # Arguments
/// * `node` - The NUMA node to pin the thread to
///
/// # Errors
/// Returns an error if:
/// - NUMA topology cannot be read
/// - No CPUs are found for the specified node
/// - The `sched_setaffinity` syscall fails
pub fn pin_thread_to_numa_node(node: NumaNode) -> Result<(), String> {
    let topology =
        topology::get_numa_topology().map_err(|e| format!("Can not get NUMA topology: {}", e))?;

    let cpus = topology
        .cpus_for_node(node.0)
        .ok_or_else(|| format!("No CPUs found for NUMA node {}", node.0))?;

    if cpus.is_empty() {
        return Err(format!("No CPUs found for NUMA node {}", node.0));
    }

    unsafe {
        let mut cpu_set: libc::cpu_set_t = mem::zeroed();

        for cpu in cpus {
            libc::CPU_SET(*cpu, &mut cpu_set);
        }

        let result = libc::sched_setaffinity(
            0, // current thread
            mem::size_of::<libc::cpu_set_t>(),
            &cpu_set,
        );

        if result != 0 {
            let err = std::io::Error::last_os_error();
            return Err(format!("Failed to set CPU affinity: {}", err));
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_numa_node_equality() {
        let node0a = NumaNode(0);
        let node0b = NumaNode(0);
        let node1 = NumaNode(1);

        assert_eq!(node0a, node0b);
        assert_ne!(node0a, node1);
    }

    #[test]
    fn test_numa_node_unknown() {
        let unknown = NumaNode::UNKNOWN;
        assert!(unknown.is_unknown());
        assert_eq!(unknown.0, u32::MAX);

        let valid = NumaNode(0);
        assert!(!valid.is_unknown());
    }

    #[test]
    fn test_numa_node_display() {
        assert_eq!(format!("{}", NumaNode(0)), "NumaNode(0)");
        assert_eq!(format!("{}", NumaNode(7)), "NumaNode(7)");
        assert_eq!(format!("{}", NumaNode::UNKNOWN), "UNKNOWN");
    }

    #[test]
    fn test_numa_node_serialization() {
        let node = NumaNode(1);
        let json = serde_json::to_string(&node).unwrap();
        let deserialized: NumaNode = serde_json::from_str(&json).unwrap();
        assert_eq!(node, deserialized);
    }

    #[test]
    fn test_get_current_cpu_numa_node() {
        let node = get_current_cpu_numa_node();
        if !node.is_unknown() {
            assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
        }
    }

    #[test]
    fn test_numa_node_hash() {
        use std::collections::HashMap;

        let mut map = HashMap::new();
        map.insert(NumaNode(0), "node0");
        map.insert(NumaNode(1), "node1");

        assert_eq!(map.get(&NumaNode(0)), Some(&"node0"));
        assert_eq!(map.get(&NumaNode(1)), Some(&"node1"));
        assert_eq!(map.get(&NumaNode(2)), None);
    }

    #[test]
    fn test_numa_node_copy_clone() {
        let node1 = NumaNode(5);
340
341
        let node2 = node1;
        let node3 = node1;
342
343
344
345
346

        assert_eq!(node1, node2);
        assert_eq!(node1, node3);
        assert_eq!(node2, node3);
    }
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

    #[test]
    fn test_format_pci_bus_address() {
        assert_eq!(format_pci_bus_address(0, 0, 0), "0000:00:00.0");
        assert_eq!(format_pci_bus_address(0, 0x3b, 0), "0000:3b:00.0");
        assert_eq!(format_pci_bus_address(0, 0xaf, 0), "0000:af:00.0");
        assert_eq!(format_pci_bus_address(0x10, 0x1a, 0x03), "0010:1a:03.0");
    }

    #[test]
    fn test_read_numa_node_from_sysfs_nonexistent() {
        assert!(read_numa_node_from_sysfs("ffff:ff:ff.0").is_none());
    }
}

#[cfg(all(test, feature = "testing-cuda"))]
mod cuda_tests {
    use super::*;

    #[test]
    fn test_get_pci_bus_address_from_cuda() {
        let addr = get_pci_bus_address_from_cuda(0).expect("should get PCI address for GPU 0");
        // Validate BDF format: DDDD:BB:DD.0
        let parts: Vec<&str> = addr.split(':').collect();
        assert_eq!(
            parts.len(),
            3,
            "PCI address should have 3 colon-separated parts: {}",
            addr
        );
        assert_eq!(parts[0].len(), 4, "domain should be 4 hex chars: {}", addr);
        assert!(parts[2].ends_with(".0"), "should end with .0: {}", addr);
        println!("GPU 0 PCI address: {}", addr);
    }

    #[test]
    fn test_read_numa_node_from_sysfs_real_gpu() {
        let addr = get_pci_bus_address_from_cuda(0).expect("should get PCI address for GPU 0");
        if let Some(node) = read_numa_node_from_sysfs(&addr) {
            assert!(node.0 < 16, "NUMA node {} seems unreasonably high", node.0);
            println!("GPU 0 (PCI {}) sysfs NUMA node: {}", addr, node.0);
        } else {
            println!(
                "GPU 0 (PCI {}) has no sysfs NUMA info (single-socket?)",
                addr
            );
        }
    }

    #[test]
    fn test_get_device_numa_node_returns_some_or_none() {
        let result = get_device_numa_node(0);
        match result {
            Some(node) => {
                assert!(node.0 < 16, "NUMA node {} seems unreasonably high", node.0);
                assert!(
                    !node.is_unknown(),
                    "should never return UNKNOWN inside Some"
                );
                println!("GPU 0 detected on NUMA node: {}", node.0);
            }
            None => {
                println!("GPU 0 has no determinable NUMA node (single-socket or no sysfs info)");
            }
        }
    }
413
}