// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 //! Typestate builder for NIXL transfers. //! //! This module provides a compile-time safe builder for NIXL transfers that ensures //! all required parameters are set before execution. use super::{PhysicalLayout, TransferContext, TransferStrategy}; use crate::BlockId; use crate::transfer::context::TransferCompleteNotification; use crate::transfer::{can_use_whole_block_transfer, validate_layout_compatibility}; use anyhow::{Result, anyhow}; use dynamo_memory::nixl::{XferDescList, XferOp}; use std::marker::PhantomData; use std::ops::Range; /// Marker type for unset builder fields. pub struct Unset; /// Marker type for set builder fields. pub struct Set; /// Typestate builder for NIXL transfers. /// /// This builder uses the typestate pattern to ensure all required parameters are set /// at compile time. The type parameters track which fields have been set: /// - `TSrc`: Source layout state /// - `TDst`: Destination layout state /// - `TSrcBlocks`: Source block IDs state /// - `TDstBlocks`: Destination block IDs state /// - `TStrategy`: Transfer strategy state pub struct NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy> { src: Option<&'a PhysicalLayout>, dst: Option<&'a PhysicalLayout>, src_block_ids: Option<&'a [BlockId]>, dst_block_ids: Option<&'a [BlockId]>, strategy: Option, layer_range: Option>, write_notif: Option, _phantom: PhantomData<(TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy)>, } impl<'a> NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> { /// Creates a new NIXL transfer builder with all fields unset. pub fn new() -> Self { Self { src: None, dst: None, src_block_ids: None, dst_block_ids: None, strategy: None, layer_range: None, write_notif: None, _phantom: PhantomData, } } } impl<'a> Default for NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> { fn default() -> Self { Self::new() } } // Required field setters - these consume self and return a new builder with the field marked as Set impl<'a, TDst, TSrcBlocks, TDstBlocks, TStrategy> NixlTransferBuilder<'a, Unset, TDst, TSrcBlocks, TDstBlocks, TStrategy> { /// Sets the source physical layout. pub fn src( self, src: &'a PhysicalLayout, ) -> NixlTransferBuilder<'a, Set, TDst, TSrcBlocks, TDstBlocks, TStrategy> { NixlTransferBuilder { src: Some(src), dst: self.dst, src_block_ids: self.src_block_ids, dst_block_ids: self.dst_block_ids, strategy: self.strategy, layer_range: self.layer_range, write_notif: self.write_notif, _phantom: PhantomData, } } } impl<'a, TSrc, TSrcBlocks, TDstBlocks, TStrategy> NixlTransferBuilder<'a, TSrc, Unset, TSrcBlocks, TDstBlocks, TStrategy> { /// Sets the destination physical layout. pub fn dst( self, dst: &'a PhysicalLayout, ) -> NixlTransferBuilder<'a, TSrc, Set, TSrcBlocks, TDstBlocks, TStrategy> { NixlTransferBuilder { src: self.src, dst: Some(dst), src_block_ids: self.src_block_ids, dst_block_ids: self.dst_block_ids, strategy: self.strategy, layer_range: self.layer_range, write_notif: self.write_notif, _phantom: PhantomData, } } } impl<'a, TSrc, TDst, TDstBlocks, TStrategy> NixlTransferBuilder<'a, TSrc, TDst, Unset, TDstBlocks, TStrategy> { /// Sets the source block IDs to transfer. pub fn src_blocks( self, src_block_ids: &'a [BlockId], ) -> NixlTransferBuilder<'a, TSrc, TDst, Set, TDstBlocks, TStrategy> { NixlTransferBuilder { src: self.src, dst: self.dst, src_block_ids: Some(src_block_ids), dst_block_ids: self.dst_block_ids, strategy: self.strategy, layer_range: self.layer_range, write_notif: self.write_notif, _phantom: PhantomData, } } } impl<'a, TSrc, TDst, TSrcBlocks, TStrategy> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Unset, TStrategy> { /// Sets the destination block IDs to transfer. pub fn dst_blocks( self, dst_block_ids: &'a [BlockId], ) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Set, TStrategy> { NixlTransferBuilder { src: self.src, dst: self.dst, src_block_ids: self.src_block_ids, dst_block_ids: Some(dst_block_ids), strategy: self.strategy, layer_range: self.layer_range, write_notif: self.write_notif, _phantom: PhantomData, } } } impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Unset> { /// Sets the NIXL transfer strategy (Read or Write). pub fn strategy( self, strategy: TransferStrategy, ) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Set> { NixlTransferBuilder { src: self.src, dst: self.dst, src_block_ids: self.src_block_ids, dst_block_ids: self.dst_block_ids, strategy: Some(strategy), layer_range: self.layer_range, write_notif: self.write_notif, _phantom: PhantomData, } } } // Optional field setters - these can be called at any point in the builder chain impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy> { /// Sets an optional range of layers to transfer. /// If not called, all layers will be transferred. pub fn layer_range(mut self, layer_range: Range) -> Self { self.layer_range = Some(layer_range); self } /// Sets an optional write notification UUID. #[expect(dead_code)] pub fn write_notif(mut self, write_notif: uuid::Uuid) -> Self { self.write_notif = Some(write_notif); self } } // Execute method - only available when all required fields are Set impl<'a> NixlTransferBuilder<'a, Set, Set, Set, Set, Set> { /// Executes the NIXL transfer with the configured parameters. /// /// This method is only available when all required fields have been set, /// enforced at compile time by the typestate pattern. pub(crate) fn execute(self, ctx: &TransferContext) -> Result { // Unwrap all required fields (safe because typestate guarantees they're set) let src = self.src.unwrap(); let dst = self.dst.unwrap(); let src_block_ids = self.src_block_ids.unwrap(); let dst_block_ids = self.dst_block_ids.unwrap(); let strategy = self.strategy.unwrap(); let layer_range = self.layer_range; let _write_notif = self.write_notif; // Validate layouts let src_layout = src.layout(); let dst_layout = dst.layout(); if src_layout.num_layers() != dst_layout.num_layers() { return Err(anyhow!( "Layouts have incompatible layer counts: src={}, dst={}", src_layout.num_layers(), dst_layout.num_layers() )); } if src_layout.outer_dim() != dst_layout.outer_dim() { return Err(anyhow!( "Layouts have incompatible outer dimensions: src={}, dst={}", src_layout.outer_dim(), dst_layout.outer_dim() )); } // Validate layout compatibility (errors if transform would be needed) validate_layout_compatibility(src, dst)?; // Get NIXL agent let nixl_agent = ctx.nixl_agent(); // Determine layer range let layers = layer_range.clone().unwrap_or(0..src_layout.num_layers()); // Check if we can use optimized whole-block transfer let use_whole_block = can_use_whole_block_transfer(src, dst, layer_range.as_ref()); // Determine NIXL operation type let xfer_op = match strategy { TransferStrategy::NixlRead | TransferStrategy::NixlReadFlipped => XferOp::Read, TransferStrategy::NixlWrite | TransferStrategy::NixlWriteFlipped => XferOp::Write, _ => { return Err(anyhow!("Invalid NIXL transfer strategy: {:?}", strategy)); } }; // Validate locality constraints based on operation type: // - For Write operations (push): source must be local, we're writing FROM local TO remote // - For Read operations (pull): destination must be local, we're reading FROM remote INTO local let src_is_local = nixl_agent.name() == src.nixl_metadata().agent_name(); let dst_is_local = nixl_agent.name() == dst.nixl_metadata().agent_name(); // These are invariant assertions — a violation means a bug in `select_strategy`, // not a user error. The strategy selection guarantees locality constraints. match xfer_op { XferOp::Write => { assert!( src_is_local, "For NIXL Write (push), the source must be local. src_agent='{}', local_agent='{}'", src.nixl_metadata().agent_name(), nixl_agent.name() ); } XferOp::Read => { assert!( dst_is_local, "For NIXL Read (pull), the destination must be local. dst_agent='{}', local_agent='{}'", dst.nixl_metadata().agent_name(), nixl_agent.name() ); } } // Capture NIXL metadata for both layouts let src_metadata = src.nixl_metadata(); let dst_metadata = dst.nixl_metadata(); let src_mem_type = src_metadata.mem_type(); let dst_mem_type = dst_metadata.mem_type(); let src_device_id = src_metadata.device_id(); let dst_device_id = dst_metadata.device_id(); // Build XferDescLists for source and destination let mut src_dl = XferDescList::new(src_mem_type)?; let mut dst_dl = XferDescList::new(dst_mem_type)?; // Build descriptor lists - use whole-block or layer-wise depending on layout if use_whole_block { let bytes_per_block = src_layout.config().bytes_per_block(); tracing::debug!( num_blocks = src_block_ids.len(), bytes_per_block, "Building whole-block NIXL descriptors" ); for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) { let src_region = src.memory_region(src_block_id, 0, 0)?; let dst_region = dst.memory_region(dst_block_id, 0, 0)?; src_dl.add_desc(src_region.addr(), bytes_per_block, src_device_id); dst_dl.add_desc(dst_region.addr(), bytes_per_block, dst_device_id); } } else { tracing::debug!( num_blocks = src_block_ids.len(), layer_range = ?layers, src_fc = src_layout.is_fully_contiguous(), dst_fc = dst_layout.is_fully_contiguous(), "Building layer-wise NIXL descriptors" ); for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) { for layer_id in layers.clone() { for outer_id in 0..src_layout.outer_dim() { let src_region = src.memory_region(src_block_id, layer_id, outer_id)?; let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?; if src_region.size() != dst_region.size() { return Err(anyhow!( "Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}", src_block_id, dst_block_id, layer_id, outer_id, src_region.size(), dst_region.size() )); } src_dl.add_desc(src_region.addr(), src_region.size(), src_device_id); dst_dl.add_desc(dst_region.addr(), dst_region.size(), dst_device_id); } } } } // Note: Overlap detection was removed from nixl-sys 0.6.1 // The NIXL library now handles overlap detection internally if matches!( strategy, TransferStrategy::NixlReadFlipped | TransferStrategy::NixlWriteFlipped ) { std::mem::swap(&mut src_dl, &mut dst_dl); } // Create transfer request // The remote agent depends on operation type: // - For Write (push): remote is the destination // - For Read (pull): remote is the source let remote_agent = match xfer_op { XferOp::Write => dst_metadata.agent_name(), XferOp::Read => src_metadata.agent_name(), }; let xfer_req = nixl_agent.create_xfer_req( xfer_op, &src_dl, &dst_dl, remote_agent, None, // opt_args )?; // Post transfer request // Note: Notification handling via OptArgs can be added later if needed let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?; if still_pending { // Register for async completion via status polling Ok(ctx.register_nixl_status(xfer_req)) } else { // Transfer completed synchronously Ok(TransferCompleteNotification::completed()) } } }