// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Transfer context. use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use anyhow::Result; use cudarc::driver::{CudaContext, CudaEvent, CudaStream}; use derive_builder::Builder; use tokio::sync::mpsc; use uuid::Uuid; use dynamo_memory::CudaMemPool; use dynamo_memory::nixl::{NixlAgent, NixlBackendConfig, XferRequest}; use velo::EventManager; use crate::manager::TransferManager; // Notifications module is declared in ../mod.rs // Re-export for convenience use super::TransferCapabilities; use notifications::RegisterPollingNotification; pub(crate) use super::notifications; pub use super::notifications::TransferCompleteNotification; #[derive(Clone, Builder)] #[builder(pattern = "owned", build_fn(private, name = "build_internal"), public)] #[allow(dead_code)] // Fields are used in build() but derive macros confuse dead code analysis pub struct TransferConfig { #[builder(default = "Arc::new(EventManager::local())")] event_system: Arc, /// Optional custom name for the NIXL agent. If not provided, defaults to "worker-{worker_id}" #[builder(default = "None", setter(strip_option))] nixl_agent_name: Option, /// Backend configuration for NIXL backends to enable #[builder(default = "NixlBackendConfig::default()")] nixl_backend_config: NixlBackendConfig, #[builder(default = "0")] cuda_device_id: usize, #[builder(default = "get_tokio_runtime()")] tokio_runtime: TokioRuntime, #[builder(default = "TransferCapabilities::default()")] capabilities: TransferCapabilities, /// Size in bytes to pre-allocate for the CUDA memory pool (default: 64 MiB) #[builder(default = "64 * 1024 * 1024")] cuda_pool_reserve_size: usize, /// Release threshold for the CUDA memory pool (default: Some(64 MiB)) /// Memory above this threshold is returned to the system when freed. /// If None, no release threshold is set. #[builder(default = "Some(64 * 1024 * 1024)")] cuda_pool_release_threshold: Option, } impl TransferConfigBuilder { /// Initialize builder with event system and tokio handle. /// /// This sets the event_system and tokio runtime handle, ensuring consistency /// with Nova's event system. Use this when the runtime has already been /// constructed and you want components to share the same event notification /// infrastructure. pub fn from_event_system_and_handle( self, event_system: Arc, handle: tokio::runtime::Handle, ) -> Self { self.event_system(event_system) .tokio_runtime(TokioRuntime::Handle(handle)) } /// Directly provide a pre-configured wrapped NIXL agent (mainly for testing). /// /// This bypasses the agent creation and backend initialization logic, /// using the provided agent directly. Useful for tests that need full /// control over agent configuration. pub fn nixl_agent(self, agent: NixlAgent) -> TransferConfigBuilderWithAgent { TransferConfigBuilderWithAgent { builder: self, agent, } } /// Add a NIXL backend to enable (uses default plugin parameters). pub fn nixl_backend(mut self, backend: impl Into) -> Self { let config = self .nixl_backend_config .get_or_insert_with(NixlBackendConfig::default); *config = config.clone().with_backend(backend); self } /// Load NIXL backend configuration from environment variables. /// /// This merges environment-based configuration with any backends already /// configured via the builder. pub fn with_env_backends(mut self) -> Result { let env_config = NixlBackendConfig::from_env()?; let config = self .nixl_backend_config .get_or_insert_with(NixlBackendConfig::default); *config = config.clone().merge(env_config); Ok(self) } pub fn build(self) -> Result { let mut config = self.build_internal()?; let worker_id = config.event_system.system_id(); // Merge environment backends if not explicitly configured if config.nixl_backend_config.backends().is_empty() { config.nixl_backend_config = NixlBackendConfig::from_env()?; } // Derive agent name from worker_id if not provided let agent_name = config .nixl_agent_name .unwrap_or_else(|| format!("worker-{}", worker_id)); let nixl_agent = NixlAgent::from_nixl_backend_config(&agent_name, config.nixl_backend_config)?; let cuda_context = CudaContext::new(config.cuda_device_id)?; let context = TransferContext::new( nixl_agent, config.event_system, cuda_context, config.tokio_runtime, config.capabilities, config.cuda_pool_reserve_size, config.cuda_pool_release_threshold, )?; Ok(TransferManager::from_context(context)) } } /// Builder that already has a pre-configured NIXL agent. /// /// This is generally used for testing when you want to pass in an agent directly /// rather than having it created by the builder. pub struct TransferConfigBuilderWithAgent { builder: TransferConfigBuilder, agent: NixlAgent, } impl TransferConfigBuilderWithAgent { /// Build the TransferManager using the pre-configured agent. pub fn build(self) -> Result { let config = self.builder.build_internal()?; let cuda_context = CudaContext::new(config.cuda_device_id)?; let context = TransferContext::new( self.agent, config.event_system, cuda_context, config.tokio_runtime, config.capabilities, config.cuda_pool_reserve_size, config.cuda_pool_release_threshold, )?; Ok(TransferManager::from_context(context)) } pub fn cuda_device_id(mut self, cuda_device_id: usize) -> Self { self.builder = self.builder.cuda_device_id(cuda_device_id); self } } fn get_tokio_runtime() -> TokioRuntime { match tokio::runtime::Handle::try_current() { Ok(handle) => TokioRuntime::Handle(handle), Err(_) => { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .max_blocking_threads(4) .worker_threads(2) .build() .expect("failed to build tokio runtime"); TokioRuntime::Shared(Arc::new(rt)) } } } #[derive(Debug, Clone)] #[doc(hidden)] pub enum TokioRuntime { Handle(tokio::runtime::Handle), Shared(Arc), } impl TokioRuntime { pub fn handle(&self) -> &tokio::runtime::Handle { match self { TokioRuntime::Handle(handle) => handle, TokioRuntime::Shared(runtime) => runtime.handle(), } } } #[derive(Clone)] #[doc(hidden)] pub struct TransferContext { worker_id: u64, nixl_agent: NixlAgent, #[allow(dead_code)] cuda_context: Arc, d2h_stream: Arc, h2d_stream: Arc, d2h_streams: Vec>, h2d_streams: Vec>, current_d2h_stream: Arc, current_h2d_stream: Arc, #[allow(dead_code)] tokio_runtime: TokioRuntime, capabilities: TransferCapabilities, event_system: Arc, // CUDA memory pool for kernel allocations cuda_pool: Arc, // Channels for background notification handlers tx_nixl_status: mpsc::Sender>, tx_cuda_event: mpsc::Sender>, #[allow(dead_code)] tx_nixl_events: mpsc::Sender, } impl TransferContext { pub fn builder() -> TransferConfigBuilder { TransferConfigBuilder::default() } pub(crate) fn new( nixl_agent: NixlAgent, event_system: Arc, cuda_context: Arc, tokio_runtime: TokioRuntime, capabilities: TransferCapabilities, cuda_pool_reserve_size: usize, cuda_pool_release_threshold: Option, ) -> Result { unsafe { cuda_context.disable_event_tracking() }; // Create CUDA memory pool for kernel allocations let mut pool_builder = CudaMemPool::builder(cuda_context.clone(), cuda_pool_reserve_size); if let Some(threshold) = cuda_pool_release_threshold { pool_builder = pool_builder.release_threshold(threshold); } let cuda_pool = Arc::new(pool_builder.build()?); // Create channels for background notification handlers let (tx_nixl_status, rx_nixl_status) = mpsc::channel(64); let (tx_cuda_event, rx_cuda_event) = mpsc::channel(64); let (tx_nixl_events, rx_nixl_events) = mpsc::channel(64); // Spawn background handlers let handle = tokio_runtime.handle(); // Spawn NIXL status polling handler handle.spawn(notifications::process_polling_notifications( rx_nixl_status, event_system.clone(), )); // Spawn CUDA event polling handler handle.spawn(notifications::process_polling_notifications( rx_cuda_event, event_system.clone(), )); // Spawn NIXL notification events handler handle.spawn(notifications::process_nixl_notification_events( nixl_agent.raw_agent().clone(), rx_nixl_events, event_system.clone(), )); let d2h_streams: Vec> = (0..4) .map(|_| cuda_context.new_stream()) .collect::, _>>()?; let h2d_streams: Vec> = (0..4) .map(|_| cuda_context.new_stream()) .collect::, _>>()?; let d2h_stream = d2h_streams[0].clone(); let h2d_stream = h2d_streams[0].clone(); let current_d2h_stream = Arc::new(AtomicUsize::new(0)); let current_h2d_stream = Arc::new(AtomicUsize::new(0)); Ok(Self { worker_id: event_system.system_id(), nixl_agent, cuda_context: cuda_context.clone(), d2h_stream, h2d_stream, d2h_streams, h2d_streams, current_d2h_stream, current_h2d_stream, tokio_runtime, capabilities, event_system, cuda_pool, tx_nixl_status, tx_cuda_event, tx_nixl_events, }) } pub(crate) fn nixl_agent(&self) -> &NixlAgent { &self.nixl_agent } #[allow(dead_code)] pub(crate) fn cuda_context(&self) -> &Arc { &self.cuda_context } // Provides the same d2h stream per invocation #[allow(dead_code)] pub(crate) fn d2h_stream(&self) -> &Arc { &self.d2h_stream } // Provides the same h2d stream per invocation #[allow(dead_code)] pub(crate) fn h2d_stream(&self) -> &Arc { &self.h2d_stream } // Provides the next d2h stream in a round-robin fashion pub(crate) fn next_d2h_streams(&self) -> Arc { let current_d2h_stream = self.current_d2h_stream.fetch_add(1, Ordering::Relaxed); self.d2h_streams[current_d2h_stream % self.d2h_streams.len()].clone() } // Provides the next h2d stream in a round-robin fashion pub(crate) fn next_h2d_streams(&self) -> Arc { let current_h2d_stream = self.current_h2d_stream.fetch_add(1, Ordering::Relaxed); self.h2d_streams[current_h2d_stream % self.h2d_streams.len()].clone() } /// Acquire an H2D stream for use by caller. /// /// This returns a stream from the pool that the caller can use for multiple /// sequential operations. The caller is responsible for all synchronization /// (e.g., recording events after operations). /// /// Used for layer-wise transfers where all layers must execute on the same stream. pub fn acquire_h2d_stream(&self) -> Arc { self.next_h2d_streams() } /// Acquire a D2H stream for use by caller. /// /// This returns a stream from the pool that the caller can use for multiple /// sequential operations. The caller is responsible for all synchronization /// (e.g., recording events after operations). /// /// Used for layer-wise transfers where all layers must execute on the same stream. pub fn acquire_d2h_stream(&self) -> Arc { self.next_d2h_streams() } #[allow(dead_code)] #[doc(hidden)] pub fn tokio(&self) -> &tokio::runtime::Handle { self.tokio_runtime.handle() } pub(crate) fn capabilities(&self) -> &TransferCapabilities { &self.capabilities } #[doc(hidden)] pub fn event_system(&self) -> &Arc { &self.event_system } /// Get the CUDA memory pool for kernel allocations. pub(crate) fn cuda_pool(&self) -> &Arc { &self.cuda_pool } /// Register a NIXL transfer request for status polling completion. /// /// This method enqueues the transfer request to be polled for completion /// using `agent.get_xfer_status()`. Returns a notification object that /// can be awaited for completion. pub(crate) fn register_nixl_status( &self, xfer_req: XferRequest, ) -> TransferCompleteNotification { let event = self .event_system .new_event() .expect("Failed to allocate event"); let handle = event.into_handle(); let awaiter = self .event_system .awaiter(handle) .expect("Failed to get awaiter"); let notification = notifications::RegisterPollingNotification { uuid: Uuid::new_v4(), checker: notifications::NixlStatusChecker::new( self.nixl_agent.raw_agent().clone(), xfer_req, ), event_handle: handle, }; // Send to background handler — log error if channel is full or closed if let Err(e) = self.tx_nixl_status.try_send(notification) { tracing::error!( "Failed to enqueue NIXL status notification: channel full or closed: {}", e ); } TransferCompleteNotification::from_awaiter(awaiter) } /// Register a CUDA event for polling completion. /// /// This method enqueues the CUDA event to be polled for completion. /// Returns a notification object that can be awaited for completion. pub(crate) fn register_cuda_event(&self, event: CudaEvent) -> TransferCompleteNotification { let new_event = self .event_system .new_event() .expect("Failed to allocate event"); let handle = new_event.into_handle(); let awaiter = self .event_system .awaiter(handle) .expect("Failed to get awaiter"); let notification = notifications::RegisterPollingNotification { uuid: Uuid::new_v4(), checker: notifications::CudaEventChecker::new(event), event_handle: handle, }; // Send to background handler — log error if channel is full or closed if let Err(e) = self.tx_cuda_event.try_send(notification) { tracing::error!( "Failed to enqueue CUDA event notification: channel full or closed: {}", e ); } TransferCompleteNotification::from_awaiter(awaiter) } /// Register a NIXL transfer request for notification-based completion. /// /// This method enqueues the transfer request to be completed via NIXL /// notification events. Returns a notification object that can be awaited /// for completion. #[allow(dead_code)] pub(crate) fn register_nixl_event( &self, xfer_req: XferRequest, ) -> TransferCompleteNotification { let event = self .event_system .new_event() .expect("Failed to allocate event"); let handle = event.into_handle(); let awaiter = self .event_system .awaiter(handle) .expect("Failed to get awaiter"); let notification = notifications::RegisterNixlNotification { uuid: Uuid::new_v4(), xfer_req, event_handle: handle, }; // Send to background handler — log error if channel is full or closed if let Err(e) = self.tx_nixl_events.try_send(notification) { tracing::error!( "Failed to enqueue NIXL event notification: channel full or closed: {}", e ); } TransferCompleteNotification::from_awaiter(awaiter) } /// Get the worker ID for this context. pub(crate) fn worker_id(&self) -> u64 { self.worker_id } }