"vscode:/vscode.git/clone" did not exist on "6b8545fc4d931e535dd3c8f15d4d5a4a51d245b6"
Unverified Commit 9ab148dc authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: kvbm-physical (#6490)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 7546c193
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Physical layout types that combine abstract layouts with storage location metadata.
use crate::BlockId;
use super::{
FullyContiguousLayout, InnerShape, LayerSeparateLayout, Layout, MemoryRegion,
builder::{PhysicalLayoutBuilder, PhysicalLayoutBuilderDefault},
serialize::{LayoutDescriptor, LayoutTypeDetails},
};
use anyhow::{Result, anyhow};
use dynamo_memory::{
Buffer, MemoryDescriptor, StorageKind,
nixl::{MemType, NixlAgent, NixlDescriptor},
};
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::sync::Arc;
/// Runtime representation of a layout with its physical storage location.
///
/// A `PhysicalLayout` wraps an abstract [`Layout`] with information about where
/// its memory physically resides (GPU, host, disk) and whether it's local or remote.
/// This enables the transfer system to select appropriate copy strategies and build
/// NIXL transfer descriptors.
#[derive(Debug, Clone)]
pub struct PhysicalLayout {
/// The abstract layout defining memory organization
layout: Arc<dyn Layout>,
/// Physical storage location (System, Device, Pinned, Disk)
location: StorageKind,
/// NIXL registration metadata
nixl_metadata: NixlMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NixlMetadata {
agent_name: String,
mem_type: MemType,
device_id: u64,
}
impl NixlMetadata {
pub fn new(agent_name: String, mem_type: MemType, device_id: u64) -> Self {
Self {
agent_name,
mem_type,
device_id,
}
}
pub fn agent_name(&self) -> &str {
&self.agent_name
}
#[inline(always)]
pub fn mem_type(&self) -> MemType {
self.mem_type
}
#[inline(always)]
pub fn device_id(&self) -> u64 {
self.device_id
}
}
impl PhysicalLayout {
/// Create a typed builder that enforces NIXL registration.
pub fn builder(agent: NixlAgent) -> PhysicalLayoutBuilderDefault {
PhysicalLayoutBuilder::new(agent)
}
/// Create a new local physical layout.
///
/// # Arguments
/// * `layout` - The abstract layout to wrap
/// * `location` - Where the layout's memory resides
pub(crate) fn new_local(
layout: Arc<dyn Layout>,
location: StorageKind,
nixl_metadata: NixlMetadata,
) -> Self {
Self {
layout,
location,
nixl_metadata,
}
}
// /// Create a new remote physical layout from a descriptor.
// ///
// /// # Arguments
// /// * `layout` - The abstract layout to wrap
// /// * `location` - Where the layout's memory resides (on remote node)
// /// * `remote_agent` - Name of the NIXL agent on the remote node
// pub fn new_remote(
// layout: Arc<dyn Layout>,
// location: StorageKind,
// remote_agent: String,
// ) -> Self {
// let metadata = NixlMetadata::new(
// remote_agent.clone(),
// location.to_nixl_mem_type(),
// location.device_id(),
// );
// let registrations = vec![RegisteredStorageMetadata::new(
// metadata.agent_name().to_string(),
// location,
// )];
// Self {
// layout,
// location,
// locality: Locality::Remote(remote_agent),
// nixl_metadata: Some(metadata),
// registered: registrations,
// }
// }
/// Get the underlying layout.
pub fn layout(&self) -> &Arc<dyn Layout> {
&self.layout
}
/// Get the storage location.
pub(crate) fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata.
pub(crate) fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get a memory region with location information.
///
/// # Arguments
/// * `block_id` - Block identifier
/// * `layer_id` - Layer identifier
/// * `outer_id` - Outer dimension identifier
pub fn memory_region(
&self,
block_id: BlockId,
layer_id: usize,
outer_id: usize,
) -> Result<MemoryRegion> {
self.layout.memory_region(block_id, layer_id, outer_id)
}
/// Serialize this physical layout for transmission to remote nodes.
///
/// This converts the runtime `PhysicalLayout` into a `LayoutDescriptor` that
/// contains all information needed to reconstruct the layout on a remote node,
/// including layout configuration, memory descriptors, NIXL metadata, and
/// layout-type-specific details.
///
/// # Returns
/// A serializable representation of this layout
pub(crate) fn to_descriptor(&self) -> Result<LayoutDescriptor> {
// Extract memory descriptors
let memory_descriptors = self
.layout
.memory_regions()
.iter()
.map(|region| MemoryRegion {
addr: region.addr(),
size: region.size(),
})
.collect();
// Get layout type details from the layout itself
let layout_type_details = self.layout.serialization_details();
Ok(LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: self.layout.config().clone(),
location: self.location,
nixl_metadata: self.nixl_metadata.clone(),
memory_descriptors,
layout_type_details,
})
}
/// Reconstruct a physical layout from serialized data received from a remote node.
///
/// This creates a new `PhysicalLayout` from a `LayoutDescriptor`. The reconstructed
/// layout will have memory descriptors that point to the remote node's memory,
/// allowing NIXL to build RDMA descriptors for remote access.
///
/// # Arguments
/// * `serialized` - Serialized layout data from a remote node
///
/// # Returns
/// A new `PhysicalLayout` representing the remote layout
///
/// # Note
/// The memory regions in the reconstructed layout are not valid for local access;
/// they represent remote memory addresses and are used to build NIXL transfer descriptors.
pub(crate) fn from_descriptor(serialized: LayoutDescriptor) -> Result<Self> {
// Validate version
if serialized.version > LayoutDescriptor::CURRENT_VERSION {
return Err(anyhow!(
"Unsupported serialization version: {}. Maximum supported: {}",
serialized.version,
LayoutDescriptor::CURRENT_VERSION
));
}
// Create remote memory regions from descriptors
let remote_regions: Vec<Arc<dyn MemoryDescriptor>> = serialized
.memory_descriptors
.iter()
.map(|desc| {
Arc::new(RemoteMemoryDescriptor {
addr: desc.addr,
size: desc.size,
storage_kind: serialized.location,
nixl_metadata: serialized.nixl_metadata.clone(),
}) as Arc<dyn MemoryDescriptor>
})
.collect();
// Reconstruct the layout based on type
let layout: Arc<dyn Layout> = match serialized.layout_type_details {
LayoutTypeDetails::FullyContiguous(details) => {
if remote_regions.len() != 1 {
return Err(anyhow!(
"FullyContiguous layout requires exactly 1 memory region, got {}",
remote_regions.len()
));
}
let layout = FullyContiguousLayout::new_with_format(
serialized.layout_config.clone(),
Buffer::from_arc(remote_regions[0].clone()),
details.block_format,
details.kv_block_layout,
)?;
Arc::new(layout)
}
LayoutTypeDetails::LayerSeparate(details) => {
if remote_regions.len() != serialized.layout_config.num_layers {
return Err(anyhow!(
"LayerSeparate layout requires {} memory regions (one per layer), got {}",
serialized.layout_config.num_layers,
remote_regions.len()
));
}
let inner_shape = details
.kv_block_layout
.to_inner_shape()
.unwrap_or(InnerShape::Unknown);
let layout = LayerSeparateLayout::builder()
.config(serialized.layout_config.clone())
.memory(remote_regions.into_iter().map(Buffer::from_arc).collect())
.block_dim(details.block_dim)
.inner_shape(inner_shape)
.build()?;
Arc::new(layout)
}
};
Ok(Self {
layout,
location: serialized.location,
nixl_metadata: serialized.nixl_metadata,
})
}
}
/// A memory region that represents remote memory addresses.
///
/// This type is used when reconstructing layouts from serialized data.
/// The addresses are not valid for local access but can be used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug)]
struct RemoteMemoryDescriptor {
addr: usize,
size: usize,
storage_kind: StorageKind,
nixl_metadata: NixlMetadata,
}
impl MemoryDescriptor for RemoteMemoryDescriptor {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.storage_kind
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(NixlDescriptor {
addr: self.addr as u64,
size: self.size,
mem_type: self.nixl_metadata.mem_type(),
device_id: self.nixl_metadata.device_id(),
})
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Serialization types for physical layouts.
//!
//! This module provides types for serializing and deserializing physical layouts
//! so they can be transmitted to remote nodes and reconstructed there for RDMA operations.
use super::physical::NixlMetadata;
use super::{BlockDimension, KvBlockLayout, LayoutConfig};
use anyhow::Result;
use dynamo_memory::{MemoryRegion, StorageKind};
use serde::{Deserialize, Serialize};
/// Format of blocks in a fully contiguous layout.
///
/// This enum describes how the blocks are organized and formatted in memory.
/// Currently only `Operational` is supported, but future variants may include
/// different compression schemes or memory layouts.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BlockFormat {
/// Standard operational format - blocks are stored in their normal, uncompressed form.
Operational,
}
impl Default for BlockFormat {
fn default() -> Self {
Self::Operational
}
}
/// Details specific to fully contiguous layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FullyContiguousDetails {
/// Format of the blocks in memory
pub block_format: BlockFormat,
/// KV block layout describing dimension ordering within blocks
#[serde(default)]
pub kv_block_layout: KvBlockLayout,
}
/// Details specific to layer-separate layouts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerSeparateDetails {
/// Block dimension ordering (block-first or block-second)
pub block_dim: BlockDimension,
/// KV block layout for the inner tensor format (must be operational: NHD or HND)
#[serde(default)]
pub kv_block_layout: KvBlockLayout,
}
/// Layout-type-specific details.
///
/// This enum captures the information that differs between layout types
/// and is needed to reconstruct the layout on a remote node.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayoutTypeDetails {
/// Fully contiguous layout details
FullyContiguous(FullyContiguousDetails),
/// Layer-separate layout details
LayerSeparate(LayerSeparateDetails),
}
/// Serializable representation of a physical layout.
///
/// This structure contains all information needed to reconstruct a layout
/// on a remote node, including:
/// - Layout configuration (dimensions, sizes, etc.)
/// - Storage location and NIXL metadata
/// - Memory descriptors for all regions
/// - Layout-type-specific details
///
/// The serialized form can be transmitted over the network and used to
/// build NIXL transfer descriptors for remote memory access.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayoutDescriptor {
/// Serialization format version (for future compatibility)
pub version: u32,
/// Layout configuration
pub layout_config: LayoutConfig,
/// Storage location
pub location: StorageKind,
/// NIXL metadata from the source node
pub nixl_metadata: NixlMetadata,
/// Memory descriptors for all regions backing this layout
pub memory_descriptors: Vec<MemoryRegion>,
/// Layout-type-specific details
pub layout_type_details: LayoutTypeDetails,
}
impl LayoutDescriptor {
/// Current serialization version
pub const CURRENT_VERSION: u32 = 1;
/// Serialize this layout to a JSON string.
///
/// # Returns
/// JSON string representation of the layout
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON: {}", e))
}
/// Serialize this layout to JSON bytes.
///
/// # Returns
/// UTF-8 encoded JSON bytes
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
serde_json::to_vec(self)
.map_err(|e| anyhow::anyhow!("failed to serialize layout to JSON bytes: {}", e))
}
/// Deserialize a layout from a JSON string.
///
/// # Arguments
/// * `json` - JSON string representation
///
/// # Returns
/// Deserialized layout
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON: {}", e))
}
/// Deserialize a layout from JSON bytes.
///
/// # Arguments
/// * `bytes` - UTF-8 encoded JSON bytes
///
/// # Returns
/// Deserialized layout
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes)
.map_err(|e| anyhow::anyhow!("failed to deserialize layout from JSON bytes: {}", e))
}
/// Get the layout configuration.
pub fn layout_config(&self) -> &LayoutConfig {
&self.layout_config
}
/// Get the storage location.
pub fn location(&self) -> StorageKind {
self.location
}
/// Get the NIXL metadata from the source node.
pub fn nixl_metadata(&self) -> &NixlMetadata {
&self.nixl_metadata
}
/// Get the memory descriptors.
pub fn memory_descriptors(&self) -> &[MemoryRegion] {
&self.memory_descriptors
}
/// Get the layout type details.
pub fn layout_type_details(&self) -> &LayoutTypeDetails {
&self.layout_type_details
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use dynamo_memory::nixl::MemType;
use super::*;
fn make_test_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
#[test]
fn test_block_format_default() {
assert_eq!(BlockFormat::default(), BlockFormat::Operational);
}
#[test]
fn test_serialized_layout_json_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), MemType::Dram, 0),
memory_descriptors: vec![MemoryRegion::new(0x1000, 4096)],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
kv_block_layout: KvBlockLayout::OperationalNHD,
}),
};
// Test to_json/from_json
let json = layout.to_json().unwrap();
let deserialized = LayoutDescriptor::from_json(&json).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.layout_config, layout.layout_config);
assert_eq!(deserialized.location, layout.location);
assert_eq!(
deserialized.nixl_metadata.agent_name(),
layout.nixl_metadata.agent_name()
);
assert_eq!(deserialized.memory_descriptors.len(), 1);
}
#[test]
fn test_serialized_layout_json_bytes_roundtrip() {
let layout = LayoutDescriptor {
version: LayoutDescriptor::CURRENT_VERSION,
layout_config: make_test_config(),
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test_agent".to_string(), MemType::Vram, 5),
memory_descriptors: vec![
MemoryRegion::new(0x1000, 2048),
MemoryRegion::new(0x2000, 2048),
],
layout_type_details: LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsFirstDim,
kv_block_layout: KvBlockLayout::OperationalNHD,
}),
};
// Test to_json_bytes/from_json_bytes
let bytes = layout.to_json_bytes().unwrap();
let deserialized = LayoutDescriptor::from_json_bytes(&bytes).unwrap();
assert_eq!(deserialized.version, layout.version);
assert_eq!(deserialized.nixl_metadata.device_id(), 5);
assert_eq!(deserialized.memory_descriptors.len(), 2);
}
#[test]
fn test_fully_contiguous_details_serialization() {
let details = LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
kv_block_layout: KvBlockLayout::UniversalTP,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::FullyContiguous(d) => {
assert_eq!(d.block_format, BlockFormat::Operational);
assert_eq!(d.kv_block_layout, KvBlockLayout::UniversalTP);
}
_ => panic!("Expected FullyContiguous variant"),
}
}
#[test]
fn test_layer_separate_details_serialization() {
let details = LayoutTypeDetails::LayerSeparate(LayerSeparateDetails {
block_dim: BlockDimension::BlockIsSecondDim,
kv_block_layout: KvBlockLayout::OperationalHND,
});
let json = serde_json::to_string(&details).unwrap();
let deserialized: LayoutTypeDetails = serde_json::from_str(&json).unwrap();
match deserialized {
LayoutTypeDetails::LayerSeparate(d) => {
assert_eq!(d.block_dim, BlockDimension::BlockIsSecondDim);
assert_eq!(d.kv_block_layout, KvBlockLayout::OperationalHND);
}
_ => panic!("Expected LayerSeparate variant"),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for layout serialization.
//!
//! These tests verify the complete serialization and deserialization flow,
//! ensuring that layouts can be transmitted to remote nodes and reconstructed
//! with all necessary metadata intact.
use crate::layout::physical::PhysicalLayout;
use crate::layout::{BlockDimension, LayoutConfig, LayoutDescriptor};
use dynamo_memory::nixl::{MemType, NixlAgent, NixlDescriptor};
use dynamo_memory::{Buffer, MemoryDescriptor, MemoryRegion, StorageKind};
use std::any::Any;
use std::sync::Arc;
// Simple mock implementation for testing
#[derive(Debug)]
pub struct MockMemory {
addr: usize,
size: usize,
}
impl MockMemory {
pub fn new(addr: usize, size: usize) -> Arc<Self> {
Arc::new(Self { addr, size })
}
}
impl MemoryDescriptor for MockMemory {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
StorageKind::System
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
None
}
}
/// Mock memory region for testing serialization
#[derive(Debug)]
struct TestMemoryRegion {
addr: usize,
size: usize,
kind: StorageKind,
descriptor: NixlDescriptor,
}
impl TestMemoryRegion {
fn new(addr: usize, size: usize, kind: StorageKind) -> Arc<Self> {
Arc::new(Self {
addr,
size,
kind,
descriptor: NixlDescriptor {
addr: addr as u64,
size,
mem_type: MemType::Dram,
device_id: 0,
},
})
}
}
impl MemoryDescriptor for TestMemoryRegion {
fn addr(&self) -> usize {
self.addr
}
fn size(&self) -> usize {
self.size
}
fn storage_kind(&self) -> StorageKind {
self.kind
}
fn as_any(&self) -> &dyn Any {
self
}
fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
Some(self.descriptor.clone())
}
}
fn make_test_config() -> LayoutConfig {
LayoutConfig::builder()
.num_blocks(10)
.num_layers(4)
.outer_dim(2)
.page_size(16)
.inner_dim(128)
.dtype_width_bytes(2)
.build()
.unwrap()
}
#[test]
fn test_fully_contiguous_layout_serialization_roundtrip() {
let agent = NixlAgent::new("test-fc-serialize").expect("failed to create agent");
let config = make_test_config();
// Calculate required size
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
// Create test memory region
let memory = TestMemoryRegion::new(0x10000, required_size, StorageKind::System);
let regions = vec![Buffer::from_arc(memory as Arc<dyn MemoryDescriptor>)];
// Build physical layout
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.fully_contiguous()
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize to LayoutDescriptor
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize layout");
// Verify serialized data
assert_eq!(serialized.version, LayoutDescriptor::CURRENT_VERSION);
assert_eq!(serialized.layout_config, config);
assert_eq!(serialized.location, StorageKind::System);
assert_eq!(serialized.memory_descriptors.len(), 1);
assert_eq!(serialized.memory_descriptors[0].addr, 0x10000);
assert_eq!(serialized.memory_descriptors[0].size, required_size);
// Serialize to JSON
let json = serialized.to_json().expect("failed to serialize to JSON");
assert!(json.contains("\"version\":1"));
assert!(json.contains("\"num_blocks\":10"));
// Deserialize from JSON
let deserialized = LayoutDescriptor::from_json(&json).expect("failed to deserialize from JSON");
// Verify deserialized matches original
assert_eq!(deserialized.version, serialized.version);
assert_eq!(deserialized.layout_config, serialized.layout_config);
assert_eq!(deserialized.location, serialized.location);
assert_eq!(
deserialized.memory_descriptors.len(),
serialized.memory_descriptors.len()
);
// Reconstruct layout from serialized data
let reconstructed =
PhysicalLayout::from_descriptor(deserialized).expect("failed to reconstruct layout");
// Verify reconstructed layout has same configuration
assert_eq!(reconstructed.layout().config(), &config);
assert_eq!(reconstructed.location(), StorageKind::System);
assert_eq!(reconstructed.layout().num_blocks(), 10);
assert_eq!(reconstructed.layout().num_layers(), 4);
assert!(reconstructed.layout().is_fully_contiguous());
}
#[test]
fn test_layer_separate_layout_serialization_roundtrip() {
let agent = NixlAgent::new("test-ls-serialize").expect("failed to create agent");
let config = make_test_config();
// Calculate per-layer size
let per_layer_size = config.num_blocks
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
// Create memory regions (one per layer)
let regions: Vec<Buffer> = (0..config.num_layers)
.map(|i| {
Buffer::from_arc(TestMemoryRegion::new(
0x10000 + i * per_layer_size,
per_layer_size,
StorageKind::System,
) as Arc<dyn MemoryDescriptor>)
})
.collect();
// Build physical layout
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.layer_separate(BlockDimension::BlockIsFirstDim)
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize to LayoutDescriptor
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize layout");
// Verify serialized data
assert_eq!(serialized.version, LayoutDescriptor::CURRENT_VERSION);
assert_eq!(serialized.layout_config, config);
assert_eq!(serialized.memory_descriptors.len(), 4); // One per layer
// Verify memory descriptors
for (i, desc) in serialized.memory_descriptors.iter().enumerate() {
assert_eq!(desc.addr, 0x10000 + i * per_layer_size);
assert_eq!(desc.size, per_layer_size);
}
// Serialize to JSON bytes
let json_bytes = serialized
.to_json_bytes()
.expect("failed to serialize to JSON bytes");
// Deserialize from JSON bytes
let deserialized = LayoutDescriptor::from_json_bytes(&json_bytes)
.expect("failed to deserialize from JSON bytes");
// Verify deserialized matches original
assert_eq!(deserialized.version, serialized.version);
assert_eq!(deserialized.layout_config, serialized.layout_config);
assert_eq!(
deserialized.memory_descriptors.len(),
serialized.memory_descriptors.len()
);
// Reconstruct layout from serialized data
let reconstructed =
PhysicalLayout::from_descriptor(deserialized).expect("failed to reconstruct layout");
// Verify reconstructed layout has same configuration
assert_eq!(reconstructed.layout().config(), &config);
assert_eq!(reconstructed.location(), StorageKind::System);
assert_eq!(reconstructed.layout().num_blocks(), 10);
assert_eq!(reconstructed.layout().num_layers(), 4);
assert!(!reconstructed.layout().is_fully_contiguous());
}
#[test]
fn test_memory_region_calculation_after_deserialization() {
let agent = NixlAgent::new("test-memory-calc").expect("failed to create agent");
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
let memory = TestMemoryRegion::new(0x1000, required_size, StorageKind::System);
let regions = vec![Buffer::from_arc(memory as Arc<dyn MemoryDescriptor>)];
let original_layout = PhysicalLayout::builder(agent)
.with_config(config.clone())
.fully_contiguous()
.with_registered_regions(regions)
.expect("failed to provide regions")
.build()
.expect("failed to build layout");
// Serialize and deserialize
let serialized = original_layout
.to_descriptor()
.expect("failed to serialize");
let reconstructed = PhysicalLayout::from_descriptor(serialized).expect("failed to reconstruct");
// Verify memory region calculations
let region = reconstructed
.memory_region(0, 0, 0)
.expect("failed to get memory region");
assert_eq!(region.addr, 0x1000);
let region_size = config.page_size * config.inner_dim * config.dtype_width_bytes;
assert_eq!(region.size, region_size);
// Test different block/layer/outer indices
let region = reconstructed
.memory_region(1, 1, 1)
.expect("failed to get memory region");
// Address should be: base + block_stride + layer_stride + outer_stride
let layer_stride = config.outer_dim * region_size;
let block_stride = config.num_layers * layer_stride;
let expected_addr = 0x1000 + block_stride + layer_stride + region_size;
assert_eq!(region.addr, expected_addr);
}
#[test]
fn test_version_check_on_deserialization() {
let config = make_test_config();
// Calculate required size for fully contiguous layout
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
let mut serialized = LayoutDescriptor {
version: 999, // Future version
layout_config: config.clone(),
location: StorageKind::System,
nixl_metadata: crate::layout::physical::NixlMetadata::new(
"test".to_string(),
MemType::Dram,
0,
),
memory_descriptors: vec![],
layout_type_details: crate::layout::LayoutTypeDetails::FullyContiguous(
crate::layout::FullyContiguousDetails {
block_format: crate::layout::BlockFormat::Operational,
kv_block_layout: crate::layout::KvBlockLayout::OperationalNHD,
},
),
};
// Should fail with unsupported version
let result = PhysicalLayout::from_descriptor(serialized.clone());
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported serialization version")
);
// Should succeed with supported version
serialized.version = LayoutDescriptor::CURRENT_VERSION;
serialized.memory_descriptors = vec![MemoryRegion {
addr: 0x1000,
size: required_size,
}];
let result = PhysicalLayout::from_descriptor(serialized);
if let Err(ref e) = result {
eprintln!("Error during deserialization: {}", e);
}
assert!(
result.is_ok(),
"Expected successful deserialization, got error: {:?}",
result.err()
);
let layout = result.unwrap();
assert_eq!(
layout.layout().block_layout(),
crate::layout::KvBlockLayout::OperationalNHD
);
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tensor validation utilities for layout creation.
use anyhow::{Result, anyhow};
use std::sync::Arc;
use dynamo_memory::TensorDescriptor;
/// Format of tensor layout (for future TP translation).
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorFormat {
/// NHD format: [N, H, D] where N=block_size, H=heads, D=hidden
NHD,
/// HND format: [H, N, D] where H=heads, N=block_size, D=hidden
HND,
/// Unknown or ambiguous format
Unknown,
}
/// Validate tensor strides and detect format.
///
/// This function checks that tensor strides are monotonically decreasing,
/// which ensures tensor-contiguous layout. The stride validation is flexible
/// at the inner dimension boundary to accommodate different layouts.
///
/// Additionally, it attempts to detect whether the layout is NHD or HND format,
/// which is important for future tensor parallel (TP) translation.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The detected tensor format (NHD, HND, or Unknown)
#[expect(dead_code)]
pub fn validate_tensor_strides(tensors: &[Arc<dyn TensorDescriptor>]) -> Result<TensorFormat> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let mut format = TensorFormat::Unknown;
for tensor in tensors {
let stride = tensor.stride();
let shape = tensor.shape();
if stride.len() < 2 {
return Err(anyhow!(
"Tensor must have at least 2 dimensions, got stride: {:?}",
stride
));
}
// Check monotonic decreasing stride
// Note: We're flexible at the combined inner dimension boundary as per requirements
let mut prev_stride = usize::MAX;
for (i, &current_stride) in stride.iter().enumerate() {
if current_stride > prev_stride {
return Err(anyhow!(
"Tensor strides must be monotonically decreasing (until inner dimension). \
Got stride: {:?} at position {}",
stride,
i
));
}
prev_stride = current_stride;
}
// Attempt to detect NHD vs HND format based on shape and stride patterns
// This is a heuristic and may need refinement based on actual usage
if shape.len() >= 3 {
// If the first dimension stride is smaller than the second, likely HND
// If the first dimension stride is larger than the second, likely NHD
if stride[0] < stride[1] {
format = TensorFormat::HND;
} else if stride[0] > stride[1] {
format = TensorFormat::NHD;
}
}
}
Ok(format)
}
/// Validate that all tensors have consistent shapes.
///
/// # Arguments
/// * `tensors` - Slice of tensors to validate
///
/// # Returns
/// The common shape shared by all tensors
#[expect(dead_code)]
pub fn validate_tensor_shapes(tensors: &[Arc<dyn TensorDescriptor>]) -> Result<Vec<usize>> {
if tensors.is_empty() {
return Err(anyhow!("Cannot validate empty tensor list"));
}
let first_shape = tensors[0].shape();
for tensor in &tensors[1..] {
if tensor.shape() != first_shape {
return Err(anyhow!(
"All tensors must have the same shape. Expected {:?}, got {:?}",
first_shape,
tensor.shape()
));
}
}
Ok(first_shape.to_vec())
}
#[allow(dead_code)]
pub fn determine_compressed_shape(shape: &[usize]) -> usize {
shape.iter().product()
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
// Note: These tests would require mock TorchTensor implementations
// which we can add if needed for testing infrastructure
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod layout;
pub mod manager;
pub mod transfer;
pub use manager::TransferManager;
pub use transfer::{TransferConfig, TransferOptions};
pub use kvbm_common::BlockId;
pub type SequenceHash = kvbm_common::SequenceHash;
#[cfg(test)]
#[cfg(not(feature = "testing-kvbm"))]
mod sentinel {
#[test]
#[allow(non_snake_case)]
fn all_functional_tests_skipped___enable_testing_kvbm() {
eprintln!(
"kvbm-physical functional tests require feature `testing-kvbm`. \
Run with: cargo test -p kvbm-physical --features testing-kvbm"
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Layout handle type encoding worker ID and layout ID.
use bincode::{Decode, Encode};
use serde::{Deserialize, Serialize};
/// Unique handle for a layout combining worker_id and layout_id.
///
/// The handle encodes:
/// - Bits 0-63: worker_id (u64)
/// - Bits 64-79: layout_id (u16)
/// - Bits 80-127: Reserved (48 bits, currently unused)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encode, Decode, Serialize, Deserialize)]
pub struct LayoutHandle(u128);
impl LayoutHandle {
/// Create a new layout handle from worker_id and layout_id.
///
/// # Arguments
/// * `worker_id` - Unique identifier for the worker (0-63 bits)
/// * `layout_id` - Layout identifier within the worker (64-79 bits)
pub fn new(worker_id: u64, layout_id: u16) -> Self {
let handle = (worker_id as u128) | ((layout_id as u128) << 64);
Self(handle)
}
/// Extract the worker_id from this handle.
pub fn worker_id(&self) -> u64 {
(self.0 & 0xFFFF_FFFF_FFFF_FFFF) as u64
}
/// Extract the layout_id from this handle.
pub fn layout_id(&self) -> u16 {
((self.0 >> 64) & 0xFFFF) as u16
}
/// Get the raw u128 value.
pub fn as_u128(&self) -> u128 {
self.0
}
/// Reconstruct a handle from a raw u128 value.
///
/// This preserves all bits including reserved bits, and is intended for
/// deserialization roundtrips with `as_u128()`.
pub fn from_u128(value: u128) -> Self {
Self(value)
}
}
impl std::fmt::Display for LayoutHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LayoutHandle(worker={}, layout={})",
self.worker_id(),
self.layout_id()
)
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
#[test]
fn test_handle_encoding() {
let worker_id = 0x1234_5678_9ABC_DEF0u64;
let layout_id = 0x4242u16;
let handle = LayoutHandle::new(worker_id, layout_id);
assert_eq!(handle.worker_id(), worker_id);
assert_eq!(handle.layout_id(), layout_id);
}
#[test]
fn test_handle_roundtrip() {
let handle = LayoutHandle::new(42, 100);
let raw = handle.as_u128();
let restored = LayoutHandle::from_u128(raw);
assert_eq!(handle, restored);
assert_eq!(restored.worker_id(), 42);
assert_eq!(restored.layout_id(), 100);
}
#[test]
fn test_handle_max_values() {
let max_worker = u64::MAX;
let max_layout = u16::MAX;
let handle = LayoutHandle::new(max_worker, max_layout);
assert_eq!(handle.worker_id(), max_worker);
assert_eq!(handle.layout_id(), max_layout);
}
#[test]
fn test_handle_bincode_roundtrip() {
let handle = LayoutHandle::new(999, 42);
let encoded = bincode::encode_to_vec(handle, bincode::config::standard()).unwrap();
let (decoded, _): (LayoutHandle, _) =
bincode::decode_from_slice(&encoded, bincode::config::standard()).unwrap();
assert_eq!(handle, decoded);
}
#[test]
fn test_handle_display() {
let handle = LayoutHandle::new(123, 456);
let display = format!("{}", handle);
assert!(display.contains("123"));
assert!(display.contains("456"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Local layout wrapper with handle and metadata.
use std::ops::Deref;
use super::handle::LayoutHandle;
use crate::layout::PhysicalLayout;
/// A local physical layout with an assigned handle.
///
/// This wraps a `PhysicalLayout` that exists on the local worker,
/// associating it with a unique handle that combines the worker_id
/// and a locally-assigned layout_id.
///
/// This type is cheap to clone as `PhysicalLayout` contains `Arc` internally.
#[derive(Debug, Clone)]
pub struct LocalLayout {
handle: LayoutHandle,
layout: PhysicalLayout,
}
#[allow(dead_code)]
impl LocalLayout {
/// Create a new local layout.
///
/// # Arguments
/// * `handle` - Unique handle for this layout
/// * `layout` - The physical layout
pub fn new(handle: LayoutHandle, layout: PhysicalLayout) -> Self {
Self { handle, layout }
}
/// Get the handle for this layout.
pub fn handle(&self) -> LayoutHandle {
self.handle
}
/// Get a reference to the physical layout.
pub fn layout(&self) -> &PhysicalLayout {
&self.layout
}
/// Get the worker_id from the handle.
pub fn worker_id(&self) -> u64 {
self.handle.worker_id()
}
/// Get the layout_id from the handle.
pub fn layout_id(&self) -> u16 {
self.handle.layout_id()
}
/// Consume this local layout and return the physical layout.
pub fn into_layout(self) -> PhysicalLayout {
self.layout
}
}
impl Deref for LocalLayout {
type Target = PhysicalLayout;
fn deref(&self) -> &Self::Target {
&self.layout
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
use crate::layout::{LayoutConfig, PhysicalLayout};
use dynamo_memory::nixl::NixlAgent;
fn create_test_agent(name: &str) -> NixlAgent {
NixlAgent::new(name).expect("failed to create agent")
}
fn make_test_layout() -> PhysicalLayout {
let agent = create_test_agent("test-local");
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
PhysicalLayout::builder(agent)
.with_config(config)
.fully_contiguous()
.allocate_system()
.build()
.unwrap()
}
#[test]
fn test_local_layout_creation() {
let handle = LayoutHandle::new(42, 100);
let layout = make_test_layout();
let local = LocalLayout::new(handle, layout);
assert_eq!(local.handle(), handle);
assert_eq!(local.worker_id(), 42);
assert_eq!(local.layout_id(), 100);
}
#[test]
fn test_local_layout_into_layout() {
let handle = LayoutHandle::new(1, 2);
let layout = make_test_layout();
let local = LocalLayout::new(handle, layout);
let _recovered = local.into_layout();
// Successfully consumed and returned the layout
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Serialization types for exporting/importing layout metadata with NIXL integration.
use super::handle::LayoutHandle;
use crate::layout::LayoutDescriptor;
use anyhow::Result;
use bincode::{Decode, Encode};
use serde::{Deserialize, Serialize};
use kvbm_common::LogicalLayoutHandle;
/// Worker identification combining worker_id and NIXL agent name.
#[derive(Debug, Clone, Encode, Decode, PartialEq, Eq)]
pub struct WorkerAddress {
/// Unique identifier for this worker
pub worker_id: u64,
/// NIXL agent name on this worker
pub nixl_agent_name: String,
}
impl WorkerAddress {
/// Create a new worker address.
pub fn new(worker_id: u64, nixl_agent_name: String) -> Self {
Self {
worker_id,
nixl_agent_name,
}
}
}
/// Layout descriptor with its assigned handle and logical type for RDMA metadata exchange.
///
/// This includes the logical layout type (G1, G2, G3, G4) so that remote instances
/// know which physical handle corresponds to which tier.
#[derive(Debug, Clone, Encode, Decode)]
pub struct LogicalLayoutDescriptor {
/// Unique handle for this layout
pub handle: LayoutHandle,
/// The logical layout type (G1, G2, G3, G4)
#[bincode(with_serde)]
pub logical_type: LogicalLayoutHandle,
/// Serialized layout data (uses Serde, bridged via bincode)
#[bincode(with_serde)]
pub layout: LayoutDescriptor,
}
impl LogicalLayoutDescriptor {
/// Create a new layout descriptor with handle and logical type.
pub fn new(
handle: LayoutHandle,
logical_type: LogicalLayoutHandle,
layout: LayoutDescriptor,
) -> Self {
Self {
handle,
logical_type,
layout,
}
}
/// Create a layout descriptor with G2 as the default logical type.
///
/// This is provided for backwards compatibility with code that doesn't
/// track logical types. G2 is used as the default since it's the most
/// common tier for RDMA transfers (GPU memory for KV cache).
///
/// For proper RDMA transfers between instances, use `new()` with the
/// correct logical type from the Worker's registered handles.
pub fn new_with_default_type(handle: LayoutHandle, layout: LayoutDescriptor) -> Self {
Self {
handle,
logical_type: LogicalLayoutHandle::G2,
layout,
}
}
}
/// Type alias for backwards compatibility.
pub type LocalLayoutDescriptor = LogicalLayoutDescriptor;
/// The set of [`LogicalLayoutDescriptor`] that are RDMA enabled. This object packages the detail
/// about the layouts and the NIXL RDMA metadata required to reconstruct the layouts and access
/// the memory via NIXL RDMA.
#[derive(Debug, Encode, Decode)]
pub struct RdmaLayoutDescriptors {
/// Worker identification
pub worker_address: WorkerAddress,
/// Exported NIXL metadata from nixl_sys::Agent::get_local_md()
pub nixl_metadata: Vec<u8>,
/// Serialized layouts (handle + logical type + layout data)
pub layouts: Vec<LogicalLayoutDescriptor>,
}
/// Managed memory metadata package for export/import.
///
/// This is the wire format for transmitting layout metadata between workers.
/// It contains everything needed to reconstruct remote layouts and load their
/// NIXL registration data.
#[derive(Clone, Serialize, Deserialize, Encode, Decode)]
#[serde(transparent)]
pub struct SerializedLayout(Vec<u8>);
impl SerializedLayout {
/// Pack metadata into a serialized form.
///
/// # Arguments
/// * `worker_address` - Worker identification
/// * `nixl_metadata` - NIXL metadata blob from get_local_md()
/// * `layouts` - Vector of layouts with handles and logical types to export
///
/// # Returns
/// Packed metadata ready for transmission
pub fn pack(
worker_address: WorkerAddress,
nixl_metadata: Vec<u8>,
layouts: Vec<LogicalLayoutDescriptor>,
) -> Result<Self> {
let inner = RdmaLayoutDescriptors {
worker_address,
nixl_metadata,
layouts,
};
let bytes = bincode::encode_to_vec(&inner, bincode::config::standard())
.map_err(|e| anyhow::anyhow!("failed to encode managed memory metadata: {}", e))?;
Ok(Self(bytes))
}
/// Unpack metadata from serialized form.
///
/// # Returns
/// Unpacked metadata structure
pub fn unpack(&self) -> Result<RdmaLayoutDescriptors> {
let (inner, _) = bincode::decode_from_slice(&self.0, bincode::config::standard())
.map_err(|e| anyhow::anyhow!("failed to decode managed memory metadata: {}", e))?;
Ok(inner)
}
/// Get the raw bytes.
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Create from raw bytes.
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self(bytes)
}
/// Get the size in bytes.
pub fn len(&self) -> usize {
self.0.len()
}
/// Check if empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl std::fmt::Debug for SerializedLayout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SerializedLayout")
.field("size_bytes", &self.len())
.finish()
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
use crate::layout::{
BlockFormat, FullyContiguousDetails, KvBlockLayout, LayoutConfig, LayoutDescriptor,
LayoutTypeDetails, NixlMetadata,
};
use dynamo_memory::{MemoryRegion, StorageKind, nixl};
use kvbm_common::LogicalLayoutHandle;
fn make_test_serialized_layout() -> LayoutDescriptor {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
LayoutDescriptor {
version: 1,
layout_config: config,
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("test".to_string(), nixl::MemType::Dram, 0),
memory_descriptors: vec![MemoryRegion {
addr: 0x1000,
size: 4096,
}],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
kv_block_layout: KvBlockLayout::OperationalNHD,
}),
}
}
#[test]
fn test_worker_address() {
let addr = WorkerAddress::new(42, "test_agent".to_string());
assert_eq!(addr.worker_id, 42);
assert_eq!(addr.nixl_agent_name, "test_agent");
}
#[test]
fn test_serialized_layout_with_handle() {
let handle = LayoutHandle::new(1, 2);
let layout = make_test_serialized_layout();
let with_handle = LogicalLayoutDescriptor::new(handle, LogicalLayoutHandle::G2, layout);
assert_eq!(with_handle.handle, handle);
assert_eq!(with_handle.logical_type, LogicalLayoutHandle::G2);
}
#[test]
fn test_metadata_pack_unpack() {
let worker_address = WorkerAddress::new(100, "worker_100".to_string());
let nixl_metadata = vec![1, 2, 3, 4, 5];
let layouts = vec![LogicalLayoutDescriptor::new(
LayoutHandle::new(100, 1),
LogicalLayoutHandle::G2,
make_test_serialized_layout(),
)];
let packed =
SerializedLayout::pack(worker_address.clone(), nixl_metadata.clone(), layouts).unwrap();
assert!(!packed.is_empty());
let unpacked = packed.unpack().unwrap();
assert_eq!(unpacked.worker_address, worker_address);
assert_eq!(unpacked.nixl_metadata, nixl_metadata);
assert_eq!(unpacked.layouts.len(), 1);
assert_eq!(unpacked.layouts[0].handle.worker_id(), 100);
assert_eq!(unpacked.layouts[0].handle.layout_id(), 1);
assert_eq!(unpacked.layouts[0].logical_type, LogicalLayoutHandle::G2);
}
#[test]
fn test_metadata_multiple_layouts() {
let worker_address = WorkerAddress::new(200, "worker_200".to_string());
let nixl_metadata = vec![10, 20, 30];
let layouts = vec![
LogicalLayoutDescriptor::new(
LayoutHandle::new(200, 1),
LogicalLayoutHandle::G1,
make_test_serialized_layout(),
),
LogicalLayoutDescriptor::new(
LayoutHandle::new(200, 2),
LogicalLayoutHandle::G2,
make_test_serialized_layout(),
),
LogicalLayoutDescriptor::new(
LayoutHandle::new(200, 3),
LogicalLayoutHandle::G3,
make_test_serialized_layout(),
),
];
let packed =
SerializedLayout::pack(worker_address, nixl_metadata, layouts.clone()).unwrap();
let unpacked = packed.unpack().unwrap();
assert_eq!(unpacked.layouts.len(), 3);
let expected_logical_types = [
LogicalLayoutHandle::G1,
LogicalLayoutHandle::G2,
LogicalLayoutHandle::G3,
];
for (i, layout) in unpacked.layouts.iter().enumerate() {
assert_eq!(layout.handle.worker_id(), 200);
assert_eq!(layout.handle.layout_id(), (i + 1) as u16);
assert_eq!(layout.logical_type, expected_logical_types[i]);
}
}
#[test]
fn test_metadata_from_bytes() {
let worker_address = WorkerAddress::new(42, "test".to_string());
let nixl_metadata = vec![1, 2, 3];
let layouts = vec![];
let packed = SerializedLayout::pack(worker_address, nixl_metadata, layouts).unwrap();
let bytes = packed.as_bytes().to_vec();
let restored = SerializedLayout::from_bytes(bytes);
let unpacked = restored.unpack().unwrap();
assert_eq!(unpacked.worker_address.worker_id, 42);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transport manager for local and remote physical layouts with transfer execution.
mod handle;
mod local;
mod metadata;
mod remote;
pub use handle::LayoutHandle;
pub use metadata::{LogicalLayoutDescriptor, SerializedLayout, WorkerAddress};
pub(crate) use local::LocalLayout;
pub(crate) use metadata::LocalLayoutDescriptor;
pub(crate) use remote::RemoteLayout;
use crate::layout::PhysicalLayout;
use crate::transfer::BounceBufferInternal;
use crate::transfer::TransferContext;
use crate::transfer::context::TransferCompleteNotification;
use crate::transfer::executor::TransferOptionsInternal;
use crate::transfer::options::TransferOptions;
use crate::{BlockId, SequenceHash};
use anyhow::{Result, anyhow, bail};
use dynamo_memory::StorageKind;
use dynamo_memory::nixl::NixlAgent;
use kvbm_common::LogicalLayoutHandle;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Arc, RwLock};
/// Public entry point for layout and transfer management.
///
/// TransferManager combines layout registration/metadata management with
/// transfer execution capabilities, providing a unified API for:
/// - Registering local layouts and obtaining handles
/// - Exporting/importing layout metadata for remote workers
/// - Executing transfers between layouts using handles
/// - Managing CUDA, NIXL, and other execution resources
#[derive(Clone)]
pub struct TransferManager {
registry: Arc<RwLock<LayoutRegistry>>,
context: Arc<TransferContext>,
}
impl TransferManager {
/// Create a new TransferManager builder.
///
/// The builder configures the worker ID, NIXL agent, CUDA device,
/// and other execution parameters before creating the manager.
///
/// # Example
/// ```ignore
/// let manager = TransferManager::builder()
/// .worker_id(0) // NIXL agent name defaults to "worker-0"
/// .nixl_backend("ucx") // Optional: defaults to UCX from env
/// .cuda_device_id(0)
/// .build()?;
///
/// // Or with custom agent name:
/// let manager = TransferManager::builder()
/// .worker_id(0)
/// .nixl_agent_name("custom-agent")
/// .build()?;
/// ```
pub fn builder() -> crate::transfer::context::TransferConfigBuilder {
TransferContext::builder()
}
/// Create a TransferManager from a built TransferContext.
///
/// This is used internally by the builder to wrap the context
/// and create the associated registry.
pub(crate) fn from_context(context: TransferContext) -> Self {
let worker_id = context.worker_id();
let nixl_agent = context.nixl_agent().clone();
let registry = Arc::new(RwLock::new(LayoutRegistry::new(nixl_agent, worker_id)));
Self {
registry,
context: Arc::new(context),
}
}
// ===== Layout Registration and Metadata Management =====
/// Register a local physical layout and return a unique handle.
///
/// This registers the layout with the embedded memory manager, assigning
/// it a unique handle that can be used for handle-based transfers.
///
/// # Arguments
/// * `layout` - Physical layout to register
///
/// # Returns
/// Unique handle for the registered layout
///
/// # Errors
/// Returns an error if layout IDs are exhausted (u16::MAX reached)
pub fn register_layout(&self, layout: PhysicalLayout) -> Result<LayoutHandle> {
self.registry.write().unwrap().register_local(layout)
}
/// Export layout metadata for transmission to remote workers.
///
/// This exports all registered local layouts along with NIXL metadata
/// needed for remote memory registration.
///
/// # Returns
/// Packed metadata ready for transmission to remote workers
pub fn export_metadata(&self) -> Result<SerializedLayout> {
self.registry.read().unwrap().export_metadata()
}
/// Import remote layout metadata.
///
/// This loads NIXL metadata and reconstructs physical layouts from a remote
/// worker's exported metadata.
///
/// # Arguments
/// * `metadata` - Packed metadata from remote worker
///
/// # Returns
/// Vector of handles for the imported remote layouts
///
/// # Errors
/// Returns an error if the remote worker was already loaded or if metadata
/// loading/reconstruction fails
pub fn import_metadata(&self, metadata: SerializedLayout) -> Result<Vec<LayoutHandle>> {
self.registry.write().unwrap().import_metadata(metadata)
}
/// Build a logical layout descriptor for a specific handle.
///
/// This creates a descriptor that includes the logical layout type (G1, G2, G3, G4)
/// for use in RDMA metadata exchange. The caller must provide the logical type
/// mapping since only the caller (e.g., DirectWorker) knows which handle corresponds
/// to which logical tier.
///
/// # Arguments
/// * `handle` - Handle to the local layout
/// * `logical_type` - The logical tier (G1, G2, G3, G4) this handle represents
///
/// # Returns
/// A LogicalLayoutDescriptor ready for serialization
///
/// # Errors
/// Returns an error if the handle is not found or serialization fails
pub fn build_logical_descriptor(
&self,
handle: LayoutHandle,
logical_type: LogicalLayoutHandle,
) -> Result<LogicalLayoutDescriptor> {
self.registry
.read()
.unwrap()
.build_logical_descriptor(handle, logical_type)
}
/// Get the NIXL metadata for this worker.
///
/// Returns the raw NIXL metadata bytes needed for remote registration.
pub fn get_nixl_metadata(&self) -> Result<Vec<u8>> {
self.registry.read().unwrap().get_nixl_metadata()
}
/// Get the worker address for this manager.
pub fn worker_address(&self) -> WorkerAddress {
self.registry.read().unwrap().worker_address()
}
/// Get a reference to the NIXL agent.
///
/// This is useful for building layouts that need to register memory
/// with the same agent that the TransferManager uses.
pub fn nixl_agent(&self) -> &NixlAgent {
self.context.nixl_agent()
}
/// Get the layout configuration for a registered layout.
///
/// Returns a clone of the layout's configuration, which includes
/// dimensions like num_blocks, num_layers, page_size, etc.
///
/// # Arguments
/// * `handle` - Handle to a registered layout (local or remote)
///
/// # Returns
/// A clone of the layout's configuration
///
/// # Errors
/// Returns an error if the handle is not found
pub fn get_layout_config(&self, handle: LayoutHandle) -> Result<crate::layout::LayoutConfig> {
let registry = self.registry.read().unwrap();
let physical_layout = registry
.get_layout(handle)
.ok_or_else(|| anyhow!("invalid handle: {}", handle))?;
Ok(physical_layout.layout().config().clone())
}
// ===== Handle-Based Transfer API =====
/// Transfer complete blocks between layouts using handles.
///
/// This function copies entire blocks (all layers and outer dimensions) between
/// the source and destination layouts identified by their handles. The transfer
/// strategy (memcpy, CUDA, NIXL) is automatically selected based on storage locations.
///
/// The lock on the registry is held only briefly during layout lookup,
/// then released before executing the actual transfer.
///
/// # Arguments
/// * `src_handle` - Handle to source layout
/// * `src_blocks` - Source block IDs to transfer
/// * `dst_handle` - Handle to destination layout
/// * `dst_blocks` - Destination block IDs to transfer
///
/// # Returns
/// A notification handle that can be awaited for transfer completion
///
/// # Errors
/// Returns an error if:
/// - Either handle is invalid
/// - Block IDs are out of bounds
/// - Transfer execution fails
pub fn execute_transfer(
&self,
src_handle: LayoutHandle,
src_blocks: &[BlockId],
dst_handle: LayoutHandle,
dst_blocks: &[BlockId],
options: TransferOptions,
) -> Result<TransferCompleteNotification> {
// Clone layouts inside the lock, then drop lock before transfer
let (src_layout, dst_layout) = {
let registry = self.registry.read().unwrap();
let src = registry
.get_layout(src_handle)
.ok_or_else(|| anyhow!("invalid source handle: {}", src_handle))?
.clone(); // Cheap: just Arc refcount bump
let dst = registry
.get_layout(dst_handle)
.ok_or_else(|| anyhow!("invalid destination handle: {}", dst_handle))?
.clone();
(src, dst)
}; // Lock released here
let (
layer_range,
nixl_write_notification,
bounce_buffer,
cuda_stream,
src_kv_layout,
dst_kv_layout,
) = options.dissolve();
let mut internal_options = TransferOptionsInternal::builder();
if let Some(range) = layer_range {
internal_options = internal_options.layer_range(range);
}
if let Some(notification) = nixl_write_notification {
internal_options = internal_options.nixl_write_notification(notification);
}
if let Some(bounce) = bounce_buffer {
let (handle, block_ids) = bounce.into_parts();
let bounce_buffer = self.create_bounce_buffer(handle, block_ids)?;
internal_options = internal_options.bounce_buffer(bounce_buffer);
}
if let Some(stream) = cuda_stream {
internal_options = internal_options.cuda_stream(stream);
}
if let Some(layout) = src_kv_layout {
internal_options = internal_options.src_kv_layout(layout);
}
if let Some(layout) = dst_kv_layout {
internal_options = internal_options.dst_kv_layout(layout);
}
let options = internal_options.build()?;
tracing::debug!(
src_handle = src_handle.to_string(),
dst_handle = dst_handle.to_string(),
"Executing transfer; src_blocks = {:?}; dst_blocks = {:?}",
src_blocks,
dst_blocks,
);
// Execute transfer with no lock held
super::transfer::executor::execute_transfer(
&src_layout,
&dst_layout,
src_blocks,
dst_blocks,
options,
&self.context,
)
}
/// Execute a G4 offload.
///
/// Takes a LayoutHandle and a vector of block IDs for the source blocks and
/// a list of SequenceHashes for the destination blocks.
///
/// use an extension on TransferOptions to pass in the "rank/part" of the the object in a
/// multi-worker/multi-tp scenario.
pub fn execute_g4_offload(
_src_handle: LayoutHandle,
_src_blocks: &[BlockId],
_dst_object: &[SequenceHash],
_options: TransferOptions, // add rank/part to the options
) -> Result<TransferCompleteNotification> {
// check registration cache for the remote object, if it's not found, register it with nixl
// register all non-registered blocks with nixl in parallel
// then extend super::transfer::executor to access the memory regions for the source
// and generate a nixl descriptor
todo!("implement remote offload")
}
pub fn execute_g4_onboard() {
todo!("implement remote onboard")
}
// ===== Query Methods =====
/// Get the worker ID for this manager.
pub fn worker_id(&self) -> u64 {
self.context.worker_id()
}
/// Get handles for all locally registered layouts.
pub fn get_local_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().local_handles()
}
/// Get handles for all imported remote layouts.
pub fn get_remote_handles(&self) -> Vec<LayoutHandle> {
self.registry.read().unwrap().remote_handles()
}
/// Get a clone of the physical layout for a given handle.
///
/// # Arguments
/// * `handle` - Handle to a registered layout (local or remote)
///
/// # Returns
/// A clone of the physical layout, or None if the handle is not found.
pub fn get_physical_layout(&self, handle: LayoutHandle) -> Option<PhysicalLayout> {
self.registry.read().unwrap().get_layout(handle).cloned()
}
/// Create a bounce buffer specification from a layout handle and block IDs.
///
/// This resolves the layout handle to a physical layout and wraps it in a
/// BounceBufferSpec implementation for use in transfer options.
pub(crate) fn create_bounce_buffer(
&self,
handle: LayoutHandle,
block_ids: Vec<BlockId>,
) -> Result<BounceBufferInternal> {
let layout = {
let registry = self.registry.read().unwrap();
registry
.get_layout(handle)
.ok_or_else(|| anyhow!("invalid bounce buffer handle: {}", handle))?
.clone()
};
Ok(BounceBufferInternal::from_layout(layout, block_ids))
}
// ===== Internal Methods for Testing =====
/// Get the internal transfer context.
#[doc(hidden)]
pub fn context(&self) -> &TransferContext {
&self.context
}
/// Get access to the internal layout registry.
///
/// This is primarily for testing utilities that need direct layout access
/// (e.g., fill patterns, checksum computation).
#[doc(hidden)]
pub fn registry(&self) -> &RwLock<LayoutRegistry> {
&self.registry
}
/// Get the H2D stream (for testing only).
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn h2d_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.h2d_stream()
}
/// Get the D2H stream (for testing only).
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn d2h_stream(&self) -> &std::sync::Arc<cudarc::driver::CudaStream> {
self.context.d2h_stream()
}
/// Get the CUDA context (for testing only).
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn cuda_context(&self) -> &std::sync::Arc<cudarc::driver::CudaContext> {
self.context.cuda_context()
}
/// Register a CUDA event for completion (for testing only).
#[cfg(test)]
#[allow(dead_code)]
pub(crate) fn register_cuda_event(
&self,
event: cudarc::driver::CudaEvent,
) -> TransferCompleteNotification {
self.context.register_cuda_event(event)
}
/// Get the CUDA memory pool (for testing only).
#[cfg(test)]
#[expect(dead_code)]
pub(crate) fn cuda_pool(&self) -> &std::sync::Arc<dynamo_memory::CudaMemPool> {
self.context.cuda_pool()
}
}
/// Internal registry for local and remote physical layouts with NIXL integration.
///
/// The LayoutRegistry handles:
/// - Registering local layouts with unique handles
/// - Exporting local layout metadata for remote access
/// - Importing remote layout metadata and reconstructing layouts
/// - Managing NIXL metadata for RDMA operations
#[derive(Debug)]
#[doc(hidden)]
pub struct LayoutRegistry {
/// NIXL agent for memory registration
nixl_agent: NixlAgent,
/// Worker ID for this manager
worker_id: u64,
/// Next layout ID to assign (monotonically increasing)
next_layout_id: AtomicU16,
/// Local layouts registered on this worker
local_layouts: HashMap<LayoutHandle, LocalLayout>,
/// Remote layouts imported from other workers
remote_layouts: HashMap<LayoutHandle, RemoteLayout>,
/// Set of loaded remote workers (agent_name, worker_id) to prevent duplicates
loaded_remotes: HashSet<(String, u64)>,
}
#[expect(dead_code)]
impl LayoutRegistry {
/// Create a new layout manager.
///
/// # Arguments
/// * `nixl_agent` - NIXL agent for memory registration
/// * `worker_id` - Unique identifier for this worker
pub(crate) fn new(nixl_agent: NixlAgent, worker_id: u64) -> Self {
Self {
nixl_agent,
worker_id,
next_layout_id: AtomicU16::new(0),
local_layouts: HashMap::new(),
remote_layouts: HashMap::new(),
loaded_remotes: HashSet::new(),
}
}
/// Register a local physical layout.
///
/// # Arguments
/// * `layout` - Physical layout to register
///
/// # Returns
/// Unique handle for the registered layout
///
/// # Errors
/// Returns an error if layout IDs are exhausted (u16::MAX reached)
pub(crate) fn register_local(&mut self, layout: PhysicalLayout) -> Result<LayoutHandle> {
// Check before incrementing to prevent wrapping
let current = self.next_layout_id.load(Ordering::SeqCst);
if current == u16::MAX {
bail!(
"Layout ID overflow: maximum number of layouts ({}) reached",
u16::MAX
);
}
let layout_id = self.next_layout_id.fetch_add(1, Ordering::SeqCst);
// Create handle
let handle = LayoutHandle::new(self.worker_id, layout_id);
// Wrap in LocalLayout
let local_layout = LocalLayout::new(handle, layout);
// Store
self.local_layouts.insert(handle, local_layout);
Ok(handle)
}
/// Export local layout metadata for transmission to remote workers.
///
/// This exports:
/// - NIXL agent metadata for remote memory registration
/// - All host and device layouts (disk layouts are excluded)
/// - Worker address information
///
/// # Returns
/// Packed metadata ready for transmission
pub(crate) fn export_metadata(&self) -> Result<SerializedLayout> {
// Get NIXL metadata from agent
let nixl_metadata = self
.nixl_agent
.get_local_md()
.map_err(|e| anyhow!("failed to get NIXL local metadata: {:?}", e))?;
// Create worker address
let worker_address = WorkerAddress::new(self.worker_id, self.nixl_agent.name().to_string());
// Filter and serialize layouts (only host and device, skip disk)
let mut serialized_layouts = Vec::new();
for (handle, local_layout) in &self.local_layouts {
let location = local_layout.layout().location();
// Only export host and device layouts
if matches!(
location,
StorageKind::System | StorageKind::Device(_) | StorageKind::Pinned
) {
let serialized = local_layout
.layout()
.to_descriptor()
.map_err(|e| anyhow!("failed to serialize layout {}: {}", handle, e))?;
serialized_layouts.push(LocalLayoutDescriptor::new_with_default_type(
*handle, serialized,
));
}
}
// Pack into managed metadata
SerializedLayout::pack(worker_address, nixl_metadata, serialized_layouts)
}
/// Import remote layout metadata.
///
/// This:
/// - Validates the remote worker hasn't been loaded already
/// - Loads NIXL metadata into the agent
/// - Reconstructs physical layouts from serialized data
/// - Stores them as remote layouts
///
/// # Arguments
/// * `metadata` - Packed metadata from remote worker
///
/// # Returns
/// Vector of handles for the imported layouts
///
/// # Errors
/// Returns an error if:
/// - The remote worker was already loaded
/// - NIXL metadata loading fails
/// - Agent name mismatch after loading
/// - Layout reconstruction fails
pub(crate) fn import_metadata(
&mut self,
metadata: SerializedLayout,
) -> Result<Vec<LayoutHandle>> {
// Unpack metadata
let inner = metadata.unpack()?;
// Validate not already loaded
let remote_key = (
inner.worker_address.nixl_agent_name.clone(),
inner.worker_address.worker_id,
);
if self.loaded_remotes.contains(&remote_key) {
bail!(
"Remote worker already loaded: {} (worker_id={})",
remote_key.0,
remote_key.1
);
}
// Load NIXL metadata
let returned_agent_name = self
.nixl_agent
.load_remote_md(&inner.nixl_metadata)
.map_err(|e| anyhow!("failed to load remote NIXL metadata: {:?}", e))?;
// Verify agent name matches
if returned_agent_name != inner.worker_address.nixl_agent_name {
bail!(
"Agent name mismatch: expected '{}', got '{}'",
inner.worker_address.nixl_agent_name,
returned_agent_name
);
}
// Reconstruct layouts
let mut imported_handles = Vec::new();
for serialized_with_handle in inner.layouts {
let handle = serialized_with_handle.handle;
let layout = PhysicalLayout::from_descriptor(serialized_with_handle.layout)
.map_err(|e| anyhow!("failed to reconstruct layout {}: {}", handle, e))?;
let remote_layout = RemoteLayout::new(handle, layout);
self.remote_layouts.insert(handle, remote_layout);
imported_handles.push(handle);
}
// Mark remote as loaded
self.loaded_remotes.insert(remote_key);
Ok(imported_handles)
}
/// Build a logical layout descriptor for a specific handle.
///
/// # Arguments
/// * `handle` - Handle to the local layout
/// * `logical_type` - The logical tier (G1, G2, G3, G4) this handle represents
///
/// # Returns
/// A LogicalLayoutDescriptor ready for serialization
pub(crate) fn build_logical_descriptor(
&self,
handle: LayoutHandle,
logical_type: LogicalLayoutHandle,
) -> Result<LogicalLayoutDescriptor> {
let local_layout = self
.local_layouts
.get(&handle)
.ok_or_else(|| anyhow!("Layout handle not found: {:?}", handle))?;
let layout_descriptor = local_layout
.layout()
.to_descriptor()
.map_err(|e| anyhow!("failed to serialize layout {}: {}", handle, e))?;
Ok(LogicalLayoutDescriptor::new(
handle,
logical_type,
layout_descriptor,
))
}
/// Get the NIXL metadata for this worker.
pub(crate) fn get_nixl_metadata(&self) -> Result<Vec<u8>> {
self.nixl_agent
.get_local_md()
.map_err(|e| anyhow!("failed to get NIXL local metadata: {:?}", e))
}
/// Get the worker address for this registry.
pub(crate) fn worker_address(&self) -> WorkerAddress {
WorkerAddress::new(self.worker_id, self.nixl_agent.name().to_string())
}
/// Get a local layout by handle.
pub(crate) fn get_local(&self, handle: LayoutHandle) -> Option<&LocalLayout> {
self.local_layouts.get(&handle)
}
/// Get a remote layout by handle.
pub(crate) fn get_remote(&self, handle: LayoutHandle) -> Option<&RemoteLayout> {
self.remote_layouts.get(&handle)
}
/// Get a layout by handle (either local or remote).
///
/// # Returns
/// Returns a reference to the PhysicalLayout if found
pub fn get_layout(&self, handle: LayoutHandle) -> Option<&PhysicalLayout> {
self.local_layouts
.get(&handle)
.map(|l| l.layout())
.or_else(|| self.remote_layouts.get(&handle).map(|r| r.layout()))
}
/// Check if a handle refers to a local layout.
pub(crate) fn is_local(&self, handle: LayoutHandle) -> bool {
self.local_layouts.contains_key(&handle)
}
/// Check if a handle refers to a remote layout.
pub(crate) fn is_remote(&self, handle: LayoutHandle) -> bool {
self.remote_layouts.contains_key(&handle)
}
/// Get the number of local layouts.
pub(crate) fn local_count(&self) -> usize {
self.local_layouts.len()
}
/// Get the number of remote layouts.
pub(crate) fn remote_count(&self) -> usize {
self.remote_layouts.len()
}
/// Get the worker ID for this manager.
pub(crate) fn worker_id(&self) -> u64 {
self.worker_id
}
/// Get all local layout handles.
pub(crate) fn local_handles(&self) -> Vec<LayoutHandle> {
self.local_layouts.keys().copied().collect()
}
/// Get all remote layout handles.
pub(crate) fn remote_handles(&self) -> Vec<LayoutHandle> {
self.remote_layouts.keys().copied().collect()
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
use crate::layout::LayoutConfig;
use dynamo_memory::nixl::NixlAgent;
fn make_test_agent(name: &str) -> NixlAgent {
NixlAgent::new(name).expect("failed to create agent")
}
fn make_test_layout(agent: &NixlAgent) -> PhysicalLayout {
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_system()
.build()
.unwrap()
}
#[test]
fn test_manager_creation() {
let agent = make_test_agent("test-manager");
let manager = LayoutRegistry::new(agent, 42);
assert_eq!(manager.worker_id(), 42);
assert_eq!(manager.local_count(), 0);
assert_eq!(manager.remote_count(), 0);
}
#[test]
fn test_register_local() {
let agent = make_test_agent("test-register");
let mut manager = LayoutRegistry::new(agent.clone(), 100);
let layout = make_test_layout(&agent);
let handle = manager.register_local(layout).unwrap();
assert_eq!(handle.worker_id(), 100);
assert_eq!(handle.layout_id(), 0);
assert_eq!(manager.local_count(), 1);
assert!(manager.is_local(handle));
assert!(!manager.is_remote(handle));
}
#[test]
fn test_register_multiple_locals() {
let agent = make_test_agent("test-multiple");
let mut manager = LayoutRegistry::new(agent.clone(), 1);
let handle1 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handle3 = manager.register_local(make_test_layout(&agent)).unwrap();
assert_eq!(handle1.layout_id(), 0);
assert_eq!(handle2.layout_id(), 1);
assert_eq!(handle3.layout_id(), 2);
assert_eq!(manager.local_count(), 3);
}
#[test]
#[ignore] // Requires actual NIXL memory registration
fn test_export_import_roundtrip() {
// Create source manager and register layouts
let source_agent = make_test_agent("source");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 1);
let handle1 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let handle2 = source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
// Export metadata
let metadata = source_manager.export_metadata().unwrap();
assert!(!metadata.is_empty());
// Create destination manager and import
let dest_agent = make_test_agent("dest");
let mut dest_manager = LayoutRegistry::new(dest_agent, 2);
let imported_handles = dest_manager.import_metadata(metadata).unwrap();
// Verify
assert_eq!(imported_handles.len(), 2);
assert_eq!(dest_manager.remote_count(), 2);
assert!(dest_manager.is_remote(handle1));
assert!(dest_manager.is_remote(handle2));
// Can get layouts
assert!(dest_manager.get_remote(handle1).is_some());
assert!(dest_manager.get_remote(handle2).is_some());
assert!(dest_manager.get_layout(handle1).is_some());
}
#[test]
#[ignore] // Requires actual NIXL memory registration
fn test_import_duplicate_remote_fails() {
let source_agent = make_test_agent("source2");
let mut source_manager = LayoutRegistry::new(source_agent.clone(), 10);
source_manager
.register_local(make_test_layout(&source_agent))
.unwrap();
let metadata = source_manager.export_metadata().unwrap();
let dest_agent = make_test_agent("dest2");
let mut dest_manager = LayoutRegistry::new(dest_agent, 20);
// First import succeeds
let metadata_clone = SerializedLayout::from_bytes(metadata.as_bytes().to_vec());
dest_manager.import_metadata(metadata).unwrap();
// Second import should fail
let result = dest_manager.import_metadata(metadata_clone);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already loaded"));
}
#[test]
fn test_get_layout_handles() {
let agent = make_test_agent("test-handles");
let mut manager = LayoutRegistry::new(agent.clone(), 5);
let h1 = manager.register_local(make_test_layout(&agent)).unwrap();
let h2 = manager.register_local(make_test_layout(&agent)).unwrap();
let handles = manager.local_handles();
assert_eq!(handles.len(), 2);
assert!(handles.contains(&h1));
assert!(handles.contains(&h2));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Remote layout wrapper reconstructed from imported metadata.
use super::handle::LayoutHandle;
use crate::layout::PhysicalLayout;
/// A remote physical layout reconstructed from imported metadata.
///
/// This wraps a `PhysicalLayout` that was deserialized from another worker's
/// exported metadata. The layout's memory regions point to addresses on the
/// remote worker and are used for building NIXL RDMA transfer descriptors.
///
/// This type is cheap to clone as `PhysicalLayout` contains `Arc` internally.
#[derive(Debug, Clone)]
pub struct RemoteLayout {
handle: LayoutHandle,
layout: PhysicalLayout,
}
#[allow(dead_code)]
impl RemoteLayout {
/// Create a new remote layout.
///
/// # Arguments
/// * `handle` - Unique handle for this layout (from remote worker)
/// * `layout` - The reconstructed physical layout
pub fn new(handle: LayoutHandle, layout: PhysicalLayout) -> Self {
Self { handle, layout }
}
/// Get the handle for this layout.
pub fn handle(&self) -> LayoutHandle {
self.handle
}
/// Get a reference to the physical layout.
pub fn layout(&self) -> &PhysicalLayout {
&self.layout
}
/// Get the worker_id from the handle (identifies the remote worker).
pub fn worker_id(&self) -> u64 {
self.handle.worker_id()
}
/// Get the layout_id from the handle.
pub fn layout_id(&self) -> u16 {
self.handle.layout_id()
}
/// Consume this remote layout and return the physical layout.
pub fn into_layout(self) -> PhysicalLayout {
self.layout
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
use crate::layout::{LayoutConfig, LayoutDescriptor, NixlMetadata, PhysicalLayout};
fn make_serialized_layout() -> LayoutDescriptor {
use crate::layout::{BlockFormat, FullyContiguousDetails, LayoutTypeDetails};
use dynamo_memory::{MemoryRegion, StorageKind, nixl};
let config = LayoutConfig::builder()
.num_blocks(2)
.num_layers(2)
.outer_dim(2)
.page_size(4)
.inner_dim(8)
.dtype_width_bytes(2)
.build()
.unwrap();
let required_size = config.num_blocks
* config.num_layers
* config.outer_dim
* config.page_size
* config.inner_dim
* config.dtype_width_bytes;
LayoutDescriptor {
version: 1,
layout_config: config,
location: StorageKind::System,
nixl_metadata: NixlMetadata::new("remote_agent".to_string(), nixl::MemType::Dram, 0),
memory_descriptors: vec![MemoryRegion {
addr: 0x1000,
size: required_size,
}],
layout_type_details: LayoutTypeDetails::FullyContiguous(FullyContiguousDetails {
block_format: BlockFormat::Operational,
kv_block_layout: crate::layout::KvBlockLayout::OperationalNHD,
}),
}
}
#[test]
fn test_remote_layout_creation() {
let handle = LayoutHandle::new(999, 42);
let serialized = make_serialized_layout();
let layout = PhysicalLayout::from_descriptor(serialized).unwrap();
let remote = RemoteLayout::new(handle, layout);
assert_eq!(remote.handle(), handle);
assert_eq!(remote.worker_id(), 999);
assert_eq!(remote.layout_id(), 42);
assert_eq!(
remote.layout().layout().block_layout(),
crate::layout::KvBlockLayout::OperationalNHD
);
}
#[test]
fn test_remote_layout_into_layout() {
let handle = LayoutHandle::new(100, 200);
let serialized = make_serialized_layout();
let layout = PhysicalLayout::from_descriptor(serialized).unwrap();
let remote = RemoteLayout::new(handle, layout);
let _recovered = remote.into_layout();
// Successfully consumed and returned the layout
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer capability flags for controlling direct path enablement.
//!
//! By default, the transfer system uses a conservative staging policy where:
//! - Device can only transfer to/from Host
//! - Disk can only transfer to/from Host
//! - Host can transfer to Device, Disk, or Remote
//! - Device ↔ Device is allowed (native CUDA)
//!
//! These capability flags enable optional direct paths that bypass host staging.
use serde::{Deserialize, Serialize};
use std::sync::OnceLock;
use crate::{
layout::LayoutConfig,
transfer::{
PhysicalLayout, TransferManager,
executor::{TransferOptionsInternal, execute_transfer},
},
};
use dynamo_memory::nixl::NixlAgent;
/// Transfer capability flags controlling which direct paths are enabled.
///
/// # Default Policy (Conservative)
///
/// With all flags disabled (default), the system uses host staging:
/// - **Device → Remote**: Device → Host → Remote (2 hops)
/// - **Disk → Remote**: Disk → Host → Remote (2 hops)
/// - **Device ↔ Disk**: Device → Host → Disk (2 hops)
///
/// # Optional Direct Paths
///
/// - `allow_gds`: Enables GPU Direct Storage (Disk ↔ Device without host)
/// - `allow_gpu_rdma`: Enables GPU RDMA (Device → Remote without host)
///
/// # Example
///
/// ```
/// # use kvbm_physical::transfer::TransferCapabilities;
/// // Default conservative policy
/// let caps = TransferCapabilities::default();
/// assert!(!caps.allow_gds);
/// assert!(!caps.allow_gpu_rdma);
///
/// // Enable GDS for high-performance disk I/O
/// let caps = TransferCapabilities::default().with_gds(true);
/// ```
static GDS_SUPPORTED: OnceLock<bool> = OnceLock::new();
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct TransferCapabilities {
/// Enable GPU Direct Storage (Disk ↔ Device without host staging).
///
/// When enabled:
/// - Disk → Device: Direct transfer (requires GDS support)
/// - Device → Disk: Direct transfer (requires GDS support)
///
/// When disabled (default):
/// - Disk → Device: Disk → Host → Device (2 hops)
/// - Device → Disk: Device → Host → Disk (2 hops)
pub allow_gds: bool,
/// Enable GPU RDMA (Device → Remote without host staging).
///
/// When enabled:
/// - Device → Remote: Direct NIXL transfer
///
/// When disabled (default):
/// - Device → Remote: Device → Host → Remote (2 hops)
///
/// Note: This only affects Device → Remote. Host → Remote is always direct.
pub allow_gpu_rdma: bool,
}
impl TransferCapabilities {
/// Create capabilities with default conservative policy (all direct paths disabled).
pub fn new() -> Self {
Self::default()
}
/// Create capabilities with all direct paths enabled (high performance mode).
pub fn all_enabled() -> Self {
Self {
allow_gds: true,
allow_gpu_rdma: true,
}
}
/// Set the GDS (GPU Direct Storage) capability.
pub fn with_gds(mut self, enabled: bool) -> Self {
self.allow_gds = enabled;
self
}
fn test_gds_transfer(&self) -> anyhow::Result<()> {
let agent = NixlAgent::with_backends("agent", &["GDS_MT"])?;
// Try a little test transfer and see if it works.
let config = LayoutConfig::builder()
.num_blocks(1)
.num_layers(1)
.outer_dim(1)
.page_size(1)
.inner_dim(4096)
.build()?;
let src = PhysicalLayout::builder(agent.clone())
.with_config(config.clone())
.fully_contiguous()
.allocate_device(0)
.build()?;
let dst = PhysicalLayout::builder(agent.clone())
.with_config(config)
.fully_contiguous()
.allocate_disk(None)
.build()?;
let src_blocks = vec![0];
let dst_blocks = vec![0];
let ctx = TransferManager::builder()
.nixl_agent(agent)
.cuda_device_id(0)
.build()?;
execute_transfer(
&src,
&dst,
&src_blocks,
&dst_blocks,
TransferOptionsInternal::default(),
ctx.context(),
)?;
Ok(())
}
pub fn with_gds_if_supported(mut self) -> Self {
self.allow_gds = *GDS_SUPPORTED.get_or_init(|| self.test_gds_transfer().is_ok());
self
}
/// Set the GPU RDMA capability.
pub fn with_gpu_rdma(mut self, enabled: bool) -> Self {
self.allow_gpu_rdma = enabled;
self
}
/// Check if a direct path from Device to Disk is allowed.
pub fn allows_device_disk_direct(&self) -> bool {
self.allow_gds
}
/// Check if a direct path from Device to Remote is allowed.
pub fn allows_device_remote_direct(&self) -> bool {
self.allow_gpu_rdma
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
#[test]
fn test_default_capabilities() {
let caps = TransferCapabilities::default();
assert!(!caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
assert!(!caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
}
#[test]
fn test_all_enabled() {
let caps = TransferCapabilities::all_enabled();
assert!(caps.allow_gds);
assert!(caps.allow_gpu_rdma);
assert!(caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
#[test]
fn test_builder_pattern() {
let caps = TransferCapabilities::new()
.with_gds(true)
.with_gpu_rdma(false);
assert!(caps.allow_gds);
assert!(!caps.allow_gpu_rdma);
}
#[test]
fn test_selective_enablement() {
// Enable only GDS
let caps = TransferCapabilities::new().with_gds(true);
assert!(caps.allows_device_disk_direct());
assert!(!caps.allows_device_remote_direct());
// Enable only GPU RDMA
let caps = TransferCapabilities::new().with_gpu_rdma(true);
assert!(!caps.allows_device_disk_direct());
assert!(caps.allows_device_remote_direct());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Block checksum computation for verification.
//!
//! This module provides utilities to compute checksums of blocks for
//! round-trip test verification.
use dynamo_memory::StorageKind;
use super::PhysicalLayout;
use aligned_vec::{AVec, avec};
use anyhow::{Result, anyhow};
use blake3::Hasher;
use std::{
collections::HashMap,
fs::File,
io::{Read, Seek},
mem::ManuallyDrop,
ops::Range,
os::fd::FromRawFd,
};
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind};
pub type BlockChecksum = String;
/// Compute checksums for a list of blocks.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to checksum
///
/// # Returns
/// A map from block ID to its checksum
///
/// # Errors
/// Returns an error if:
/// - Layout is remote (cannot checksum remote memory directly)
/// - Block IDs are out of range
pub fn compute_block_checksums(
layout: &PhysicalLayout,
block_ids: &[usize],
) -> Result<HashMap<usize, BlockChecksum>> {
let mut checksums = HashMap::new();
for &block_id in block_ids {
let checksum = compute_single_block_checksum(layout, block_id, None)?;
checksums.insert(block_id, checksum);
}
Ok(checksums)
}
/// Compute checksums for specific layers in blocks.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to checksum
/// * `layer_range` - Range of layers to include in checksum
///
/// # Returns
/// A map from block ID to its checksum (for the specified layers only)
pub fn compute_layer_checksums(
layout: &PhysicalLayout,
block_ids: &[usize],
layer_range: Range<usize>,
) -> Result<HashMap<usize, BlockChecksum>> {
let config = layout.layout().config();
if layer_range.end > config.num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layer_range,
config.num_layers
));
}
let mut checksums = HashMap::new();
for &block_id in block_ids {
let checksum = compute_single_block_checksum(layout, block_id, Some(layer_range.clone()))?;
checksums.insert(block_id, checksum);
}
Ok(checksums)
}
/// Compute checksum for a single block.
fn compute_single_block_checksum(
layout: &PhysicalLayout,
block_id: usize,
layer_range: Option<Range<usize>>,
) -> Result<String> {
let config = layout.layout().config();
if block_id >= config.num_blocks {
return Err(anyhow!("Block ID {} out of range", block_id));
}
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
let layers = layer_range.unwrap_or(0..num_layers);
// validate layer range
if layers.end > config.num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layers,
config.num_layers
));
}
let mut hasher = Hasher::new();
// Iterate over all layers and outer dimensions
for layer_id in layers {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
match layout.location() {
StorageKind::System | StorageKind::Pinned => {
let slice = unsafe {
std::slice::from_raw_parts(region.addr() as *const u8, region.size())
};
hasher.update(slice);
}
StorageKind::Device(_) => {
let mut system_region: Vec<u8> = vec![0; region.size()];
let err = unsafe {
cudaMemcpy(
system_region.as_mut_ptr() as *mut std::ffi::c_void,
region.addr() as *const std::ffi::c_void,
region.size(),
cudaMemcpyKind::cudaMemcpyDeviceToHost,
)
};
if err != cudarc::runtime::sys::cudaError::cudaSuccess {
return Err(anyhow!("cudaMemcpy D2H failed in checksum: {:?}", err));
}
hasher.update(system_region.as_slice());
}
StorageKind::Disk(fd) => {
let mut system_region: AVec<u8, _> = avec![[4096]| 0; region.size()];
let mut file = ManuallyDrop::new(unsafe { File::from_raw_fd(fd as i32) });
file.seek(std::io::SeekFrom::Start(region.addr() as u64))?;
file.read_exact(&mut system_region)?;
hasher.update(system_region.as_slice());
}
}
}
}
Ok(hasher.finalize().to_string())
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::super::tests::*;
use super::*;
use crate::transfer::{FillPattern, fill_blocks};
#[test]
fn test_checksum_constant_pattern() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap();
let checksums = compute_block_checksums(&physical, &[0, 1]).unwrap();
// Both blocks should have the same checksum values (same pattern)
assert_eq!(checksums[&0], checksums[&1]);
let memory_region = physical.memory_region(0, 0, 0).unwrap();
let slice = unsafe {
std::slice::from_raw_parts(memory_region.addr() as *const u8, memory_region.size())
};
assert!(slice.iter().all(|&b| b == 42));
let mut hasher = Hasher::new();
hasher.update(slice);
let checksum_mr_slice = hasher.finalize().to_string();
let vec = vec![42; memory_region.size()];
let mut hasher = Hasher::new();
hasher.update(&vec);
let checksum_vec = hasher.finalize().to_string();
assert_eq!(checksum_mr_slice, checksum_vec);
}
// #[test]
// fn test_checksum_different_patterns() {
// let (layout, _memory) = create_test_layout(2);
// let physical = PhysicalLayout::new_local(layout, StorageLocation::System);
// // Fill blocks with different patterns
// fill_blocks(&physical, &[0], FillPattern::Constant(42)).unwrap();
// fill_blocks(&physical, &[1], FillPattern::Constant(100)).unwrap();
// let checksums = compute_block_checksums(&physical, &[0, 1]).unwrap();
// // Blocks should have different checksums
// assert_ne!(checksums[&0], checksums[&1]);
// }
// #[test]
// fn test_checksum_matches() {
// let (layout1, _memory1) = create_test_layout(1);
// let (layout2, _memory2) = create_test_layout(1);
// let physical1 = PhysicalLayout::new_local(layout1, StorageLocation::System);
// let physical2 = PhysicalLayout::new_local(layout2, StorageLocation::System);
// // Fill both with same pattern
// fill_blocks(&physical1, &[0], FillPattern::Sequential).unwrap();
// fill_blocks(&physical2, &[0], FillPattern::Sequential).unwrap();
// let checksum1 = compute_block_checksums(&physical1, &[0]).unwrap();
// let checksum2 = compute_block_checksums(&physical2, &[0]).unwrap();
// // Checksums should match (ignoring block_id)
// assert!(checksum1[&0].matches(&checksum2[&0]));
// }
// #[test]
// fn test_layer_checksums() {
// let (layout, _memory) = create_test_layout(1);
// let physical = PhysicalLayout::new_local(layout, StorageLocation::System);
// // Fill entire block
// fill_blocks(&physical, &[0], FillPattern::Sequential).unwrap();
// // Compute checksums for different layer ranges
// let full_checksum = compute_block_checksums(&physical, &[0]).unwrap();
// let layer0_checksum = compute_layer_checksums(&physical, &[0], 0..1).unwrap();
// let layer1_checksum = compute_layer_checksums(&physical, &[0], 1..2).unwrap();
// // Layer checksums should be different from full checksum
// assert_ne!(full_checksum[&0].byte_count, layer0_checksum[&0].byte_count);
// assert_ne!(full_checksum[&0].byte_count, layer1_checksum[&0].byte_count);
// // Layer 0 and Layer 1 should have same byte count (same size)
// assert_eq!(
// layer0_checksum[&0].byte_count,
// layer1_checksum[&0].byte_count
// );
// }
// #[test]
// fn test_checksum_remote_layout_fails() {
// let (layout, _memory) = create_test_layout(1);
// let physical =
// PhysicalLayout::new_remote(layout, StorageLocation::System, "remote".to_string());
// let result = compute_block_checksums(&physical, &[0]);
// assert!(result.is_err());
// assert!(result.unwrap_err().to_string().contains("remote"));
// }
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer context.
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use anyhow::Result;
use cudarc::driver::{CudaContext, CudaEvent, CudaStream};
use derive_builder::Builder;
use tokio::sync::mpsc;
use uuid::Uuid;
use dynamo_memory::CudaMemPool;
use dynamo_memory::nixl::{NixlAgent, NixlBackendConfig, XferRequest};
use velo_events::EventManager;
use crate::manager::TransferManager;
// Notifications module is declared in ../mod.rs
// Re-export for convenience
use super::TransferCapabilities;
use notifications::RegisterPollingNotification;
pub(crate) use super::notifications;
pub use super::notifications::TransferCompleteNotification;
#[derive(Clone, Builder)]
#[builder(pattern = "owned", build_fn(private, name = "build_internal"), public)]
#[allow(dead_code)] // Fields are used in build() but derive macros confuse dead code analysis
pub struct TransferConfig {
#[builder(default = "Arc::new(EventManager::local())")]
event_system: Arc<EventManager>,
/// Optional custom name for the NIXL agent. If not provided, defaults to "worker-{worker_id}"
#[builder(default = "None", setter(strip_option))]
nixl_agent_name: Option<String>,
/// Backend configuration for NIXL backends to enable
#[builder(default = "NixlBackendConfig::default()")]
nixl_backend_config: NixlBackendConfig,
#[builder(default = "0")]
cuda_device_id: usize,
#[builder(default = "get_tokio_runtime()")]
tokio_runtime: TokioRuntime,
#[builder(default = "TransferCapabilities::default()")]
capabilities: TransferCapabilities,
/// Size in bytes to pre-allocate for the CUDA memory pool (default: 64 MiB)
#[builder(default = "64 * 1024 * 1024")]
cuda_pool_reserve_size: usize,
/// Release threshold for the CUDA memory pool (default: Some(64 MiB))
/// Memory above this threshold is returned to the system when freed.
/// If None, no release threshold is set.
#[builder(default = "Some(64 * 1024 * 1024)")]
cuda_pool_release_threshold: Option<u64>,
}
impl TransferConfigBuilder {
/// Initialize builder with event system and tokio handle.
///
/// This sets the event_system and tokio runtime handle, ensuring consistency
/// with Nova's event system. Use this when the runtime has already been
/// constructed and you want components to share the same event notification
/// infrastructure.
pub fn from_event_system_and_handle(
self,
event_system: Arc<EventManager>,
handle: tokio::runtime::Handle,
) -> Self {
self.event_system(event_system)
.tokio_runtime(TokioRuntime::Handle(handle))
}
/// Directly provide a pre-configured wrapped NIXL agent (mainly for testing).
///
/// This bypasses the agent creation and backend initialization logic,
/// using the provided agent directly. Useful for tests that need full
/// control over agent configuration.
pub fn nixl_agent(self, agent: NixlAgent) -> TransferConfigBuilderWithAgent {
TransferConfigBuilderWithAgent {
builder: self,
agent,
}
}
/// Add a NIXL backend to enable (uses default plugin parameters).
pub fn nixl_backend(mut self, backend: impl Into<String>) -> Self {
let config = self
.nixl_backend_config
.get_or_insert_with(NixlBackendConfig::default);
*config = config.clone().with_backend(backend);
self
}
/// Load NIXL backend configuration from environment variables.
///
/// This merges environment-based configuration with any backends already
/// configured via the builder.
pub fn with_env_backends(mut self) -> Result<Self> {
let env_config = NixlBackendConfig::from_env()?;
let config = self
.nixl_backend_config
.get_or_insert_with(NixlBackendConfig::default);
*config = config.clone().merge(env_config);
Ok(self)
}
pub fn build(self) -> Result<TransferManager> {
let mut config = self.build_internal()?;
let worker_id = config.event_system.system_id();
// Merge environment backends if not explicitly configured
if config.nixl_backend_config.backends().is_empty() {
config.nixl_backend_config = NixlBackendConfig::from_env()?;
}
// Derive agent name from worker_id if not provided
let agent_name = config
.nixl_agent_name
.unwrap_or_else(|| format!("worker-{}", worker_id));
let nixl_agent =
NixlAgent::from_nixl_backend_config(&agent_name, config.nixl_backend_config)?;
let cuda_context = CudaContext::new(config.cuda_device_id)?;
let context = TransferContext::new(
nixl_agent,
config.event_system,
cuda_context,
config.tokio_runtime,
config.capabilities,
config.cuda_pool_reserve_size,
config.cuda_pool_release_threshold,
)?;
Ok(TransferManager::from_context(context))
}
}
/// Builder that already has a pre-configured NIXL agent.
///
/// This is generally used for testing when you want to pass in an agent directly
/// rather than having it created by the builder.
pub struct TransferConfigBuilderWithAgent {
builder: TransferConfigBuilder,
agent: NixlAgent,
}
impl TransferConfigBuilderWithAgent {
/// Build the TransferManager using the pre-configured agent.
pub fn build(self) -> Result<TransferManager> {
let config = self.builder.build_internal()?;
let cuda_context = CudaContext::new(config.cuda_device_id)?;
let context = TransferContext::new(
self.agent,
config.event_system,
cuda_context,
config.tokio_runtime,
config.capabilities,
config.cuda_pool_reserve_size,
config.cuda_pool_release_threshold,
)?;
Ok(TransferManager::from_context(context))
}
pub fn cuda_device_id(mut self, cuda_device_id: usize) -> Self {
self.builder = self.builder.cuda_device_id(cuda_device_id);
self
}
}
fn get_tokio_runtime() -> TokioRuntime {
match tokio::runtime::Handle::try_current() {
Ok(handle) => TokioRuntime::Handle(handle),
Err(_) => {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.max_blocking_threads(4)
.worker_threads(2)
.build()
.expect("failed to build tokio runtime");
TokioRuntime::Shared(Arc::new(rt))
}
}
}
#[derive(Debug, Clone)]
#[doc(hidden)]
pub enum TokioRuntime {
Handle(tokio::runtime::Handle),
Shared(Arc<tokio::runtime::Runtime>),
}
impl TokioRuntime {
pub fn handle(&self) -> &tokio::runtime::Handle {
match self {
TokioRuntime::Handle(handle) => handle,
TokioRuntime::Shared(runtime) => runtime.handle(),
}
}
}
#[derive(Clone)]
#[doc(hidden)]
pub struct TransferContext {
worker_id: u64,
nixl_agent: NixlAgent,
#[allow(dead_code)]
cuda_context: Arc<CudaContext>,
d2h_stream: Arc<CudaStream>,
h2d_stream: Arc<CudaStream>,
d2h_streams: Vec<Arc<CudaStream>>,
h2d_streams: Vec<Arc<CudaStream>>,
current_d2h_stream: Arc<AtomicUsize>,
current_h2d_stream: Arc<AtomicUsize>,
#[allow(dead_code)]
tokio_runtime: TokioRuntime,
capabilities: TransferCapabilities,
event_system: Arc<EventManager>,
// CUDA memory pool for kernel allocations
cuda_pool: Arc<CudaMemPool>,
// Channels for background notification handlers
tx_nixl_status: mpsc::Sender<RegisterPollingNotification<notifications::NixlStatusChecker>>,
tx_cuda_event: mpsc::Sender<RegisterPollingNotification<notifications::CudaEventChecker>>,
#[allow(dead_code)]
tx_nixl_events: mpsc::Sender<notifications::RegisterNixlNotification>,
}
impl TransferContext {
pub fn builder() -> TransferConfigBuilder {
TransferConfigBuilder::default()
}
pub(crate) fn new(
nixl_agent: NixlAgent,
event_system: Arc<EventManager>,
cuda_context: Arc<CudaContext>,
tokio_runtime: TokioRuntime,
capabilities: TransferCapabilities,
cuda_pool_reserve_size: usize,
cuda_pool_release_threshold: Option<u64>,
) -> Result<Self> {
unsafe { cuda_context.disable_event_tracking() };
// Create CUDA memory pool for kernel allocations
let mut pool_builder = CudaMemPool::builder(cuda_context.clone(), cuda_pool_reserve_size);
if let Some(threshold) = cuda_pool_release_threshold {
pool_builder = pool_builder.release_threshold(threshold);
}
let cuda_pool = Arc::new(pool_builder.build()?);
// Create channels for background notification handlers
let (tx_nixl_status, rx_nixl_status) = mpsc::channel(64);
let (tx_cuda_event, rx_cuda_event) = mpsc::channel(64);
let (tx_nixl_events, rx_nixl_events) = mpsc::channel(64);
// Spawn background handlers
let handle = tokio_runtime.handle();
// Spawn NIXL status polling handler
handle.spawn(notifications::process_polling_notifications(
rx_nixl_status,
event_system.clone(),
));
// Spawn CUDA event polling handler
handle.spawn(notifications::process_polling_notifications(
rx_cuda_event,
event_system.clone(),
));
// Spawn NIXL notification events handler
handle.spawn(notifications::process_nixl_notification_events(
nixl_agent.raw_agent().clone(),
rx_nixl_events,
event_system.clone(),
));
let d2h_streams: Vec<Arc<CudaStream>> = (0..4)
.map(|_| cuda_context.new_stream())
.collect::<Result<Vec<_>, _>>()?;
let h2d_streams: Vec<Arc<CudaStream>> = (0..4)
.map(|_| cuda_context.new_stream())
.collect::<Result<Vec<_>, _>>()?;
let d2h_stream = d2h_streams[0].clone();
let h2d_stream = h2d_streams[0].clone();
let current_d2h_stream = Arc::new(AtomicUsize::new(0));
let current_h2d_stream = Arc::new(AtomicUsize::new(0));
Ok(Self {
worker_id: event_system.system_id(),
nixl_agent,
cuda_context: cuda_context.clone(),
d2h_stream,
h2d_stream,
d2h_streams,
h2d_streams,
current_d2h_stream,
current_h2d_stream,
tokio_runtime,
capabilities,
event_system,
cuda_pool,
tx_nixl_status,
tx_cuda_event,
tx_nixl_events,
})
}
pub(crate) fn nixl_agent(&self) -> &NixlAgent {
&self.nixl_agent
}
#[allow(dead_code)]
pub(crate) fn cuda_context(&self) -> &Arc<CudaContext> {
&self.cuda_context
}
// Provides the same d2h stream per invocation
#[allow(dead_code)]
pub(crate) fn d2h_stream(&self) -> &Arc<CudaStream> {
&self.d2h_stream
}
// Provides the same h2d stream per invocation
#[allow(dead_code)]
pub(crate) fn h2d_stream(&self) -> &Arc<CudaStream> {
&self.h2d_stream
}
// Provides the next d2h stream in a round-robin fashion
pub(crate) fn next_d2h_streams(&self) -> Arc<CudaStream> {
let current_d2h_stream = self.current_d2h_stream.fetch_add(1, Ordering::Relaxed);
self.d2h_streams[current_d2h_stream % self.d2h_streams.len()].clone()
}
// Provides the next h2d stream in a round-robin fashion
pub(crate) fn next_h2d_streams(&self) -> Arc<CudaStream> {
let current_h2d_stream = self.current_h2d_stream.fetch_add(1, Ordering::Relaxed);
self.h2d_streams[current_h2d_stream % self.h2d_streams.len()].clone()
}
/// Acquire an H2D stream for use by caller.
///
/// This returns a stream from the pool that the caller can use for multiple
/// sequential operations. The caller is responsible for all synchronization
/// (e.g., recording events after operations).
///
/// Used for layer-wise transfers where all layers must execute on the same stream.
pub fn acquire_h2d_stream(&self) -> Arc<CudaStream> {
self.next_h2d_streams()
}
/// Acquire a D2H stream for use by caller.
///
/// This returns a stream from the pool that the caller can use for multiple
/// sequential operations. The caller is responsible for all synchronization
/// (e.g., recording events after operations).
///
/// Used for layer-wise transfers where all layers must execute on the same stream.
pub fn acquire_d2h_stream(&self) -> Arc<CudaStream> {
self.next_d2h_streams()
}
#[allow(dead_code)]
#[doc(hidden)]
pub fn tokio(&self) -> &tokio::runtime::Handle {
self.tokio_runtime.handle()
}
pub(crate) fn capabilities(&self) -> &TransferCapabilities {
&self.capabilities
}
#[doc(hidden)]
pub fn event_system(&self) -> &Arc<EventManager> {
&self.event_system
}
/// Get the CUDA memory pool for kernel allocations.
pub(crate) fn cuda_pool(&self) -> &Arc<CudaMemPool> {
&self.cuda_pool
}
/// Register a NIXL transfer request for status polling completion.
///
/// This method enqueues the transfer request to be polled for completion
/// using `agent.get_xfer_status()`. Returns a notification object that
/// can be awaited for completion.
pub(crate) fn register_nixl_status(
&self,
xfer_req: XferRequest,
) -> TransferCompleteNotification {
let event = self
.event_system
.new_event()
.expect("Failed to allocate event");
let handle = event.into_handle();
let awaiter = self
.event_system
.awaiter(handle)
.expect("Failed to get awaiter");
let notification = notifications::RegisterPollingNotification {
uuid: Uuid::new_v4(),
checker: notifications::NixlStatusChecker::new(
self.nixl_agent.raw_agent().clone(),
xfer_req,
),
event_handle: handle,
};
// Send to background handler — log error if channel is full or closed
if let Err(e) = self.tx_nixl_status.try_send(notification) {
tracing::error!(
"Failed to enqueue NIXL status notification: channel full or closed: {}",
e
);
}
TransferCompleteNotification::from_awaiter(awaiter)
}
/// Register a CUDA event for polling completion.
///
/// This method enqueues the CUDA event to be polled for completion.
/// Returns a notification object that can be awaited for completion.
pub(crate) fn register_cuda_event(&self, event: CudaEvent) -> TransferCompleteNotification {
let new_event = self
.event_system
.new_event()
.expect("Failed to allocate event");
let handle = new_event.into_handle();
let awaiter = self
.event_system
.awaiter(handle)
.expect("Failed to get awaiter");
let notification = notifications::RegisterPollingNotification {
uuid: Uuid::new_v4(),
checker: notifications::CudaEventChecker::new(event),
event_handle: handle,
};
// Send to background handler — log error if channel is full or closed
if let Err(e) = self.tx_cuda_event.try_send(notification) {
tracing::error!(
"Failed to enqueue CUDA event notification: channel full or closed: {}",
e
);
}
TransferCompleteNotification::from_awaiter(awaiter)
}
/// Register a NIXL transfer request for notification-based completion.
///
/// This method enqueues the transfer request to be completed via NIXL
/// notification events. Returns a notification object that can be awaited
/// for completion.
#[allow(dead_code)]
pub(crate) fn register_nixl_event(
&self,
xfer_req: XferRequest,
) -> TransferCompleteNotification {
let event = self
.event_system
.new_event()
.expect("Failed to allocate event");
let handle = event.into_handle();
let awaiter = self
.event_system
.awaiter(handle)
.expect("Failed to get awaiter");
let notification = notifications::RegisterNixlNotification {
uuid: Uuid::new_v4(),
xfer_req,
event_handle: handle,
};
// Send to background handler — log error if channel is full or closed
if let Err(e) = self.tx_nixl_events.try_send(notification) {
tracing::error!(
"Failed to enqueue NIXL event notification: channel full or closed: {}",
e
);
}
TransferCompleteNotification::from_awaiter(awaiter)
}
/// Get the worker ID for this context.
pub(crate) fn worker_id(&self) -> u64 {
self.worker_id
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA executor for GPU memory transfers.
use super::TransferContext;
use super::{PhysicalLayout, TransferStrategy};
use crate::BlockId;
use crate::transfer::context::TransferCompleteNotification;
use crate::transfer::{can_use_whole_block_transfer, validate_layout_compatibility};
use anyhow::{Result, anyhow};
use cudarc::driver::{CudaStream, result as cuda_result};
use cudarc::runtime::sys::cudaStream_t;
use dynamo_memory::CudaMemPool;
use kvbm_kernels::MemcpyBatchMode;
use std::ffi::c_void;
use std::ops::Range;
use std::sync::Arc;
// #[cfg(test)]
// mod cuda_kernel_tests;
/// Execute a CUDA transfer between host and device memory.
///
/// This executor handles transfers involving GPU memory using CUDA APIs.
/// Supports async and blocking transfers depending on the strategy.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `src_block_ids` - Source block IDs to transfer
/// * `dst_block_ids` - Destination block IDs to transfer
/// * `layer_range` - Optional range of layers to transfer (None = all layers)
/// * `strategy` - CUDA transfer strategy (H2D, D2H, D2D, async or blocking)
/// * `cuda_stream` - Optional caller-provided stream. If provided, use this stream
/// and skip event recording (caller manages sync). Returns completed() immediately.
/// * `ctx` - Transfer context with CUDA stream
#[allow(clippy::too_many_arguments)]
pub fn execute_cuda_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layer_range: Option<Range<usize>>,
strategy: TransferStrategy,
cuda_stream: Option<Arc<CudaStream>>,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
// Validate layouts
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Validate layout compatibility (errors if transform would be needed)
validate_layout_compatibility(src, dst)?;
// Determine layer range
let layers = layer_range.clone().unwrap_or(0..src_layout.num_layers());
// Check if we can use optimized whole-block transfer
let use_whole_block = can_use_whole_block_transfer(src, dst, layer_range.as_ref());
// Track whether caller provided stream (affects event recording)
let caller_manages_sync = cuda_stream.is_some();
// Get appropriate CUDA stream - use caller-provided or acquire from pool
let stream = if let Some(s) = cuda_stream {
s
} else {
match strategy {
TransferStrategy::CudaAsyncD2H => ctx.next_d2h_streams(),
_ => ctx.next_h2d_streams(), // H2D and D2D use h2d_stream
}
};
// Perform CUDA transfers based on strategy
// Determine direction name for logging
let strategy_name = match strategy {
TransferStrategy::CudaAsyncH2D => "H2D",
TransferStrategy::CudaAsyncD2H => "D2H",
TransferStrategy::CudaAsyncD2D => "D2D",
_ => "Unknown",
};
match strategy {
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => {
if use_whole_block {
// FC→FC: Use unified whole-block path with batched memcpy
// Direction auto-detected by cudaMemcpyDefault
tracing::debug!(
strategy = strategy_name,
num_blocks = src_block_ids.len(),
bytes_per_block = src_layout.config().bytes_per_block(),
"Using whole-block transfer (auto direction)"
);
execute_whole_block_cuda(src, dst, src_block_ids, dst_block_ids, stream.as_ref())?;
} else {
// FC↔LW: Use vectorized_copy kernel directly
tracing::debug!(
strategy = strategy_name,
num_blocks = src_block_ids.len(),
num_layers = layers.len(),
"Using vectorized_copy for FC↔LW transfer"
);
execute_fc_lw_vectorized(
src,
dst,
src_block_ids,
dst_block_ids,
layers.clone(),
stream.as_ref(),
ctx.cuda_pool(),
)?;
}
}
_ => {
return Err(anyhow!("Invalid CUDA transfer strategy: {:?}", strategy));
}
}
// If caller provided the stream, they manage synchronization - return completed immediately
if caller_manages_sync {
return Ok(TransferCompleteNotification::completed());
}
// For async transfers, record an event and register it for completion tracking
if matches!(
strategy,
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D
) {
let event = stream.record_event(None)?;
Ok(ctx.register_cuda_event(event))
} else {
// Blocking transfers are already synchronized
Ok(TransferCompleteNotification::completed())
}
}
// ============================================================================
// Whole-Block Transfer Functions (FC→FC optimization)
// ============================================================================
/// Unified whole-block transfer using batched memcpy.
///
/// NO device pointer allocation needed. Direction is auto-detected by CUDA
/// from pointer types using cudaMemcpyDefault.
///
/// Uses cudaMemcpyBatchAsync when available (CUDA 12.9+), falling back to
/// individual cudaMemcpyAsync calls on older CUDA versions.
fn execute_whole_block_cuda(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
stream: &cudarc::driver::CudaStream,
) -> Result<()> {
let bytes_per_block = src.layout().config().bytes_per_block();
let num_blocks = src_block_ids.len();
if num_blocks == 0 {
return Ok(());
}
// Build host pointer arrays
let mut src_ptrs: Vec<*const std::ffi::c_void> = Vec::with_capacity(num_blocks);
let mut dst_ptrs: Vec<*mut std::ffi::c_void> = Vec::with_capacity(num_blocks);
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_region = src.memory_region(src_block_id, 0, 0)?;
let dst_region = dst.memory_region(dst_block_id, 0, 0)?;
src_ptrs.push(src_region.addr() as *const std::ffi::c_void);
dst_ptrs.push(dst_region.addr() as *mut std::ffi::c_void);
}
// Use batched memcpy - handles CUDA 12.9+ batch API with automatic fallback
let status = unsafe {
kvbm_kernels::memcpy_batch(
src_ptrs.as_ptr(),
dst_ptrs.as_ptr(),
bytes_per_block,
num_blocks,
MemcpyBatchMode::BatchedWithFallback,
stream.cu_stream() as cudarc::runtime::sys::cudaStream_t,
)
};
if status != cudarc::runtime::sys::cudaError::cudaSuccess {
return Err(anyhow!("memcpy_batch failed: {:?}", status));
}
tracing::debug!(
num_blocks,
bytes_per_block,
batch_available = kvbm_kernels::is_memcpy_batch_available(),
"Whole-block transfer completed"
);
Ok(())
}
// ============================================================================
// FC↔LW Transfer using vectorized_copy kernel
// ============================================================================
/// Execute FC↔LW transfer using vectorized_copy kernel.
///
/// This function builds flat (src, dst) pointer arrays for all chunks across all blocks,
/// uploads them to device memory, and calls the vectorized_copy kernel directly.
///
/// Benefits over the old operational_copy approach:
/// - Simpler: One kernel, no backend selection logic
/// - Faster: 16-byte (int4) loads when aligned (vs 8-byte in operational_copy_vectorized)
/// - All offset math on host: Kernel just copies bytes
/// - Handles any alignment: Falls back gracefully to 8/4/1-byte copies
fn execute_fc_lw_vectorized(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layers: Range<usize>,
stream: &CudaStream,
pool: &CudaMemPool,
) -> Result<()> {
// Bind CUDA context to current thread before any CUDA operations.
stream.context().bind_to_thread()?;
let src_layout = src.layout();
let nl = layers.len();
let no = src_layout.outer_dim();
let chunk_size =
src_layout.page_size() * src_layout.inner_dim() * src_layout.dtype_width_bytes();
let num_blocks = src_block_ids.len();
let total_chunks = num_blocks * nl * no;
if total_chunks == 0 {
return Ok(());
}
// Build flat pointer arrays on host
let mut src_ptrs: Vec<usize> = Vec::with_capacity(total_chunks);
let mut dst_ptrs: Vec<usize> = Vec::with_capacity(total_chunks);
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..no {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
src_ptrs.push(src_region.addr());
dst_ptrs.push(dst_region.addr());
}
}
}
// Allocate device memory for pointer arrays
let src_ptrs_device = pool.alloc_async(total_chunks * std::mem::size_of::<usize>(), stream)?;
let dst_ptrs_device = pool.alloc_async(total_chunks * std::mem::size_of::<usize>(), stream)?;
// Upload pointer arrays to device
unsafe {
cuda_result::memcpy_htod_async(
src_ptrs_device,
std::slice::from_raw_parts(
src_ptrs.as_ptr() as *const u8,
total_chunks * std::mem::size_of::<usize>(),
),
stream.cu_stream(),
)?;
cuda_result::memcpy_htod_async(
dst_ptrs_device,
std::slice::from_raw_parts(
dst_ptrs.as_ptr() as *const u8,
total_chunks * std::mem::size_of::<usize>(),
),
stream.cu_stream(),
)?;
}
let pointers_transfered_event = stream.record_event(None)?;
// Call vectorized_copy kernel
let status = unsafe {
kvbm_kernels::vectorized_copy(
src_ptrs_device as *mut *mut c_void,
dst_ptrs_device as *mut *mut c_void,
chunk_size,
total_chunks as i32,
stream.cu_stream() as cudaStream_t,
)
};
// Free device allocations back to the pool
pool.free_async(src_ptrs_device, stream)?;
pool.free_async(dst_ptrs_device, stream)?;
if status != cudarc::runtime::sys::cudaError::cudaSuccess {
return Err(anyhow!("vectorized_copy failed: {:?}", status));
}
tracing::debug!(
total_chunks,
chunk_size,
"FC↔LW vectorized_copy transfer completed"
);
pointers_transfered_event.synchronize()?;
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Memcpy executor for host-to-host transfers.
use crate::BlockId;
use crate::transfer::PhysicalLayout;
use crate::transfer::TransferContext;
use crate::transfer::context::TransferCompleteNotification;
use crate::transfer::{can_use_whole_block_transfer, validate_layout_compatibility};
use anyhow::Result;
use std::ops::Range;
/// Execute a memcpy transfer between host memory locations.
///
/// This executor handles transfers between System and Pinned memory using
/// standard CPU memcpy operations. The transfer is synchronous and blocking.
///
/// For FC→FC transfers with compatible layouts and full-block transfers,
/// this uses an optimized whole-block copy path (single memcpy per block).
/// Otherwise, falls back to layer-wise copying.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `src_block_ids` - Source block IDs to transfer
/// * `dst_block_ids` - Destination block IDs to transfer
/// * `layer_range` - Optional range of layers to transfer (None = all layers)
/// * `_ctx` - Transfer context (unused for memcpy, kept for API consistency)
pub fn execute_memcpy_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layer_range: Option<Range<usize>>,
_ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
if src_block_ids.len() != dst_block_ids.len() {
return Err(anyhow::anyhow!(
"Block ID slice length mismatch: src={}, dst={}",
src_block_ids.len(),
dst_block_ids.len()
));
}
// Validate layouts have compatible structure
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow::anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow::anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Validate layout compatibility (errors if transform would be needed)
validate_layout_compatibility(src, dst)?;
let layers = layer_range.clone().unwrap_or(0..src_layout.num_layers());
// Try whole-block path for FC→FC transfers with compatible layouts
if can_use_whole_block_transfer(src, dst, layer_range.as_ref()) {
tracing::debug!(
num_blocks = src_block_ids.len(),
bytes_per_block = src_layout.config().bytes_per_block(),
"Using whole-block memcpy path"
);
execute_whole_block_memcpy(src, dst, src_block_ids, dst_block_ids)?;
} else {
tracing::debug!(
num_blocks = src_block_ids.len(),
layer_range = ?layers,
src_fc = src_layout.is_fully_contiguous(),
dst_fc = dst_layout.is_fully_contiguous(),
"Using layer-wise memcpy path"
);
execute_layer_wise_memcpy(src, dst, src_block_ids, dst_block_ids, layers)?;
}
// Memcpy is synchronous, so return already-completed notification
Ok(TransferCompleteNotification::completed())
}
/// Whole-block memcpy for FC→FC with compatible layouts.
///
/// Copies entire blocks in a single memcpy operation per block,
/// leveraging the fully contiguous memory layout.
fn execute_whole_block_memcpy(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
) -> Result<()> {
let bytes_per_block = src.layout().config().bytes_per_block();
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
// Get block base address (layer=0, outer=0 for FC layout gives base)
let src_region = src.memory_region(src_block_id, 0, 0)?;
let dst_region = dst.memory_region(dst_block_id, 0, 0)?;
unsafe {
std::ptr::copy_nonoverlapping(
src_region.addr() as *const u8,
dst_region.addr() as *mut u8,
bytes_per_block,
);
}
}
Ok(())
}
/// Layer-wise memcpy (existing behavior, refactored).
///
/// Copies blocks layer by layer and outer dimension by outer dimension.
/// Used for FC→LW, LW→FC, LW→LW, or partial layer transfers.
fn execute_layer_wise_memcpy(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layers: Range<usize>,
) -> Result<()> {
let src_layout = src.layout();
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src_layout.outer_dim() {
// Get source and destination memory regions
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
// Validate sizes match
if src_region.size() != dst_region.size() {
return Err(anyhow::anyhow!(
"Memory region size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
// Perform memcpy
unsafe {
let src_ptr = src_region.addr() as *const u8;
let dst_ptr = dst_region.addr() as *mut u8;
std::ptr::copy_nonoverlapping(src_ptr, dst_ptr, src_region.size());
}
}
}
}
Ok(())
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer executors for different copy strategies.
pub(super) mod cuda;
mod memcpy;
mod nixl;
use super::strategy::select_strategy;
use super::strategy::{TransferPlan, TransferStrategy};
use super::validation::validate_block_transfer;
use super::{PhysicalLayout, TransferContext};
use crate::BlockId;
use crate::layout::KvBlockLayout;
use crate::transfer::BounceBufferInternal;
use crate::transfer::{StorageKind, context::TransferCompleteNotification};
use anyhow::Result;
use cudarc::driver::CudaStream;
use std::ops::Range;
use std::sync::Arc;
use tokio::sync::Mutex;
// Re-export the NIXL transfer builder for public use
pub use nixl::NixlTransferBuilder;
/// Transformation kernel types for converting between different block layouts.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TransformKernel {
/// No transformation needed - layouts are compatible, use copy
None,
/// Transform from operational (NHD/HND) to universal format
BlockToUniversal { src_layout: KvBlockLayout },
/// Transform from universal to operational (NHD/HND) format
UniversalToBlock { dst_layout: KvBlockLayout },
/// Transpose between operational formats (NHD <-> HND)
OperationalTranspose,
/// Layouts are incompatible and no kernel is available
Unsupported,
}
/// Select the appropriate transformation kernel based on source and destination layouts.
///
/// Returns `TransformKernel::None` if the layouts are the same (copy is sufficient).
/// Returns `TransformKernel::Unsupported` if the layout combination is not supported.
#[allow(dead_code)]
pub(crate) fn select_transform_kernel(
src_layout: KvBlockLayout,
dst_layout: KvBlockLayout,
) -> TransformKernel {
// Same layout - no transformation needed
if !src_layout.requires_transform(&dst_layout) {
return TransformKernel::None;
}
// Unknown layouts cannot be transformed
if matches!(src_layout, KvBlockLayout::Unknown) || matches!(dst_layout, KvBlockLayout::Unknown)
{
return TransformKernel::Unsupported;
}
match (src_layout, dst_layout) {
// Operational to Universal
(KvBlockLayout::OperationalNHD, KvBlockLayout::UniversalTP)
| (KvBlockLayout::OperationalNHD, KvBlockLayout::UniversalPP)
| (KvBlockLayout::OperationalHND, KvBlockLayout::UniversalTP)
| (KvBlockLayout::OperationalHND, KvBlockLayout::UniversalPP) => {
TransformKernel::BlockToUniversal { src_layout }
}
// Universal to Operational
(KvBlockLayout::UniversalTP, KvBlockLayout::OperationalNHD)
| (KvBlockLayout::UniversalTP, KvBlockLayout::OperationalHND)
| (KvBlockLayout::UniversalPP, KvBlockLayout::OperationalNHD)
| (KvBlockLayout::UniversalPP, KvBlockLayout::OperationalHND) => {
TransformKernel::UniversalToBlock { dst_layout }
}
// Operational NHD <-> HND transpose
(KvBlockLayout::OperationalNHD, KvBlockLayout::OperationalHND)
| (KvBlockLayout::OperationalHND, KvBlockLayout::OperationalNHD) => {
TransformKernel::OperationalTranspose
}
// Custom layouts need explicit handling
(KvBlockLayout::Custom(_), _) | (_, KvBlockLayout::Custom(_)) => {
TransformKernel::Unsupported
}
// Universal to Universal (different variants)
(KvBlockLayout::UniversalTP, KvBlockLayout::UniversalPP)
| (KvBlockLayout::UniversalPP, KvBlockLayout::UniversalTP) => {
// TODO: Add direct universal-to-universal kernel
TransformKernel::Unsupported
}
// Fallback for any unhandled combinations
_ => TransformKernel::Unsupported,
}
}
/// Get the effective source layout, using override if provided.
#[expect(dead_code)]
pub(crate) fn effective_src_layout(
src: &PhysicalLayout,
override_layout: Option<KvBlockLayout>,
) -> KvBlockLayout {
override_layout.unwrap_or_else(|| src.layout().block_layout())
}
/// Get the effective destination layout, using override if provided.
#[expect(dead_code)]
pub(crate) fn effective_dst_layout(
dst: &PhysicalLayout,
override_layout: Option<KvBlockLayout>,
) -> KvBlockLayout {
override_layout.unwrap_or_else(|| dst.layout().block_layout())
}
#[derive(Default)]
#[expect(dead_code)]
pub(crate) struct TransferOptionsInternal {
layer_range: Option<Range<usize>>,
nixl_write_notification: Option<u64>,
bounce_buffer: Option<BounceBufferInternal>,
/// If provided, use this stream instead of acquiring from pool.
/// Caller manages synchronization - no event is recorded by the executor.
pub(crate) cuda_stream: Option<Arc<CudaStream>>,
/// Override source block layout interpretation.
/// If None, uses the layout's block_layout() method.
pub(crate) src_kv_layout: Option<KvBlockLayout>,
/// Override destination block layout interpretation.
/// If None, uses the layout's block_layout() method.
pub(crate) dst_kv_layout: Option<KvBlockLayout>,
}
impl TransferOptionsInternal {
pub(crate) fn builder() -> TransferOptionsInternalBuilder {
TransferOptionsInternalBuilder::default()
}
}
#[derive(Default)]
pub(crate) struct TransferOptionsInternalBuilder {
layer_range: Option<Range<usize>>,
nixl_write_notification: Option<u64>,
bounce_buffer: Option<BounceBufferInternal>,
cuda_stream: Option<Arc<CudaStream>>,
src_kv_layout: Option<KvBlockLayout>,
dst_kv_layout: Option<KvBlockLayout>,
}
impl TransferOptionsInternalBuilder {
pub(crate) fn layer_range(mut self, range: Range<usize>) -> Self {
self.layer_range = Some(range);
self
}
pub(crate) fn nixl_write_notification(mut self, notification: u64) -> Self {
self.nixl_write_notification = Some(notification);
self
}
pub(crate) fn bounce_buffer(mut self, bounce_buffer: BounceBufferInternal) -> Self {
self.bounce_buffer = Some(bounce_buffer);
self
}
/// Set a specific CUDA stream to use for this transfer.
///
/// When provided, the executor will use this stream instead of acquiring
/// one from the pool. The caller is responsible for synchronization -
/// no event is recorded by the executor.
///
/// This is useful for layer-wise transfers where all layers must execute
/// on the same stream to allow proper event sequencing.
pub(crate) fn cuda_stream(mut self, stream: Arc<CudaStream>) -> Self {
self.cuda_stream = Some(stream);
self
}
/// Override the source block layout interpretation.
///
/// When set, the transfer executor will treat source blocks as having
/// this layout instead of the layout's default block_layout().
/// This enables transferring blocks that are stored in one format
/// but should be interpreted as another.
pub(crate) fn src_kv_layout(mut self, layout: KvBlockLayout) -> Self {
self.src_kv_layout = Some(layout);
self
}
/// Override the destination block layout interpretation.
///
/// When set, the transfer executor will treat destination blocks as having
/// this layout instead of the layout's default block_layout().
/// This enables writing blocks in a different format than the destination
/// layout's native format.
pub(crate) fn dst_kv_layout(mut self, layout: KvBlockLayout) -> Self {
self.dst_kv_layout = Some(layout);
self
}
pub(crate) fn build(self) -> Result<TransferOptionsInternal> {
Ok(TransferOptionsInternal {
layer_range: self.layer_range,
nixl_write_notification: self.nixl_write_notification,
bounce_buffer: self.bounce_buffer,
cuda_stream: self.cuda_stream,
src_kv_layout: self.src_kv_layout,
dst_kv_layout: self.dst_kv_layout,
})
}
}
/// Execute a transfer between two physical layouts.
///
/// This is an internal entry point for all transfer operations called by TransportManager.
/// It selects the appropriate strategy and dispatches to the corresponding executor.
///
/// # Arguments
/// * `src` - Source physical layout
/// * `dst` - Destination physical layout
/// * `src_block_ids` - Source block IDs to transfer
/// * `dst_block_ids` - Destination block IDs to transfer
/// * `options` - Transfer options
/// * `ctx` - Transfer context with CUDA stream and NIXL agent
pub(crate) fn execute_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
options: TransferOptionsInternal,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
// Validate block IDs
validate_block_transfer(src_block_ids, dst_block_ids, None, src, dst, None)?;
// Select transfer plan based on locations and capabilities
let plan = select_strategy(src, dst, ctx)?;
// Dispatch based on plan type
match plan {
TransferPlan::Direct(strategy) => execute_direct_transfer(
src,
dst,
src_block_ids,
dst_block_ids,
options.layer_range,
strategy,
options.cuda_stream,
ctx,
),
TransferPlan::TwoHop {
first,
bounce_location,
second,
} => execute_two_hop_transfer(TwoHopTransferParams {
src,
dst,
src_block_ids,
dst_block_ids,
first_strategy: first,
bounce_location,
second_strategy: second,
options,
ctx,
}),
}
}
/// Execute a direct single-hop transfer.
#[allow(clippy::too_many_arguments)]
fn execute_direct_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
layer_range: Option<Range<usize>>,
strategy: TransferStrategy,
cuda_stream: Option<Arc<CudaStream>>,
ctx: &TransferContext,
) -> Result<TransferCompleteNotification> {
match strategy {
TransferStrategy::Memcpy => {
if cuda_stream.is_some() {
return Err(anyhow::anyhow!(
"cuda_stream option is not supported for Memcpy strategy"
));
}
memcpy::execute_memcpy_transfer(
src,
dst,
src_block_ids,
dst_block_ids,
layer_range,
ctx,
)
}
TransferStrategy::CudaAsyncH2D
| TransferStrategy::CudaAsyncD2H
| TransferStrategy::CudaAsyncD2D => Ok(cuda::execute_cuda_transfer(
src,
dst,
src_block_ids,
dst_block_ids,
layer_range,
strategy,
cuda_stream,
ctx,
)?),
TransferStrategy::NixlRead
| TransferStrategy::NixlWrite
| TransferStrategy::NixlReadFlipped
| TransferStrategy::NixlWriteFlipped => {
if cuda_stream.is_some() {
return Err(anyhow::anyhow!(
"cuda_stream option is not supported for NIXL strategies"
));
}
let mut builder = NixlTransferBuilder::new()
.src(src)
.dst(dst)
.src_blocks(src_block_ids)
.dst_blocks(dst_block_ids)
.strategy(strategy);
if let Some(range) = layer_range {
builder = builder.layer_range(range);
}
builder.execute(ctx)
}
TransferStrategy::Invalid => Err(anyhow::anyhow!(
"Invalid transfer strategy for src={:?}, dst={:?}",
src.location(),
dst.location()
)),
}
}
/// Work-stealing bounce buffer transfer using two parallel tasks.
///
/// This function implements a work-stealing approach where two tasks each take
/// batches from a shared iterator and execute complete two-hop transfers.
/// This is simpler to maintain than double-buffering while still providing
/// good throughput through task parallelism.
///
/// # Algorithm
/// 1. Split bounce buffer into two groups (group 0 and group 1)
/// 2. Create a shared iterator over (src_block_id, dst_block_id) pairs
/// 3. Two parallel tasks each:
/// - Lock the iterator, take a batch of pairs
/// - Execute the complete two-hop transfer for that batch
/// - Repeat until iterator is exhausted
#[allow(clippy::too_many_arguments)]
async fn handle_buffered_transfer(
src: &PhysicalLayout,
bounce_layout: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
bounce_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
first_strategy: TransferStrategy,
second_strategy: TransferStrategy,
layer_range: &Option<Range<usize>>,
ctx: &TransferContext,
) -> Result<()> {
let bounce_groups =
&bounce_block_ids[0..std::cmp::min(src_block_ids.len(), bounce_block_ids.len())];
let (bounce_group_0, bounce_group_1) = bounce_groups.split_at(bounce_groups.len() / 2);
let bounce_group_0 = bounce_group_0.to_vec();
let bounce_group_1 = bounce_group_1.to_vec();
let src_dst_iter = Arc::new(Mutex::new(src_block_ids.iter().zip(dst_block_ids.iter())));
let transfer_task = async move |bounce_group: &[BlockId]| -> Result<()> {
loop {
let (src_ids, dst_ids): (Vec<BlockId>, Vec<BlockId>);
{
let mut x = src_dst_iter.lock().await;
(src_ids, dst_ids) = x
.by_ref()
.take(bounce_group.len())
.map(|(&s, &d)| (s, d))
.unzip();
if src_ids.is_empty() {
break;
}
}
execute_two_hop_transfer_chunk(
src,
bounce_layout,
dst,
&src_ids,
&bounce_group[0..src_ids.len()],
&dst_ids,
first_strategy,
second_strategy,
layer_range,
ctx,
)
.await?;
}
Ok(())
};
let transfer_0 = transfer_task(&bounce_group_0);
let transfer_1 = transfer_task(&bounce_group_1);
futures::future::try_join(transfer_0, transfer_1).await?;
Ok(())
}
/// Execute a single chunk of a two-hop transfer sequentially.
///
/// Used when bounce buffer has only a single block or as a fallback.
/// Performs src→bounce followed by bounce→dst sequentially.
#[allow(clippy::too_many_arguments)]
async fn execute_two_hop_transfer_chunk(
src: &PhysicalLayout,
bounce_layout: &PhysicalLayout,
dst: &PhysicalLayout,
src_block_ids: &[BlockId],
bounce_block_ids: &[BlockId],
dst_block_ids: &[BlockId],
first_strategy: TransferStrategy,
second_strategy: TransferStrategy,
layer_range: &Option<Range<usize>>,
ctx: &TransferContext,
) -> Result<()> {
let bounce_ids_to_use = &bounce_block_ids[..src_block_ids.len()];
execute_direct_transfer(
src,
bounce_layout,
src_block_ids,
bounce_ids_to_use,
layer_range.clone(),
first_strategy,
None, // Two-hop transfers don't support caller-provided streams
ctx,
)?
.await?;
execute_direct_transfer(
bounce_layout,
dst,
bounce_ids_to_use,
dst_block_ids,
layer_range.clone(),
second_strategy,
None, // Two-hop transfers don't support caller-provided streams
ctx,
)?
.await?;
Ok(())
}
/// Parameters for two-hop transfer execution
struct TwoHopTransferParams<'a> {
src: &'a PhysicalLayout,
dst: &'a PhysicalLayout,
src_block_ids: &'a [BlockId],
dst_block_ids: &'a [BlockId],
first_strategy: TransferStrategy,
bounce_location: StorageKind,
second_strategy: TransferStrategy,
options: TransferOptionsInternal,
ctx: &'a TransferContext,
}
fn execute_two_hop_transfer(params: TwoHopTransferParams) -> Result<TransferCompleteNotification> {
let TwoHopTransferParams {
src,
dst,
src_block_ids,
dst_block_ids,
first_strategy,
bounce_location,
second_strategy,
options,
ctx,
} = params;
let event = ctx.event_system().new_event()?;
let handle = event.into_handle();
let awaiter = ctx.event_system().awaiter(handle)?;
let system = ctx.event_system().clone();
// TODO: Cloning all this stuff is not ideal.
let src_clone = src.clone();
let dst_clone = dst.clone();
let src_block_ids = src_block_ids.to_vec();
let dst_block_ids = dst_block_ids.to_vec();
let ctx_clone = ctx.clone();
// let options_clone = options.clone();
ctx.tokio().spawn(async move {
let Some(ref bounce_buffer_spec) = options.bounce_buffer else {
let _ = system.poison(
handle,
"Two-hop transfers require a bounce buffer.".to_string(),
);
return;
};
if bounce_buffer_spec.layout.location() != bounce_location {
let _ = system.poison(
handle,
"Bounce buffer layout does not match bounce location.".to_string(),
);
return;
}
let num_bounce_blocks = bounce_buffer_spec.block_ids.len();
if num_bounce_blocks == 1 {
// Single bounce block: use sequential processing for each block
let bounce_block = bounce_buffer_spec.block_ids[0];
for (src_block_id, dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
if let Err(e) = execute_two_hop_transfer_chunk(
&src_clone,
&bounce_buffer_spec.layout,
&dst_clone,
&[*src_block_id],
&[bounce_block],
&[*dst_block_id],
first_strategy,
second_strategy,
&options.layer_range,
&ctx_clone,
)
.await
{
let _ = system.poison(handle, e.to_string());
return;
}
}
let _ = system.trigger(handle);
} else {
// Multiple bounce blocks: use work-stealing parallel transfer
if let Err(e) = handle_buffered_transfer(
&src_clone,
&bounce_buffer_spec.layout,
&dst_clone,
&src_block_ids,
&bounce_buffer_spec.block_ids,
&dst_block_ids,
first_strategy,
second_strategy,
&options.layer_range,
&ctx_clone,
)
.await
{
let _ = system.poison(handle, e.to_string());
return;
}
let _ = system.trigger(handle);
}
});
Ok(TransferCompleteNotification::from_awaiter(awaiter))
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::*;
#[test]
fn test_select_transform_kernel_same_layout() {
// Same layout - no transformation
assert_eq!(
select_transform_kernel(KvBlockLayout::OperationalNHD, KvBlockLayout::OperationalNHD),
TransformKernel::None
);
assert_eq!(
select_transform_kernel(KvBlockLayout::UniversalTP, KvBlockLayout::UniversalTP),
TransformKernel::None
);
}
#[test]
fn test_select_transform_kernel_block_to_universal() {
// Operational to Universal
assert!(matches!(
select_transform_kernel(KvBlockLayout::OperationalNHD, KvBlockLayout::UniversalTP),
TransformKernel::BlockToUniversal {
src_layout: KvBlockLayout::OperationalNHD
}
));
assert!(matches!(
select_transform_kernel(KvBlockLayout::OperationalHND, KvBlockLayout::UniversalTP),
TransformKernel::BlockToUniversal {
src_layout: KvBlockLayout::OperationalHND
}
));
}
#[test]
fn test_select_transform_kernel_universal_to_block() {
// Universal to Operational
assert!(matches!(
select_transform_kernel(KvBlockLayout::UniversalTP, KvBlockLayout::OperationalNHD),
TransformKernel::UniversalToBlock {
dst_layout: KvBlockLayout::OperationalNHD
}
));
assert!(matches!(
select_transform_kernel(KvBlockLayout::UniversalTP, KvBlockLayout::OperationalHND),
TransformKernel::UniversalToBlock {
dst_layout: KvBlockLayout::OperationalHND
}
));
}
#[test]
fn test_select_transform_kernel_operational_transpose() {
// NHD <-> HND
assert_eq!(
select_transform_kernel(KvBlockLayout::OperationalNHD, KvBlockLayout::OperationalHND),
TransformKernel::OperationalTranspose
);
assert_eq!(
select_transform_kernel(KvBlockLayout::OperationalHND, KvBlockLayout::OperationalNHD),
TransformKernel::OperationalTranspose
);
}
#[test]
fn test_select_transform_kernel_unknown_unsupported() {
// Unknown is always unsupported
assert_eq!(
select_transform_kernel(KvBlockLayout::Unknown, KvBlockLayout::OperationalNHD),
TransformKernel::Unsupported
);
assert_eq!(
select_transform_kernel(KvBlockLayout::OperationalNHD, KvBlockLayout::Unknown),
TransformKernel::Unsupported
);
}
#[test]
fn test_select_transform_kernel_custom_unsupported() {
// Custom layouts are unsupported (for now)
let custom = KvBlockLayout::Custom([
crate::layout::BlockDim::Head,
crate::layout::BlockDim::Layer,
crate::layout::BlockDim::Outer,
crate::layout::BlockDim::Page,
]);
assert_eq!(
select_transform_kernel(custom, KvBlockLayout::OperationalNHD),
TransformKernel::Unsupported
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Typestate builder for NIXL transfers.
//!
//! This module provides a compile-time safe builder for NIXL transfers that ensures
//! all required parameters are set before execution.
use super::{PhysicalLayout, TransferContext, TransferStrategy};
use crate::BlockId;
use crate::transfer::context::TransferCompleteNotification;
use crate::transfer::{can_use_whole_block_transfer, validate_layout_compatibility};
use anyhow::{Result, anyhow};
use dynamo_memory::nixl::{XferDescList, XferOp};
use std::marker::PhantomData;
use std::ops::Range;
/// Marker type for unset builder fields.
pub struct Unset;
/// Marker type for set builder fields.
pub struct Set;
/// Typestate builder for NIXL transfers.
///
/// This builder uses the typestate pattern to ensure all required parameters are set
/// at compile time. The type parameters track which fields have been set:
/// - `TSrc`: Source layout state
/// - `TDst`: Destination layout state
/// - `TSrcBlocks`: Source block IDs state
/// - `TDstBlocks`: Destination block IDs state
/// - `TStrategy`: Transfer strategy state
pub struct NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy> {
src: Option<&'a PhysicalLayout>,
dst: Option<&'a PhysicalLayout>,
src_block_ids: Option<&'a [BlockId]>,
dst_block_ids: Option<&'a [BlockId]>,
strategy: Option<TransferStrategy>,
layer_range: Option<Range<usize>>,
write_notif: Option<uuid::Uuid>,
_phantom: PhantomData<(TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy)>,
}
impl<'a> NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> {
/// Creates a new NIXL transfer builder with all fields unset.
pub fn new() -> Self {
Self {
src: None,
dst: None,
src_block_ids: None,
dst_block_ids: None,
strategy: None,
layer_range: None,
write_notif: None,
_phantom: PhantomData,
}
}
}
impl<'a> Default for NixlTransferBuilder<'a, Unset, Unset, Unset, Unset, Unset> {
fn default() -> Self {
Self::new()
}
}
// Required field setters - these consume self and return a new builder with the field marked as Set
impl<'a, TDst, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, Unset, TDst, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets the source physical layout.
pub fn src(
self,
src: &'a PhysicalLayout,
) -> NixlTransferBuilder<'a, Set, TDst, TSrcBlocks, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: Some(src),
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, Unset, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets the destination physical layout.
pub fn dst(
self,
dst: &'a PhysicalLayout,
) -> NixlTransferBuilder<'a, TSrc, Set, TSrcBlocks, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: Some(dst),
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, Unset, TDstBlocks, TStrategy>
{
/// Sets the source block IDs to transfer.
pub fn src_blocks(
self,
src_block_ids: &'a [BlockId],
) -> NixlTransferBuilder<'a, TSrc, TDst, Set, TDstBlocks, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: Some(src_block_ids),
dst_block_ids: self.dst_block_ids,
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TSrcBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Unset, TStrategy>
{
/// Sets the destination block IDs to transfer.
pub fn dst_blocks(
self,
dst_block_ids: &'a [BlockId],
) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, Set, TStrategy> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: Some(dst_block_ids),
strategy: self.strategy,
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Unset>
{
/// Sets the NIXL transfer strategy (Read or Write).
pub fn strategy(
self,
strategy: TransferStrategy,
) -> NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, Set> {
NixlTransferBuilder {
src: self.src,
dst: self.dst,
src_block_ids: self.src_block_ids,
dst_block_ids: self.dst_block_ids,
strategy: Some(strategy),
layer_range: self.layer_range,
write_notif: self.write_notif,
_phantom: PhantomData,
}
}
}
// Optional field setters - these can be called at any point in the builder chain
impl<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy>
NixlTransferBuilder<'a, TSrc, TDst, TSrcBlocks, TDstBlocks, TStrategy>
{
/// Sets an optional range of layers to transfer.
/// If not called, all layers will be transferred.
pub fn layer_range(mut self, layer_range: Range<usize>) -> Self {
self.layer_range = Some(layer_range);
self
}
/// Sets an optional write notification UUID.
#[expect(dead_code)]
pub fn write_notif(mut self, write_notif: uuid::Uuid) -> Self {
self.write_notif = Some(write_notif);
self
}
}
// Execute method - only available when all required fields are Set
impl<'a> NixlTransferBuilder<'a, Set, Set, Set, Set, Set> {
/// Executes the NIXL transfer with the configured parameters.
///
/// This method is only available when all required fields have been set,
/// enforced at compile time by the typestate pattern.
pub(crate) fn execute(self, ctx: &TransferContext) -> Result<TransferCompleteNotification> {
// Unwrap all required fields (safe because typestate guarantees they're set)
let src = self.src.unwrap();
let dst = self.dst.unwrap();
let src_block_ids = self.src_block_ids.unwrap();
let dst_block_ids = self.dst_block_ids.unwrap();
let strategy = self.strategy.unwrap();
let layer_range = self.layer_range;
let _write_notif = self.write_notif;
// Validate layouts
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout.num_layers() != dst_layout.num_layers() {
return Err(anyhow!(
"Layouts have incompatible layer counts: src={}, dst={}",
src_layout.num_layers(),
dst_layout.num_layers()
));
}
if src_layout.outer_dim() != dst_layout.outer_dim() {
return Err(anyhow!(
"Layouts have incompatible outer dimensions: src={}, dst={}",
src_layout.outer_dim(),
dst_layout.outer_dim()
));
}
// Validate layout compatibility (errors if transform would be needed)
validate_layout_compatibility(src, dst)?;
// Get NIXL agent
let nixl_agent = ctx.nixl_agent();
// Determine layer range
let layers = layer_range.clone().unwrap_or(0..src_layout.num_layers());
// Check if we can use optimized whole-block transfer
let use_whole_block = can_use_whole_block_transfer(src, dst, layer_range.as_ref());
// Determine NIXL operation type
let xfer_op = match strategy {
TransferStrategy::NixlRead | TransferStrategy::NixlReadFlipped => XferOp::Read,
TransferStrategy::NixlWrite | TransferStrategy::NixlWriteFlipped => XferOp::Write,
_ => {
return Err(anyhow!("Invalid NIXL transfer strategy: {:?}", strategy));
}
};
// Validate locality constraints based on operation type:
// - For Write operations (push): source must be local, we're writing FROM local TO remote
// - For Read operations (pull): destination must be local, we're reading FROM remote INTO local
let src_is_local = nixl_agent.name() == src.nixl_metadata().agent_name();
let dst_is_local = nixl_agent.name() == dst.nixl_metadata().agent_name();
// These are invariant assertions — a violation means a bug in `select_strategy`,
// not a user error. The strategy selection guarantees locality constraints.
match xfer_op {
XferOp::Write => {
assert!(
src_is_local,
"For NIXL Write (push), the source must be local. src_agent='{}', local_agent='{}'",
src.nixl_metadata().agent_name(),
nixl_agent.name()
);
}
XferOp::Read => {
assert!(
dst_is_local,
"For NIXL Read (pull), the destination must be local. dst_agent='{}', local_agent='{}'",
dst.nixl_metadata().agent_name(),
nixl_agent.name()
);
}
}
// Capture NIXL metadata for both layouts
let src_metadata = src.nixl_metadata();
let dst_metadata = dst.nixl_metadata();
let src_mem_type = src_metadata.mem_type();
let dst_mem_type = dst_metadata.mem_type();
let src_device_id = src_metadata.device_id();
let dst_device_id = dst_metadata.device_id();
// Build XferDescLists for source and destination
let mut src_dl = XferDescList::new(src_mem_type)?;
let mut dst_dl = XferDescList::new(dst_mem_type)?;
// Build descriptor lists - use whole-block or layer-wise depending on layout
if use_whole_block {
let bytes_per_block = src_layout.config().bytes_per_block();
tracing::debug!(
num_blocks = src_block_ids.len(),
bytes_per_block,
"Building whole-block NIXL descriptors"
);
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
let src_region = src.memory_region(src_block_id, 0, 0)?;
let dst_region = dst.memory_region(dst_block_id, 0, 0)?;
src_dl.add_desc(src_region.addr(), bytes_per_block, src_device_id);
dst_dl.add_desc(dst_region.addr(), bytes_per_block, dst_device_id);
}
} else {
tracing::debug!(
num_blocks = src_block_ids.len(),
layer_range = ?layers,
src_fc = src_layout.is_fully_contiguous(),
dst_fc = dst_layout.is_fully_contiguous(),
"Building layer-wise NIXL descriptors"
);
for (&src_block_id, &dst_block_id) in src_block_ids.iter().zip(dst_block_ids.iter()) {
for layer_id in layers.clone() {
for outer_id in 0..src_layout.outer_dim() {
let src_region = src.memory_region(src_block_id, layer_id, outer_id)?;
let dst_region = dst.memory_region(dst_block_id, layer_id, outer_id)?;
if src_region.size() != dst_region.size() {
return Err(anyhow!(
"Size mismatch at block=({},{}), layer={}, outer={}: src={}, dst={}",
src_block_id,
dst_block_id,
layer_id,
outer_id,
src_region.size(),
dst_region.size()
));
}
src_dl.add_desc(src_region.addr(), src_region.size(), src_device_id);
dst_dl.add_desc(dst_region.addr(), dst_region.size(), dst_device_id);
}
}
}
}
// Note: Overlap detection was removed from nixl-sys 0.6.1
// The NIXL library now handles overlap detection internally
if matches!(
strategy,
TransferStrategy::NixlReadFlipped | TransferStrategy::NixlWriteFlipped
) {
std::mem::swap(&mut src_dl, &mut dst_dl);
}
// Create transfer request
// The remote agent depends on operation type:
// - For Write (push): remote is the destination
// - For Read (pull): remote is the source
let remote_agent = match xfer_op {
XferOp::Write => dst_metadata.agent_name(),
XferOp::Read => src_metadata.agent_name(),
};
let xfer_req = nixl_agent.create_xfer_req(
xfer_op,
&src_dl,
&dst_dl,
remote_agent,
None, // opt_args
)?;
// Post transfer request
// Note: Notification handling via OptArgs can be added later if needed
let still_pending = nixl_agent.post_xfer_req(&xfer_req, None)?;
if still_pending {
// Register for async completion via status polling
Ok(ctx.register_nixl_status(xfer_req))
} else {
// Transfer completed synchronously
Ok(TransferCompleteNotification::completed())
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Block filling operations for testing.
//!
//! This module provides utilities to populate blocks with specific patterns
//! for verification in round-trip tests.
use crate::BlockId;
use super::PhysicalLayout;
use aligned_vec::{AVec, avec};
use anyhow::{Result, anyhow};
use cudarc::runtime::sys::{cudaMemcpy, cudaMemcpyKind};
use dynamo_memory::StorageKind;
use std::{
fs::File,
io::{Seek, Write},
mem::ManuallyDrop,
ops::Range,
os::fd::FromRawFd,
};
/// Fill strategy for block memory.
#[derive(Debug, Clone, Copy)]
pub enum FillPattern {
/// Fill with a constant byte value
Constant(u8),
/// Fill with a sequential pattern: block_id + layer_id + offset % 256
Sequential,
}
/// Fill blocks in a physical layout with a specific pattern.
///
/// This operation directly writes to memory and should only be used on
/// local layouts. Remote layouts cannot be filled directly.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to fill
/// * `pattern` - Fill pattern to use
///
/// # Errors
/// Returns an error if:
/// - Layout is remote (cannot fill remote memory directly)
/// - Block IDs are out of range
/// - Memory access fails
pub fn fill_blocks(
layout: &PhysicalLayout,
block_ids: &[BlockId],
pattern: FillPattern,
) -> Result<()> {
// Can only fill local layouts
let config = layout.layout().config();
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
for &block_id in block_ids {
if block_id >= config.num_blocks as BlockId {
return Err(anyhow!("Block ID {} out of range", block_id));
}
// Fill all layers and outer dimensions for this block
for layer_id in 0..num_layers {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
match layout.location() {
StorageKind::System | StorageKind::Pinned => {
fill_memory_region(
region.addr(),
region.size(),
block_id,
layer_id,
pattern,
)?;
}
StorageKind::Device(_) => {
let system_region: Vec<u8> = vec![0; region.size()];
fill_memory_region(
system_region.as_ptr() as usize,
system_region.len(),
block_id,
layer_id,
pattern,
)?;
let err = unsafe {
cudaMemcpy(
region.addr() as *mut std::ffi::c_void,
system_region.as_ptr() as *const std::ffi::c_void,
region.size(),
cudaMemcpyKind::cudaMemcpyHostToDevice,
)
};
if err != cudarc::runtime::sys::cudaError::cudaSuccess {
return Err(anyhow!("cudaMemcpy H2D failed in fill_blocks: {:?}", err));
}
}
StorageKind::Disk(fd) => {
let system_region: AVec<u8, _> = avec![[4096]| 0; region.size()];
fill_memory_region(
system_region.as_ptr() as usize,
system_region.len(),
block_id,
layer_id,
pattern,
)?;
let mut file = ManuallyDrop::new(unsafe { File::from_raw_fd(fd as i32) });
file.seek(std::io::SeekFrom::Start(region.addr() as u64))?;
file.write_all(&system_region)?;
file.sync_all()?;
file.flush()?;
}
}
}
}
}
Ok(())
}
/// Fill a subset of layers in blocks with a specific pattern.
///
/// # Arguments
/// * `layout` - The physical layout containing the blocks
/// * `block_ids` - List of block IDs to fill
/// * `layer_range` - Range of layers to fill
/// * `pattern` - Fill pattern to use
pub fn fill_layers(
layout: &PhysicalLayout,
block_ids: &[usize],
layer_range: Range<usize>,
pattern: FillPattern,
) -> Result<()> {
let config = layout.layout().config();
let num_layers = config.num_layers;
let outer_dim = config.outer_dim;
if layer_range.end > num_layers {
return Err(anyhow!(
"Layer range {:?} exceeds num_layers {}",
layer_range,
num_layers
));
}
for &block_id in block_ids {
if block_id >= config.num_blocks {
return Err(anyhow!("Block ID {} out of range", block_id));
}
// Fill specified layers and all outer dimensions
for layer_id in layer_range.clone() {
for outer_id in 0..outer_dim {
let region = layout.memory_region(block_id, layer_id, outer_id)?;
match layout.location() {
StorageKind::System | StorageKind::Pinned => {
fill_memory_region(
region.addr(),
region.size(),
block_id,
layer_id,
pattern,
)?;
}
StorageKind::Device(_) | StorageKind::Disk(_) => {
return Err(anyhow!(
"fill_layers only supports host-accessible storage (System/Pinned)"
));
}
}
}
}
}
Ok(())
}
/// Fill a memory region with the specified pattern.
///
/// # Safety
/// This function performs unsafe memory writes. The caller must ensure:
/// - The memory region is valid and accessible
/// - No other references exist to this memory
fn fill_memory_region(
addr: usize,
size: usize,
block_id: BlockId,
layer_id: usize,
pattern: FillPattern,
) -> Result<()> {
unsafe {
let ptr = addr as *mut u8;
match pattern {
FillPattern::Constant(value) => {
std::ptr::write_bytes(ptr, value, size);
}
FillPattern::Sequential => {
for offset in 0..size {
let value = ((block_id + layer_id + offset) % 256) as u8;
ptr.add(offset).write(value);
}
}
}
}
Ok(())
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use super::super::tests::*;
use super::*;
#[test]
fn test_fill_blocks_constant() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Constant(42)).unwrap();
// Verify all bytes are set to 42
assert!(unsafe {
physical
.memory_region(0, 0, 0)
.unwrap()
.as_slice()
.unwrap()
.iter()
.all(|&b| b == 42)
});
}
#[test]
fn test_fill_blocks_sequential() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
fill_blocks(&physical, &[0, 1], FillPattern::Sequential).unwrap();
let mr = physical.memory_region(0, 0, 0).unwrap();
let mr_slice = unsafe { mr.as_slice().unwrap() };
// Verify pattern is applied (spot check a few bytes)
let first_byte = mr_slice[0];
let second_byte = mr_slice[1];
assert_eq!(first_byte, 0);
assert_eq!(second_byte, first_byte.wrapping_add(1));
let mr = physical.memory_region(1, 1, 0).unwrap();
let mr_slice = unsafe { mr.as_slice().unwrap() };
let first_byte = mr_slice[0];
let second_byte = mr_slice[1];
assert_eq!(first_byte, 2);
assert_eq!(second_byte, first_byte.wrapping_add(1));
}
#[test]
fn test_fill_layers() {
let physical = builder(2)
.fully_contiguous()
.allocate_system()
.build()
.unwrap();
// Fill only layer 0
fill_layers(&physical, &[0], 0..1, FillPattern::Constant(0)).unwrap();
fill_layers(&physical, &[0], 1..2, FillPattern::Constant(1)).unwrap();
fill_layers(&physical, &[1], 0..1, FillPattern::Constant(100)).unwrap();
fill_layers(&physical, &[1], 1..2, FillPattern::Constant(101)).unwrap();
let mr_00 = unsafe { physical.memory_region(0, 0, 0).unwrap().as_slice().unwrap()[0] };
let mr_01 = unsafe { physical.memory_region(0, 1, 0).unwrap().as_slice().unwrap()[0] };
let mr_10 = unsafe { physical.memory_region(1, 0, 0).unwrap().as_slice().unwrap()[0] };
let mr_11 = unsafe { physical.memory_region(1, 1, 0).unwrap().as_slice().unwrap()[0] };
assert_eq!(mr_00, 0);
assert_eq!(mr_01, 1);
assert_eq!(mr_10, 100);
assert_eq!(mr_11, 101);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Transfer module for copying blocks between layouts with different storage locations.
//!
//! This module provides functionality for transferring KV cache blocks between layouts
//! that may be backed by different storage types (GPU memory, pinned host memory, disk, etc.)
//! and potentially across NIXL-connected remote nodes.
//!
//! # Core Concepts
//!
//! - [`PhysicalLayout`]: Wraps a layout with its physical storage location and NIXL metadata
//! - [`LayoutDescriptor`]: Serializable representation for cross-node communication
//! - Transfer strategies: memcpy, CUDA, NIXL based on source/destination locations
//! - Block-wise and layer-wise transfer operations
//!
//! # Usage
//!
//! ```rust,ignore
//! use dynamo_kvbm::v2::transfer::{PhysicalLayout, transfer_blocks};
//!
//! // Create local physical layout with NIXL registration
//! let src = PhysicalLayout::new_local(src_layout, StorageKind::Device(0))
//! .with_nixl_registration("local_agent".to_string())?;
//!
//! // Create remote physical layout
//! let dst = PhysicalLayout::new_remote(
//! dst_layout,
//! StorageKind::Pinned,
//! "remote_agent".to_string()
//! );
//!
//! // Transfer blocks from local to remote
//! let src_block_ids = [0, 1, 2];
//! let dst_block_ids = [0, 1, 2];
//! let future = transfer_blocks(&src, &dst, &src_block_ids, &dst_block_ids, &ctx)?;
//! future.await?;
//! ```
pub(crate) mod capabilities;
pub(crate) mod checksum;
pub mod context;
pub(crate) mod executor;
pub(crate) mod fill;
pub(crate) mod notifications;
pub(crate) mod options;
pub(crate) mod preferences;
pub(crate) mod strategy;
pub(crate) mod validation;
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests;
// Re-export StorageKind
pub use dynamo_memory::StorageKind;
pub use capabilities::TransferCapabilities;
pub use checksum::{BlockChecksum, compute_block_checksums, compute_layer_checksums};
pub use context::{TransferCompleteNotification, TransferConfig};
pub use dynamo_memory::nixl::NixlAgent;
pub use fill::{FillPattern, fill_blocks, fill_layers};
pub use options::{TransferOptions, TransferOptionsBuilder};
// TransferContext - managed by TransferManager
#[doc(hidden)]
pub use context::TransferContext;
use crate::BlockId;
pub use crate::layout::PhysicalLayout;
// Re-export manager types - TransferManager is the primary public API
pub use crate::manager::{LayoutHandle, SerializedLayout, TransferManager, WorkerAddress};
// #[cfg(test)]
// pub use testing::{RoundTripTest, RoundTripTestResult};
// /// Specification for bounce buffer in multi-hop transfers.
// ///
// /// This structure provides the layout and block IDs to use as an intermediate
// /// staging area when direct transfers are not allowed.
// #[deprecated(since = "2025-11-25", note = "use TransferOptions instead")]
// pub trait BounceBufferSpec: Send + Sync {
// fn layout(&self) -> &PhysicalLayout;
// fn block_ids(&self) -> &[BlockId];
// }
#[derive(Clone)]
pub struct BounceBuffer {
layout: LayoutHandle,
block_ids: Vec<BlockId>,
}
#[derive(Clone)]
pub struct BounceBufferInternal {
layout: PhysicalLayout,
block_ids: Vec<BlockId>,
}
impl BounceBuffer {
pub fn from_handle(layout: LayoutHandle, block_ids: Vec<BlockId>) -> Self {
Self { layout, block_ids }
}
#[doc(hidden)]
pub fn into_parts(self) -> (LayoutHandle, Vec<BlockId>) {
(self.layout, self.block_ids)
}
}
impl BounceBufferInternal {
pub fn from_layout(layout: PhysicalLayout, block_ids: Vec<BlockId>) -> Self {
Self { layout, block_ids }
}
}
// ============================================================================
// Layout Compatibility Helpers
// ============================================================================
use anyhow::anyhow;
use std::ops::Range;
/// Validate that layouts are compatible for transfer.
///
/// Returns an error if layouts require transformation, which is not yet supported.
/// This should be called early in transfer execution to fail fast.
pub(crate) fn validate_layout_compatibility(
src: &PhysicalLayout,
dst: &PhysicalLayout,
) -> anyhow::Result<()> {
let src_layout = src.layout();
let dst_layout = dst.layout();
if src_layout
.block_layout()
.requires_transform(&dst_layout.block_layout())
{
return Err(anyhow!(
"Layout transformation not supported: src={:?}, dst={:?}",
src_layout.block_layout(),
dst_layout.block_layout()
));
}
Ok(())
}
/// Check if layouts support whole-block transfers.
///
/// Returns true when:
/// - Both src and dst are fully contiguous
/// - Transfer is full-block (layer_range covers all layers or is None)
///
/// Note: Caller must have already validated layout compatibility via
/// [`validate_layout_compatibility`].
pub(crate) fn can_use_whole_block_transfer(
src: &PhysicalLayout,
dst: &PhysicalLayout,
layer_range: Option<&Range<usize>>,
) -> bool {
// Must be full-block transfer
let is_full_block = match layer_range {
None => true,
Some(range) => range.start == 0 && range.end == src.layout().num_layers(),
};
if !is_full_block {
return false;
}
// Both must be fully contiguous
src.layout().is_fully_contiguous() && dst.layout().is_fully_contiguous()
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! CUDA event polling-based completion checker.
use anyhow::Result;
use cudarc::driver::{CudaEvent, DriverError, result as cuda_result, sys::CUresult};
use super::CompletionChecker;
/// Completion checker that polls CUDA event status.
pub struct CudaEventChecker {
event: CudaEvent,
}
impl CudaEventChecker {
pub fn new(event: CudaEvent) -> Self {
Self { event }
}
}
impl CompletionChecker for CudaEventChecker {
fn is_complete(&self) -> Result<bool> {
// Query the CUDA event to check if it's complete
// cudaEventQuery returns cudaSuccess if complete, cudaErrorNotReady if still pending
unsafe {
match cuda_result::event::query(self.event.cu_event()) {
Ok(()) => Ok(true), // Event is complete
Err(DriverError(CUresult::CUDA_ERROR_NOT_READY)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("CUDA event query failed: {:?}", e)),
}
}
}
}
#[cfg(all(test, feature = "testing-kvbm"))]
mod tests {
use crate::manager::TransferManager;
use crate::transfer::tests::CudaSleep;
use dynamo_memory::nixl::NixlAgent;
use std::time::{Duration, Instant};
#[tokio::test]
async fn test_cuda_event_delayed_notification() {
let agent = NixlAgent::new("test_agent").unwrap();
let manager = TransferManager::builder()
.cuda_device_id(0)
.nixl_agent(agent)
.build()
.unwrap();
let stream = manager.h2d_stream();
let cuda_ctx = manager.cuda_context();
// Get or create the CudaSleep utility (compiles kernel and calibrates on first use)
let cuda_sleep = CudaSleep::for_context(cuda_ctx).unwrap();
// Test 1: Launch sleep and wait via async notification
let t0_queue_start = Instant::now();
cuda_sleep
.launch(Duration::from_millis(600), stream)
.unwrap();
let queue_time = t0_queue_start.elapsed();
let event = stream.record_event(None).unwrap();
let notification = manager.register_cuda_event(event);
notification.await.unwrap();
let wait_time = t0_queue_start.elapsed() - queue_time;
println!(
"GPU sleep test: queue {:?}, wait {:?}",
queue_time, wait_time
);
assert!(
queue_time < Duration::from_millis(10),
"launching the sleep kernel should be fast: {:?}",
queue_time
);
assert!(
wait_time >= Duration::from_millis(500),
"wait time should reflect >=500ms of GPU work: {:?}",
wait_time
);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment