bootstrap.rs 9.1 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! NCCL bootstrap utilities for creating communicators from scratch.
//!
//! This module provides helpers for initializing NCCL communicators in standalone
//! Rust applications and tests, where no external launcher (like PyTorch) provides
//! pre-initialized communicators.
//!
//! # Two Construction Paths
//!
//! NCCL communicators can be created via two paths:
//!
//! 1. **Bootstrap (this module)**: For tests and standalone Rust applications.
//!    Rank 0 generates a unique ID, distributes it to other ranks, and all
//!    ranks collectively call `ncclCommInitRank`.
//!
//! 2. **Borrowed handles**: For production use with PyTorch, vLLM, or TensorRT-LLM.
//!    The external runtime creates the communicator, and Rust code borrows it
//!    via FFI. See [`NcclCollectives::from_borrowed`].
//!
//! # Example: Multi-process Bootstrap
//!
//! ```rust,ignore
//! use kvbm::v2::distributed::collectives::NcclBootstrap;
//!
//! // Rank 0: Generate and share the unique ID
//! if rank == 0 {
//!     let bootstrap = NcclBootstrap::generate(world_size)?;
//!     let bytes = bootstrap.serialize();
//!     // Send `bytes` to other ranks via your IPC mechanism
//! }
//!
//! // All ranks: Initialize communicator
//! let bootstrap = if rank == 0 {
//!     NcclBootstrap::generate(world_size)?
//! } else {
//!     let bytes = receive_from_rank_0();
//!     NcclBootstrap::deserialize(&bytes)?
//! };
//!
//! let comm = bootstrap.init_communicator(rank, stream)?;
//! ```

use std::ffi::c_char;
use std::mem::MaybeUninit;

/// Platform-neutral byte type for NCCL's `ncclUniqueId::internal` field.
/// `c_char` is `i8` on x86_64 and `u8` on aarch64.
type NcclByte = c_char;

use anyhow::{Context, Result};
use cudarc::driver::sys::CUstream;
use cudarc::nccl::sys::{
    ncclComm_t, ncclCommInitRank, ncclGetUniqueId, ncclResult_t, ncclUniqueId,
};

/// Bootstrap for creating NCCL communicators from scratch.
///
/// Used by tests and standalone Rust applications where NCCL communicators
/// need to be created without an external launcher.
///
/// # Workflow
///
/// 1. Rank 0 calls [`NcclBootstrap::generate`] to create the unique ID
/// 2. Rank 0 serializes via [`NcclBootstrap::serialize`] and sends to other ranks
/// 3. Other ranks deserialize via [`NcclBootstrap::deserialize`]
/// 4. All ranks collectively call [`NcclBootstrap::init_communicator`]
///
/// # Thread Safety
///
/// The bootstrap object itself is not thread-safe, but multiple threads can
/// each have their own bootstrap object with the same unique ID to initialize
/// communicators on different devices.
#[derive(Clone)]
pub struct NcclBootstrap {
    nccl_id: ncclUniqueId,
    world_size: usize,
}

impl NcclBootstrap {
    /// Generate a new bootstrap on rank 0.
    ///
    /// This creates a unique NCCL ID that must be shared with all other ranks
    /// before they can initialize their communicators.
    ///
    /// # Arguments
    /// * `world_size` - Total number of ranks in the collective group
    ///
    /// # Returns
    /// A bootstrap object that can be serialized and distributed to other ranks.
    ///
    /// # Errors
    /// Returns an error if NCCL fails to generate a unique ID.
    pub fn generate(world_size: usize) -> Result<Self> {
        anyhow::ensure!(
            world_size > 0 && world_size <= i32::MAX as usize,
            "world_size must be in 1..={}, got {}",
            i32::MAX,
            world_size
        );
        let mut nccl_id = MaybeUninit::<ncclUniqueId>::uninit();

        // SAFETY: ncclGetUniqueId initializes the ncclUniqueId struct
        let result = unsafe { ncclGetUniqueId(nccl_id.as_mut_ptr()) };
        check_nccl_result(result).context("Failed to generate NCCL unique ID")?;

        // SAFETY: ncclGetUniqueId has initialized the struct
        let nccl_id = unsafe { nccl_id.assume_init() };

        Ok(Self {
            nccl_id,
            world_size,
        })
    }

    /// Get the world size for this bootstrap.
    pub fn world_size(&self) -> usize {
        self.world_size
    }

