context.rs 17.1 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Transfer context.

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use anyhow::Result;
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use derive_builder::Builder;
use tokio::sync::mpsc;
use uuid::Uuid;

use dynamo_memory::CudaMemPool;
use dynamo_memory::nixl::{NixlAgent, NixlBackendConfig, XferRequest};
Ryan Olson's avatar
Ryan Olson committed
17
use velo::EventManager;
Ryan Olson's avatar
Ryan Olson committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507

use crate::manager::TransferManager;

// Notifications module is declared in ../mod.rs
// Re-export for convenience
use super::TransferCapabilities;
use notifications::RegisterPollingNotification;

pub(crate) use super::notifications;
pub use super::notifications::TransferCompleteNotification;

#[derive(Clone, Builder)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"), public)]
#[allow(dead_code)] // Fields are used in build() but derive macros confuse dead code analysis
pub struct TransferConfig {
    #[builder(default = "Arc::new(EventManager::local())")]
    event_system: Arc<EventManager>,

    /// Optional custom name for the NIXL agent. If not provided, defaults to "worker-{worker_id}"
    #[builder(default = "None", setter(strip_option))]
    nixl_agent_name: Option<String>,

    /// Backend configuration for NIXL backends to enable
    #[builder(default = "NixlBackendConfig::default()")]
    nixl_backend_config: NixlBackendConfig,

    #[builder(default = "0")]
    cuda_device_id: usize,

    #[builder(default = "get_tokio_runtime()")]
    tokio_runtime: TokioRuntime,

    #[builder(default = "TransferCapabilities::default()")]
    capabilities: TransferCapabilities,

    /// Size in bytes to pre-allocate for the CUDA memory pool (default: 64 MiB)
    #[builder(default = "64 * 1024 * 1024")]
    cuda_pool_reserve_size: usize,

    /// Release threshold for the CUDA memory pool (default: Some(64 MiB))
    /// Memory above this threshold is returned to the system when freed.
    /// If None, no release threshold is set.
    #[builder(default = "Some(64 * 1024 * 1024)")]
    cuda_pool_release_threshold: Option<u64>,
}

impl TransferConfigBuilder {
    /// Initialize builder with event system and tokio handle.
    ///
    /// This sets the event_system and tokio runtime handle, ensuring consistency
    /// with Nova's event system. Use this when the runtime has already been
    /// constructed and you want components to share the same event notification
    /// infrastructure.
    pub fn from_event_system_and_handle(
        self,
        event_system: Arc<EventManager>,
        handle: tokio::runtime::Handle,
    ) -> Self {
        self.event_system(event_system)
            .tokio_runtime(TokioRuntime::Handle(handle))
    }

    /// Directly provide a pre-configured wrapped NIXL agent (mainly for testing).
    ///
    /// This bypasses the agent creation and backend initialization logic,
    /// using the provided agent directly. Useful for tests that need full
    /// control over agent configuration.
    pub fn nixl_agent(self, agent: NixlAgent) -> TransferConfigBuilderWithAgent {
        TransferConfigBuilderWithAgent {
            builder: self,
            agent,
        }
    }

    /// Add a NIXL backend to enable (uses default plugin parameters).
    pub fn nixl_backend(mut self, backend: impl Into<String>) -> Self {
        let config = self
            .nixl_backend_config
            .get_or_insert_with(NixlBackendConfig::default);
        *config = config.clone().with_backend(backend);
        self
    }

    /// Load NIXL backend configuration from environment variables.
    ///
    /// This merges environment-based configuration with any backends already
    /// configured via the builder.
    pub fn with_env_backends(mut self) -> Result<Self> {
        let env_config = NixlBackendConfig::from_env()?;
        let config = self
            .nixl_backend_config
            .get_or_insert_with(NixlBackendConfig::default);
        *config = config.clone().merge(env_config);
        Ok(self)
    }

