// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use super::*; use crate::object::ObjectBlockOps; use anyhow::Result; // velo event types used via fully-qualified paths (::velo::Event, ::velo::EventManager) use futures::future::BoxFuture; use std::collections::HashMap; use std::sync::{Arc, RwLock}; /// SPMD (Single Program, Multiple Data) parallel worker group. /// /// Wraps a set of rank-indexed [`Worker`]s and executes every operation on /// all of them in parallel. Each worker has its own rank, physical layout /// handles, and `TransferManager`, but they all receive the same logical /// commands (transfer, connect, import/export metadata). /// /// Transfer completion notifications from individual workers are aggregated /// into a single notification via the event system, so callers see one /// completion event per logical operation regardless of worker count. /// /// Remote handle mappings are stored per `(InstanceId, worker_idx, /// LogicalLayoutHandle)` so that each rank resolves to its own peer handle /// during RDMA transfers. pub struct SpmdParallelWorkers { workers: Vec>, events: Arc<::velo::EventManager>, runtime: tokio::runtime::Handle, /// Remote handle mappings: (InstanceId, worker_idx, LogicalLayoutHandle) -> remote LayoutHandle. /// Populated by `connect_remote` for later use by `execute_remote_onboard_for_instance`. remote_handles: RwLock>, } impl SpmdParallelWorkers { /// Create a new SpmdParallelWorkers. /// /// # Arguments /// * `workers` - The underlying workers (one per rank) /// * `events` - The event system for aggregating completion notifications /// * `runtime` - The tokio runtime handle for spawning aggregation tasks pub fn new( workers: Vec>, events: Arc<::velo::EventManager>, runtime: tokio::runtime::Handle, ) -> Self { Self { workers, events, runtime, remote_handles: RwLock::new(HashMap::new()), } } /// Get the number of workers. pub fn worker_count(&self) -> usize { self.workers.len() } } impl WorkerTransfers for SpmdParallelWorkers { fn execute_local_transfer( &self, src: LogicalLayoutHandle, dst: LogicalLayoutHandle, src_block_ids: Arc<[BlockId]>, dst_block_ids: Arc<[BlockId]>, options: kvbm_physical::transfer::TransferOptions, ) -> Result { let notifications = self .workers .iter() .map(|worker| { worker.execute_local_transfer( src, dst, src_block_ids.clone(), dst_block_ids.clone(), options.clone(), ) }) .collect::>>()?; TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime) } fn execute_remote_onboard( &self, src: RemoteDescriptor, dst: LogicalLayoutHandle, dst_block_ids: Arc<[BlockId]>, options: kvbm_physical::transfer::TransferOptions, ) -> Result { let notifications = self .workers .iter() .map(|worker| { worker.execute_remote_onboard( src.clone(), dst, dst_block_ids.clone(), options.clone(), ) }) .collect::>>()?; TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime) } fn execute_remote_offload( &self, src: LogicalLayoutHandle, src_block_ids: Arc<[BlockId]>, dst: RemoteDescriptor, options: kvbm_physical::transfer::TransferOptions, ) -> Result { let notifications = self .workers .iter() .map(|worker| { worker.execute_remote_offload( src, src_block_ids.clone(), dst.clone(), options.clone(), ) }) .collect::>>()?; TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime) } fn connect_remote( &self, instance_id: InstanceId, metadata: Vec, ) -> Result { // Validate metadata count matches worker count if metadata.len() != self.workers.len() { anyhow::bail!( "Metadata count ({}) doesn't match worker count ({})", metadata.len(), self.workers.len() ); } // Collect handles to store and responses to await let mut new_handles = Vec::new(); let mut import_responses = Vec::new(); for (worker_idx, (worker, meta)) in self.workers.iter().zip(metadata.into_iter()).enumerate() { // Unpack to extract logical type info let unpacked = meta.unpack()?; // Collect handle mappings for descriptor in &unpacked.layouts { new_handles.push(( (instance_id, worker_idx, descriptor.logical_type), descriptor.handle, )); } // Repack for the underlying worker's import_metadata let repacked = SerializedLayout::pack( unpacked.worker_address, unpacked.nixl_metadata, unpacked.layouts, )?; // Call underlying worker's import_metadata import_responses.push(worker.import_metadata(repacked)?); } // Store all handle mappings { let mut handles = self.remote_handles.write().unwrap(); for (key, value) in new_handles { handles.insert(key, value); } } // If all responses are ready (synchronous), return immediately if import_responses.iter().all(|r| !r.could_yield()) { return Ok(ConnectRemoteResponse::ready()); } // Create an event to aggregate all import completions let event = self.events.new_event()?; let awaiter = self.events.awaiter(event.handle())?; // Spawn task to await all import responses and signal completion self.runtime .spawn(await_import_responses(import_responses, event)); Ok(ConnectRemoteResponse::from_awaiter(awaiter)) } fn has_remote_metadata(&self, instance_id: InstanceId) -> bool { let handles = self.remote_handles.read().unwrap(); handles.keys().any(|(id, _, _)| *id == instance_id) } fn execute_remote_onboard_for_instance( &self, instance_id: InstanceId, remote_logical_type: LogicalLayoutHandle, src_block_ids: Vec, dst: LogicalLayoutHandle, dst_block_ids: Arc<[BlockId]>, options: kvbm_physical::transfer::TransferOptions, ) -> Result { let handles = self.remote_handles.read().unwrap(); let mut notifications = Vec::with_capacity(self.workers.len()); // SPMD: Execute SAME transfer on EVERY worker, each with its own remote handle for (worker_idx, worker) in self.workers.iter().enumerate() { let remote_handle = handles .get(&(instance_id, worker_idx, remote_logical_type)) .ok_or_else(|| { anyhow::anyhow!( "No remote {:?} handle for instance {} worker {}", remote_logical_type, instance_id, worker_idx ) })?; let descriptor = RemoteDescriptor::Layout { handle: *remote_handle, block_ids: src_block_ids.clone(), }; notifications.push(worker.execute_remote_onboard( descriptor, dst, dst_block_ids.clone(), options.clone(), )?); } TransferCompleteNotification::aggregate(notifications, &self.events, &self.runtime) } } /// Helper to await all import metadata responses and signal completion via an event. /// Helper to await all import metadata responses and signal completion via an event. async fn await_import_responses(responses: Vec, event: ::velo::Event) { let results: Vec>> = futures::future::join_all(responses.into_iter().map(|r| r.into_future())).await; // Check for any failures let errors: Vec<_> = results.into_iter().filter_map(|r| r.err()).collect(); if errors.is_empty() { let _ = event.trigger(); } else { let error_msg = errors .iter() .map(|e| e.to_string()) .collect::>() .join("; "); let _ = event.poison(error_msg); } } impl ParallelWorkers for SpmdParallelWorkers { fn export_metadata(&self) -> Result> { let metadata = self .workers .iter() .map(|worker| worker.export_metadata()) .collect::>>()?; Ok(metadata) } fn import_metadata( &self, metadata: Vec, ) -> Result> { // validate the size of the metadata is the same as the number of workers if metadata.len() != self.workers.len() { return Err(anyhow::anyhow!( "Metadata size does not match number of workers" )); } let results = self .workers .iter() .zip(metadata.iter()) .map(|(worker, metadata)| worker.import_metadata(metadata.clone())) .collect::>>()?; Ok(results) } fn worker_count(&self) -> usize { self.workers.len() } fn workers(&self) -> &[Arc] { &self.workers } } impl ObjectBlockOps for SpmdParallelWorkers { fn has_blocks( &self, keys: Vec, ) -> BoxFuture<'static, Vec<(SequenceHash, Option)>> { // For has_blocks, we query all workers and verify consistency. // All workers should agree on block presence for SPMD semantics. // We return the results from worker 0 but verify all workers agree. let workers = self.workers.clone(); let _runtime = self.runtime.clone(); Box::pin(async move { if workers.is_empty() { return keys.into_iter().map(|k| (k, None)).collect(); } // Query all workers in parallel let futures: Vec<_> = workers .iter() .map(|worker| worker.has_blocks(keys.clone())) .collect(); let results: Vec)>> = futures::future::join_all(futures).await; // Return results from first worker (all should agree in SPMD) // In debug mode, we could verify consistency across workers results.into_iter().next().unwrap_or_default() }) } fn put_blocks( &self, keys: Vec, src_layout: LogicalLayoutHandle, block_ids: Vec, ) -> BoxFuture<'static, Vec>> { // For put_blocks, each worker writes with its own rank-prefixed key. // Each worker resolves the logical handle to its own physical layout. // All workers must succeed for the operation to be considered successful. let workers = self.workers.clone(); Box::pin(async move { if workers.is_empty() { return keys.into_iter().map(Err).collect(); } // Execute put on all workers in parallel // Each worker resolves src_layout to its own physical layout let futures: Vec<_> = workers .iter() .map(|worker| worker.put_blocks(keys.clone(), src_layout, block_ids.clone())) .collect(); let results: Vec>> = futures::future::join_all(futures).await; // Aggregate: a key succeeded only if ALL workers succeeded let num_keys = keys.len(); let mut aggregated = Vec::with_capacity(num_keys); for (key_idx, key) in keys.iter().enumerate() { let all_succeeded = results.iter().all(|worker_results| { worker_results .get(key_idx) .map(|r| r.is_ok()) .unwrap_or(false) }); if all_succeeded { aggregated.push(Ok(*key)); } else { aggregated.push(Err(*key)); } } aggregated }) } fn get_blocks( &self, keys: Vec, dst_layout: LogicalLayoutHandle, block_ids: Vec, ) -> BoxFuture<'static, Vec>> { // For get_blocks, each worker reads from its own rank-prefixed key. // Each worker resolves the logical handle to its own physical layout. // All workers must succeed for the operation to be considered successful. let workers = self.workers.clone(); Box::pin(async move { if workers.is_empty() { return keys.into_iter().map(Err).collect(); } // Execute get on all workers in parallel // Each worker resolves dst_layout to its own physical layout let futures: Vec<_> = workers .iter() .map(|worker| worker.get_blocks(keys.clone(), dst_layout, block_ids.clone())) .collect(); let results: Vec>> = futures::future::join_all(futures).await; // Aggregate: a key succeeded only if ALL workers succeeded let num_keys = keys.len(); let mut aggregated = Vec::with_capacity(num_keys); for (key_idx, key) in keys.iter().enumerate() { let all_succeeded = results.iter().all(|worker_results| { worker_results .get(key_idx) .map(|r| r.is_ok()) .unwrap_or(false) }); if all_succeeded { aggregated.push(Ok(*key)); } else { aggregated.push(Err(*key)); } } aggregated }) } }