    /// Serialize the bootstrap for transmission to other ranks.
    ///
    /// The serialized format is:
    /// - 8 bytes: world_size as little-endian u64
    /// - 128 bytes: NCCL unique ID internal data
    ///
    /// # Returns
    /// A byte vector that can be transmitted via any IPC mechanism.
    pub fn serialize(&self) -> Vec<u8> {
        let mut bytes = Vec::with_capacity(8 + 128);
        bytes.extend_from_slice(&(self.world_size as u64).to_le_bytes());
        // Convert NcclByte array to u8 for serialization
        for &byte in &self.nccl_id.internal {
            bytes.push(byte as u8);
        }
        bytes
    }

    /// Deserialize a bootstrap received from rank 0.
    ///
    /// # Arguments
    /// * `bytes` - Serialized bootstrap data from [`NcclBootstrap::serialize`]
    ///
    /// # Returns
    /// A bootstrap object that can be used to initialize a communicator.
    ///
    /// # Errors
    /// Returns an error if the byte array has incorrect length.
    pub fn deserialize(bytes: &[u8]) -> Result<Self> {
        if bytes.len() != 8 + 128 {
            anyhow::bail!(
                "Invalid bootstrap data length: expected {}, got {}",
                8 + 128,
                bytes.len()
            );
        }

        let world_size = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;

        let mut nccl_id = ncclUniqueId {
            internal: [0 as NcclByte; 128],
        };
        // Copy bytes into internal array
        for (i, &byte) in bytes[8..].iter().enumerate() {
            nccl_id.internal[i] = byte as NcclByte;
        }

        Ok(Self {
            nccl_id,
            world_size,
        })
    }

    /// Initialize an NCCL communicator for this rank.
    ///
    /// This is a **collective operation** - all ranks must call this method
    /// simultaneously with the same bootstrap data for initialization to succeed.
    ///
    /// # Arguments
    /// * `rank` - The rank of this worker (0 to world_size-1)
    /// * `stream` - The CUDA stream to associate with NCCL operations
    ///
    /// # Returns
    /// An NCCL communicator handle that can be used for collective operations.
    ///
    /// # Safety
    /// The returned communicator must be destroyed with `ncclCommDestroy` when
    /// no longer needed. The caller is responsible for lifetime management.
    ///
    /// # Errors
    /// Returns an error if:
    /// - `rank` is >= `world_size`
    /// - NCCL initialization fails (e.g., network issues, GPU errors)
    /// - Not all ranks call this method (will hang)
    pub fn init_communicator(&self, rank: usize, _stream: CUstream) -> Result<ncclComm_t> {
        if rank >= self.world_size {
            anyhow::bail!(
                "Rank {} is invalid for world_size {}",
                rank,
                self.world_size
            );
        }
        anyhow::ensure!(
            self.world_size <= i32::MAX as usize,
            "world_size {} exceeds i32::MAX",
            self.world_size
        );

        let mut comm = MaybeUninit::<ncclComm_t>::uninit();

        // SAFETY: ncclCommInitRank is a collective call that initializes the communicator.
        // All ranks must call this with the same nccl_id for it to complete.
        let result = unsafe {
            ncclCommInitRank(
                comm.as_mut_ptr(),
                self.world_size as i32,
                self.nccl_id,
                rank as i32,
            )
        };
        check_nccl_result(result).context("Failed to initialize NCCL communicator")?;

        // SAFETY: ncclCommInitRank has initialized the communicator
        let comm = unsafe { comm.assume_init() };

        tracing::debug!(
            rank,
            world_size = self.world_size,
            "NCCL communicator initialized"
        );

        Ok(comm)
    }
}

/// Check an NCCL result and convert to anyhow::Result.
pub(crate) fn check_nccl_result(result: ncclResult_t) -> Result<()> {
    if result == ncclResult_t::ncclSuccess {
        Ok(())
    } else {
        anyhow::bail!("NCCL error: {:?}", result)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bootstrap_serialization_roundtrip() {
        // Note: This test doesn't actually call NCCL functions,
        // it just tests the serialization logic
        let world_size = 4;

        // Create a bootstrap with a dummy ID (we can't call ncclGetUniqueId without NCCL)
        let original = NcclBootstrap {
            nccl_id: ncclUniqueId {
                internal: [42 as NcclByte; 128],
            },
            world_size,
        };

        let bytes = original.serialize();
        assert_eq!(bytes.len(), 8 + 128);

        let deserialized = NcclBootstrap::deserialize(&bytes).unwrap();
        assert_eq!(deserialized.world_size, world_size);
        assert_eq!(deserialized.nccl_id.internal, original.nccl_id.internal);
    }

    #[test]
    fn test_deserialize_invalid_length() {
        let bytes = vec![0u8; 10]; // Wrong length
        let result = NcclBootstrap::deserialize(&bytes);
        assert!(result.is_err());
    }
}