    pub fn build(self) -> Result<TransferManager> {
        let mut config = self.build_internal()?;

        let worker_id = config.event_system.system_id();

        // Merge environment backends if not explicitly configured
        if config.nixl_backend_config.backends().is_empty() {
            config.nixl_backend_config = NixlBackendConfig::from_env()?;
        }

        // Derive agent name from worker_id if not provided
        let agent_name = config
            .nixl_agent_name
            .unwrap_or_else(|| format!("worker-{}", worker_id));

        let nixl_agent =
            NixlAgent::from_nixl_backend_config(&agent_name, config.nixl_backend_config)?;

        let cuda_context = CudaContext::new(config.cuda_device_id)?;
        let context = TransferContext::new(
            nixl_agent,
            config.event_system,
            cuda_context,
            config.tokio_runtime,
            config.capabilities,
            config.cuda_pool_reserve_size,
            config.cuda_pool_release_threshold,
        )?;
        Ok(TransferManager::from_context(context))
    }
}

/// Builder that already has a pre-configured NIXL agent.
///
/// This is generally used for testing when you want to pass in an agent directly
/// rather than having it created by the builder.
pub struct TransferConfigBuilderWithAgent {
    builder: TransferConfigBuilder,
    agent: NixlAgent,
}

impl TransferConfigBuilderWithAgent {
    /// Build the TransferManager using the pre-configured agent.
    pub fn build(self) -> Result<TransferManager> {
        let config = self.builder.build_internal()?;
        let cuda_context = CudaContext::new(config.cuda_device_id)?;
        let context = TransferContext::new(
            self.agent,
            config.event_system,
            cuda_context,
            config.tokio_runtime,
            config.capabilities,
            config.cuda_pool_reserve_size,
            config.cuda_pool_release_threshold,
        )?;
        Ok(TransferManager::from_context(context))
    }

    pub fn cuda_device_id(mut self, cuda_device_id: usize) -> Self {
        self.builder = self.builder.cuda_device_id(cuda_device_id);
        self
    }
}

fn get_tokio_runtime() -> TokioRuntime {
    match tokio::runtime::Handle::try_current() {
        Ok(handle) => TokioRuntime::Handle(handle),
        Err(_) => {
            let rt = tokio::runtime::Builder::new_multi_thread()
                .enable_all()
                .max_blocking_threads(4)
                .worker_threads(2)
                .build()
                .expect("failed to build tokio runtime");

            TokioRuntime::Shared(Arc::new(rt))
        }
    }
}

#[derive(Debug, Clone)]
#[doc(hidden)]
pub enum TokioRuntime {
    Handle(tokio::runtime::Handle),
    Shared(Arc<tokio::runtime::Runtime>),
}

impl TokioRuntime {
    pub fn handle(&self) -> &tokio::runtime::Handle {
        match self {
            TokioRuntime::Handle(handle) => handle,
            TokioRuntime::Shared(runtime) => runtime.handle(),
        }
    }
}

#[derive(Clone)]
#[doc(hidden)]
pub struct TransferContext {
    worker_id: u64,
    nixl_agent: NixlAgent,
    #[allow(dead_code)]
    cuda_context: Arc<CudaContext>,
    d2h_stream: Arc<CudaStream>,
    h2d_stream: Arc<CudaStream>,
    d2h_streams: Vec<Arc<CudaStream>>,
    h2d_streams: Vec<Arc<CudaStream>>,
    current_d2h_stream: Arc<AtomicUsize>,
    current_h2d_stream: Arc<AtomicUsize>,
    #[allow(dead_code)]
    tokio_runtime: TokioRuntime,
    capabilities: TransferCapabilities,
    event_system: Arc<EventManager>,
    // CUDA memory pool for kernel allocations
    cuda_pool: Arc<CudaMemPool>,
    // Channels for background notification handlers
    tx_nixl_status: mpsc::Sender<RegisterPollingNotification<notifications::NixlStatusChecker>>,
    tx_cuda_event: mpsc::Sender<RegisterPollingNotification<notifications::CudaEventChecker>>,
    #[allow(dead_code)]
    tx_nixl_events: mpsc::Sender<notifications::RegisterNixlNotification>,
}

impl TransferContext {
    pub fn builder() -> TransferConfigBuilder {
        TransferConfigBuilder::default()
    }

