// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use super::*; use nixl_sys::NixlDescriptor; use utils::*; use zmq::*; use BlockTransferPool::*; use crate::block_manager::{ BasicMetadata, Storage, block::{ Block, BlockDataProvider, BlockDataProviderMut, ReadableBlock, WritableBlock, data::local::LocalBlockData, locality, transfer::{TransferContext, WriteTo, WriteToStrategy}, }, connector::scheduler::{SchedulingDecision, TransferSchedulerClient}, storage::{DeviceStorage, DiskStorage, Local, PinnedStorage}, }; use anyhow::Result; use async_trait::async_trait; use std::{any::Any, sync::Arc}; type LocalBlock = Block; type LocalBlockDataList = Vec>; /// A handler for all block transfers. Wraps a group of [`BlockTransferPoolManager`]s. #[derive(Clone)] pub struct BlockTransferHandler { device: Option>, host: Option>, disk: Option>, context: Arc, scheduler_client: Option, // add worker-connector scheduler client here } impl BlockTransferHandler { pub fn new( device_blocks: Option>>, host_blocks: Option>>, disk_blocks: Option>>, context: Arc, scheduler_client: Option, // add worker-connector scheduler client here ) -> Result { Ok(Self { device: Self::get_local_data(device_blocks), host: Self::get_local_data(host_blocks), disk: Self::get_local_data(disk_blocks), context, scheduler_client, }) } fn get_local_data( blocks: Option>>, ) -> Option> { blocks.map(|blocks| { blocks .into_iter() .map(|b| { let block_data = b.block_data() as &dyn Any; block_data .downcast_ref::>() .unwrap() .clone() }) .collect() }) } /// Initiate a transfer between two pools. async fn begin_transfer( &self, source_pool_list: &Option>, target_pool_list: &Option>, request: BlockTransferRequest, ) -> Result> where Source: Storage + NixlDescriptor, Target: Storage + NixlDescriptor, // Check that the source block is readable, local, and writable to the target block. LocalBlockData: ReadableBlock + Local + WriteToStrategy>, // Check that the target block is writable. LocalBlockData: WritableBlock, LocalBlockData: BlockDataProvider, LocalBlockData: BlockDataProviderMut, { let Some(source_pool_list) = source_pool_list else { return Err(anyhow::anyhow!("Source pool manager not initialized")); }; let Some(target_pool_list) = target_pool_list else { return Err(anyhow::anyhow!("Target pool manager not initialized")); }; // Extract the `from` and `to` indices from the request. let source_idxs = request.blocks().iter().map(|(from, _)| *from); let target_idxs = request.blocks().iter().map(|(_, to)| *to); // Get the blocks corresponding to the indices. let sources: Vec> = source_idxs .map(|idx| source_pool_list[idx].clone()) .collect(); let mut targets: Vec> = target_idxs .map(|idx| target_pool_list[idx].clone()) .collect(); // Perform the transfer, and return the notifying channel. match sources.write_to(&mut targets, self.context.clone()) { Ok(channel) => Ok(channel), Err(e) => { tracing::error!("Failed to write to blocks: {:?}", e); Err(e.into()) } } } pub async fn execute_transfer(&self, request: BlockTransferRequest) -> Result<()> { tracing::debug!( "Performing transfer of {} blocks from {:?} to {:?}", request.blocks().len(), request.from_pool(), request.to_pool() ); tracing::debug!("request: {request:#?}"); let notify = match (request.from_pool(), request.to_pool()) { (Device, Host) => self.begin_transfer(&self.device, &self.host, request).await, (Host, Device) => self.begin_transfer(&self.host, &self.device, request).await, (Host, Disk) => self.begin_transfer(&self.host, &self.disk, request).await, (Disk, Device) => self.begin_transfer(&self.disk, &self.device, request).await, _ => { return Err(anyhow::anyhow!("Invalid transfer type.")); } }?; notify.await?; Ok(()) } } #[async_trait] impl Handler for BlockTransferHandler { async fn handle(&self, mut message: MessageHandle) -> Result<()> { if message.data.len() != 1 { return Err(anyhow::anyhow!( "Block transfer request must have exactly one data element" )); } let mut request: BlockTransferRequest = serde_json::from_slice(&message.data[0])?; let result = if let Some(req) = request.connector_req.take() { let operation_id = req.uuid; tracing::debug!( request_id = %req.request_id, operation_id = %operation_id, "scheduling transfer" ); let client = self .scheduler_client .as_ref() .expect("scheduler client is required") .clone(); let handle = client.schedule_transfer(req).await?; // we don't support cancellation yet assert_eq!(handle.scheduler_decision(), SchedulingDecision::Execute); match self.execute_transfer(request).await { Ok(_) => { handle.mark_complete(Ok(())).await; Ok(()) } Err(e) => { handle.mark_complete(Err(anyhow::anyhow!("{}", e))).await; Err(e) } } } else { self.execute_transfer(request).await }; // we always ack regardless of if we error or not message.ack().await?; // the error may trigger a cancellation result } }