    pub(crate) fn new(
        nixl_agent: NixlAgent,
        event_system: Arc<EventManager>,
        cuda_context: Arc<CudaContext>,
        tokio_runtime: TokioRuntime,
        capabilities: TransferCapabilities,
        cuda_pool_reserve_size: usize,
        cuda_pool_release_threshold: Option<u64>,
    ) -> Result<Self> {
        unsafe { cuda_context.disable_event_tracking() };

        // Create CUDA memory pool for kernel allocations
        let mut pool_builder = CudaMemPool::builder(cuda_context.clone(), cuda_pool_reserve_size);
        if let Some(threshold) = cuda_pool_release_threshold {
            pool_builder = pool_builder.release_threshold(threshold);
        }
        let cuda_pool = Arc::new(pool_builder.build()?);

        // Create channels for background notification handlers
        let (tx_nixl_status, rx_nixl_status) = mpsc::channel(64);
        let (tx_cuda_event, rx_cuda_event) = mpsc::channel(64);
        let (tx_nixl_events, rx_nixl_events) = mpsc::channel(64);

        // Spawn background handlers
        let handle = tokio_runtime.handle();

        // Spawn NIXL status polling handler
        handle.spawn(notifications::process_polling_notifications(
            rx_nixl_status,
            event_system.clone(),
        ));

        // Spawn CUDA event polling handler
        handle.spawn(notifications::process_polling_notifications(
            rx_cuda_event,
            event_system.clone(),
        ));

        // Spawn NIXL notification events handler
        handle.spawn(notifications::process_nixl_notification_events(
            nixl_agent.raw_agent().clone(),
            rx_nixl_events,
            event_system.clone(),
        ));

        let d2h_streams: Vec<Arc<CudaStream>> = (0..4)
            .map(|_| cuda_context.new_stream())
            .collect::<Result<Vec<_>, _>>()?;

        let h2d_streams: Vec<Arc<CudaStream>> = (0..4)
            .map(|_| cuda_context.new_stream())
            .collect::<Result<Vec<_>, _>>()?;

        let d2h_stream = d2h_streams[0].clone();
        let h2d_stream = h2d_streams[0].clone();

        let current_d2h_stream = Arc::new(AtomicUsize::new(0));
        let current_h2d_stream = Arc::new(AtomicUsize::new(0));

        Ok(Self {
            worker_id: event_system.system_id(),
            nixl_agent,
            cuda_context: cuda_context.clone(),
            d2h_stream,
            h2d_stream,
            d2h_streams,
            h2d_streams,
            current_d2h_stream,
            current_h2d_stream,
            tokio_runtime,
            capabilities,
            event_system,
            cuda_pool,
            tx_nixl_status,
            tx_cuda_event,
            tx_nixl_events,
        })
    }

    pub(crate) fn nixl_agent(&self) -> &NixlAgent {
        &self.nixl_agent
    }

    #[allow(dead_code)]
    pub(crate) fn cuda_context(&self) -> &Arc<CudaContext> {
        &self.cuda_context
    }

    // Provides the same d2h stream per invocation
    #[allow(dead_code)]
    pub(crate) fn d2h_stream(&self) -> &Arc<CudaStream> {
        &self.d2h_stream
    }

    // Provides the same h2d stream per invocation
    #[allow(dead_code)]
    pub(crate) fn h2d_stream(&self) -> &Arc<CudaStream> {
        &self.h2d_stream
    }

    // Provides the next d2h stream in a round-robin fashion
    pub(crate) fn next_d2h_streams(&self) -> Arc<CudaStream> {
        let current_d2h_stream = self.current_d2h_stream.fetch_add(1, Ordering::Relaxed);
        self.d2h_streams[current_d2h_stream % self.d2h_streams.len()].clone()
    }

    // Provides the next h2d stream in a round-robin fashion
    pub(crate) fn next_h2d_streams(&self) -> Arc<CudaStream> {
        let current_h2d_stream = self.current_h2d_stream.fetch_add(1, Ordering::Relaxed);
        self.h2d_streams[current_h2d_stream % self.h2d_streams.len()].clone()
    }

    /// Acquire an H2D stream for use by caller.
    ///
    /// This returns a stream from the pool that the caller can use for multiple
    /// sequential operations. The caller is responsible for all synchronization
    /// (e.g., recording events after operations).
    ///
    /// Used for layer-wise transfers where all layers must execute on the same stream.
    pub fn acquire_h2d_stream(&self) -> Arc<CudaStream> {
        self.next_h2d_streams()
    }

    /// Acquire a D2H stream for use by caller.
    ///
    /// This returns a stream from the pool that the caller can use for multiple
    /// sequential operations. The caller is responsible for all synchronization
    /// (e.g., recording events after operations).
    ///
    /// Used for layer-wise transfers where all layers must execute on the same stream.
    pub fn acquire_d2h_stream(&self) -> Arc<CudaStream> {
        self.next_d2h_streams()
    }

    #[allow(dead_code)]
    #[doc(hidden)]
    pub fn tokio(&self) -> &tokio::runtime::Handle {
        self.tokio_runtime.handle()
    }

    pub(crate) fn capabilities(&self) -> &TransferCapabilities {
        &self.capabilities
    }

    #[doc(hidden)]
    pub fn event_system(&self) -> &Arc<EventManager> {
        &self.event_system
    }

    /// Get the CUDA memory pool for kernel allocations.
    pub(crate) fn cuda_pool(&self) -> &Arc<CudaMemPool> {
        &self.cuda_pool
    }

    /// Register a NIXL transfer request for status polling completion.
    ///
    /// This method enqueues the transfer request to be polled for completion
    /// using `agent.get_xfer_status()`. Returns a notification object that
    /// can be awaited for completion.
    pub(crate) fn register_nixl_status(
        &self,
        xfer_req: XferRequest,
    ) -> TransferCompleteNotification {
        let event = self
            .event_system
            .new_event()
            .expect("Failed to allocate event");
        let handle = event.into_handle();
        let awaiter = self
            .event_system
            .awaiter(handle)
            .expect("Failed to get awaiter");

        let notification = notifications::RegisterPollingNotification {
            uuid: Uuid::new_v4(),
            checker: notifications::NixlStatusChecker::new(
                self.nixl_agent.raw_agent().clone(),
                xfer_req,
            ),
            event_handle: handle,
        };

        // Send to background handler — log error if channel is full or closed
        if let Err(e) = self.tx_nixl_status.try_send(notification) {
            tracing::error!(
                "Failed to enqueue NIXL status notification: channel full or closed: {}",
                e
            );
        }

        TransferCompleteNotification::from_awaiter(awaiter)
    }

    /// Register a CUDA event for polling completion.
    ///
    /// This method enqueues the CUDA event to be polled for completion.
    /// Returns a notification object that can be awaited for completion.
    pub(crate) fn register_cuda_event(&self, event: CudaEvent) -> TransferCompleteNotification {
        let new_event = self
            .event_system
            .new_event()
            .expect("Failed to allocate event");
        let handle = new_event.into_handle();
        let awaiter = self
            .event_system
            .awaiter(handle)
            .expect("Failed to get awaiter");

        let notification = notifications::RegisterPollingNotification {
            uuid: Uuid::new_v4(),
            checker: notifications::CudaEventChecker::new(event),
            event_handle: handle,
        };

        // Send to background handler — log error if channel is full or closed
        if let Err(e) = self.tx_cuda_event.try_send(notification) {
            tracing::error!(
                "Failed to enqueue CUDA event notification: channel full or closed: {}",
                e
            );
        }

        TransferCompleteNotification::from_awaiter(awaiter)
    }

    /// Register a NIXL transfer request for notification-based completion.
    ///
    /// This method enqueues the transfer request to be completed via NIXL
    /// notification events. Returns a notification object that can be awaited
    /// for completion.
    #[allow(dead_code)]
    pub(crate) fn register_nixl_event(
        &self,
        xfer_req: XferRequest,
    ) -> TransferCompleteNotification {
        let event = self
            .event_system
            .new_event()
            .expect("Failed to allocate event");
        let handle = event.into_handle();
        let awaiter = self
            .event_system
            .awaiter(handle)
            .expect("Failed to get awaiter");

        let notification = notifications::RegisterNixlNotification {
            uuid: Uuid::new_v4(),
            xfer_req,
            event_handle: handle,
        };

        // Send to background handler — log error if channel is full or closed
        if let Err(e) = self.tx_nixl_events.try_send(notification) {
            tracing::error!(
                "Failed to enqueue NIXL event notification: channel full or closed: {}",
                e
            );
        }

        TransferCompleteNotification::from_awaiter(awaiter)
    }

    /// Get the worker ID for this context.
    pub(crate) fn worker_id(&self) -> u64 {
        self.worker_id
    }
}