pending.rs 15.7 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
//! # Transfer Managers
//!
//! Transfer managers are responsible for multiple things:
//! - Before the transfer:
//!     - Rate-limiting the number of transfers that can be initiated concurrently. This is implemented through bounded channels.
//!         - Due to the nature of the [`super::OffloadManager`], we only apply this rate-limiting to offloads.
//! - During the transfer:
//!     - Initiating the transfer
//!     - Holding strong references to blocks being transfered.
//! - After the transfer:
//!     - Dropping these references once the transfer is complete.
//!     - Registering the blocks with the target pool.
//!     - Returning the registered blocks to the caller.
//!
//! This is implemented through the [`TransferManager`] trait, which takes a single [`PendingTransfer`]
//! and initiates the transfer.
//!
//! Since CUDA and NIXL transfers use completely different semantics, we implement two separate transfer managers.
//!
//! ## Workflow
24
25
//! 1. A transfer request is made by calling [`TransferManager::enqueue_transfer`]
//! 2. [`TransferManager::enqueue_transfer`] performs the transfer, and enqueues relevant data into a bounded channel.
26
27
28
//! 3. A worker thread (consuming this bounded channel and enforcing rate limiting) awaits the incoming transfers.
//! 4. After a transfer is complete, the worker thread registers the blocks with the target pool, and returns the registered blocks to the caller.

Ryan Olson's avatar
Ryan Olson committed
29
use nixl_sys::NixlDescriptor;
30
use std::marker::PhantomData;
31
use std::pin::Pin;
32
use std::sync::Arc;
33
use tokio::runtime::Handle;
Ryan Olson's avatar
Ryan Olson committed
34
use tokio::sync::{mpsc, oneshot};
35
use tokio_util::sync::CancellationToken;
36

37
use crate::block_manager::block::{
Ryan Olson's avatar
Ryan Olson committed
38
39
    BlockDataProvider, BlockDataProviderMut, BlockError, BlockMetadata, BlockState, ImmutableBlock,
    MutableBlock, ReadableBlock, WritableBlock,
40
41
    locality::LocalityProvider,
    transfer::{TransferContext, WriteTo, WriteToStrategy},
42
};
Ryan Olson's avatar
Ryan Olson committed
43
use crate::block_manager::pool::{BlockPool, BlockPoolError};
44
45
use crate::block_manager::storage::{Local, Storage};

46
use anyhow::Result;
47
use async_trait::async_trait;
48
use futures::{StreamExt, stream::FuturesUnordered};
49

50
use super::BlockResult;
51

52
53
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;

54
/// Manage a set of pending transfers.
Ryan Olson's avatar
Ryan Olson committed
55
56
57
58
59
60
pub struct PendingTransfer<
    Source: Storage,
    Target: Storage,
    Locality: LocalityProvider,
    Metadata: BlockMetadata,
> {
61
    /// The block being copied from.
Ryan Olson's avatar
Ryan Olson committed
62
    sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
63
    /// The block being copied to.
Ryan Olson's avatar
Ryan Olson committed
64
    targets: Vec<MutableBlock<Target, Locality, Metadata>>,
65
    /// The oneshot sender that optionally returns the registered blocks once the transfer is complete.
Ryan Olson's avatar
Ryan Olson committed
66
    completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
67
    /// The target pool that will receive the registered block.
Ryan Olson's avatar
Ryan Olson committed
68
    target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
69
70
}

Ryan Olson's avatar
Ryan Olson committed
71
72
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
    PendingTransfer<Source, Target, Locality, Metadata>
73
74
{
    pub fn new(
Ryan Olson's avatar
Ryan Olson committed
75
76
77
78
        sources: Vec<ImmutableBlock<Source, Locality, Metadata>>,
        targets: Vec<MutableBlock<Target, Locality, Metadata>>,
        completion_indicator: Option<oneshot::Sender<BlockResult<Target, Locality, Metadata>>>,
        target_pool: Arc<dyn BlockPool<Target, Locality, Metadata>>,
79
    ) -> Self {
80
        assert_eq!(sources.len(), targets.len());
81
        Self {
82
83
84
            sources,
            targets,
            completion_indicator,
85
            target_pool,
86
87
88
        }
    }

Ryan Olson's avatar
Ryan Olson committed
89
    async fn handle_complete(self) -> Result<()> {
90
        let Self {
91
92
93
            sources,
            mut targets,
            target_pool,
94
            completion_indicator,
95
96
97
            ..
        } = self;

98
99
100
101
        for (source, target) in sources.iter().zip(targets.iter_mut()) {
            transfer_metadata(source, target)?;
        }

Ryan Olson's avatar
Ryan Olson committed
102
103
104
        let blocks = target_pool.register_blocks(targets).await?;

        tracing::debug!("Transfer complete. Registered {} blocks.", blocks.len());
105

106
        if let Some(completion_indicator) = completion_indicator {
107
108
109
            completion_indicator
                .send(Ok(blocks))
                .map_err(|_| BlockPoolError::ProgressEngineShutdown)?;
110
        }
111
112
113
114
115

        Ok(())
    }
}

Ryan Olson's avatar
Ryan Olson committed
116
117
118
119
120
121
122
123
fn transfer_metadata<
    Source: Storage,
    Target: Storage,
    Locality: LocalityProvider,
    Metadata: BlockMetadata,
>(
    source: &ImmutableBlock<Source, Locality, Metadata>,
    target: &mut MutableBlock<Target, Locality, Metadata>,
124
125
) -> Result<()> {
    // Only registered blocks can be transferred. There are upstream checks for this, so this shouldn't ever fail.
126
    if let BlockState::Registered(reg_handle, _) = source.state() {
127
128
129
130
131
132
133
134
135
136
        // Bring the block back to the 'Reset' state.
        target.reset();
        // Transfer metadata.
        target.update_metadata(source.metadata().clone());
        // Copy tokens
        target.apply_token_block(reg_handle.token_block().clone())?;
    } else {
        Err(BlockPoolError::BlockError(BlockError::InvalidState(
            "Block is not registered.".to_string(),
        )))?;
137
    }
138
139
140
141
142

    Ok(())
}

#[async_trait]
Ryan Olson's avatar
Ryan Olson committed
143
144
145
146
147
148
pub trait TransferManager<
    Source: Storage,
    Target: Storage,
    Locality: LocalityProvider,
    Metadata: BlockMetadata,
>: Send + Sync
149
150
{
    /// Begin a transfer. Blocks if the pending queue is full.
151
    async fn enqueue_transfer(
152
        &self,
Ryan Olson's avatar
Ryan Olson committed
153
        pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
154
    ) -> Result<()>;
155
156
}

Ryan Olson's avatar
Ryan Olson committed
157
158
159
160
161
162
163
164
struct TransferCompletionManager<
    Source: Storage,
    Target: Storage,
    Locality: LocalityProvider,
    Metadata: BlockMetadata,
> {
    num_blocks_transferred: usize,
    _phantom: PhantomData<(Source, Target, Locality, Metadata)>,
165
166
}

Ryan Olson's avatar
Ryan Olson committed
167
168
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
    TransferCompletionManager<Source, Target, Locality, Metadata>
169
{
170
    pub fn new() -> Self {
Ryan Olson's avatar
Ryan Olson committed
171
172
173
174
175
        Self {
            num_blocks_transferred: 0,
            _phantom: PhantomData,
        }
    }
176

Ryan Olson's avatar
Ryan Olson committed
177
178
179
180
181
    pub async fn handle_complete(
        &mut self,
        pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
    ) -> Result<()> {
        self.num_blocks_transferred += pending_transfer.sources.len();
182

Ryan Olson's avatar
Ryan Olson committed
183
184
185
186
187
188
189
190
        match pending_transfer.handle_complete().await {
            Ok(_) => {}
            Err(e) => {
                // The only case where this can fail is if the progress engine is being shutdown.
                // This is not a problem, so we can just ignore it.
                tracing::warn!("Error handling transfer completion: {:?}", e);
            }
        }
191
192
193
194
195

        Ok(())
    }
}

Ryan Olson's avatar
Ryan Olson committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
type TransferFuture<Source, Target, Locality, Metadata> = Pin<
    Box<
        dyn std::future::Future<Output = PendingTransfer<Source, Target, Locality, Metadata>>
            + Send
            + Sync,
    >,
>;

pub struct LocalTransferManager<
    Source: Storage,
    Target: Storage,
    Locality: LocalityProvider,
    Metadata: BlockMetadata,
> {
    futures_tx: mpsc::Sender<TransferFuture<Source, Target, Locality, Metadata>>,
211
212
213
    transfer_ctx: Arc<TransferContext>,
}

Ryan Olson's avatar
Ryan Olson committed
214
215
216
impl<Source: Storage, Target: Storage, Locality: LocalityProvider, Metadata: BlockMetadata>
    LocalTransferManager<Source, Target, Locality, Metadata>
{
217
218
219
220
221
    pub fn new(
        transfer_ctx: Arc<TransferContext>,
        max_concurrent_transfers: usize,
        runtime: &Handle,
        cancellation_token: CancellationToken,
222
    ) -> Result<Self> {
223
224
        let (futures_tx, mut futures_rx) = mpsc::channel(1);

225
        let mut completion_manager = TransferCompletionManager::new();
Ryan Olson's avatar
Ryan Olson committed
226

227
228
        CriticalTaskExecutionHandle::new_with_runtime(
            move |cancel_token| async move {
Ryan Olson's avatar
Ryan Olson committed
229
                let mut pending_transfers: FuturesUnordered<TransferFuture<Source, Target, Locality, Metadata>> = FuturesUnordered::new();
230
231
                loop {
                    tokio::select! {
232

233
234
235
                        _ = cancel_token.cancelled() => {
                            return Ok(());
                        }
236

237
238
239
                        Some(future) = futures_rx.recv() => {
                            // If we're at max size, block the worker thread on the next() call until we have capacity.
                            while pending_transfers.len() >= max_concurrent_transfers {
Ryan Olson's avatar
Ryan Olson committed
240
241
242
243
244
                                if let Some(pending_transfer) = pending_transfers.next().await {
                                    completion_manager.handle_complete(pending_transfer).await?;
                                } else {
                                    break;
                                }
245
                            }
Ryan Olson's avatar
Ryan Olson committed
246

247
248
                            pending_transfers.push(future);
                        }
Ryan Olson's avatar
Ryan Olson committed
249
250
                        Some(pending_transfer) = pending_transfers.next(), if !pending_transfers.is_empty() => {
                            completion_manager.handle_complete(pending_transfer).await?;
251
252
253
                        }
                    }
                }
254
255
            },
            cancellation_token.clone(),
Ryan Olson's avatar
Ryan Olson committed
256
            "Local Transfer Manager",
257
258
259
260
261
            runtime,
        )?
        .detach();

        Ok(Self {
262
263
            futures_tx,
            transfer_ctx,
264
        })
265
266
267
268
    }
}

#[async_trait]
Ryan Olson's avatar
Ryan Olson committed
269
270
impl<Source, Target, Locality, Metadata> TransferManager<Source, Target, Locality, Metadata>
    for LocalTransferManager<Source, Target, Locality, Metadata>
271
where
Ryan Olson's avatar
Ryan Olson committed
272
273
274
    Source: Storage + NixlDescriptor,
    Target: Storage + NixlDescriptor,
    Locality: LocalityProvider,
275
276
    Metadata: BlockMetadata,
    // Check that the source block is readable, local, and writable to the target block.
Ryan Olson's avatar
Ryan Olson committed
277
    ImmutableBlock<Source, Locality, Metadata>: ReadableBlock<StorageType = Source>
278
        + Local
Ryan Olson's avatar
Ryan Olson committed
279
        + WriteToStrategy<MutableBlock<Target, Locality, Metadata>>,
280
    // Check that the target block is writable.
Ryan Olson's avatar
Ryan Olson committed
281
282
283
284
    MutableBlock<Target, Locality, Metadata>: WritableBlock<StorageType = Target>,
    // Check that the source and target blocks have the same locality.
    ImmutableBlock<Source, Locality, Metadata>: BlockDataProvider<Locality = Locality>,
    MutableBlock<Target, Locality, Metadata>: BlockDataProviderMut<Locality = Locality>,
285
{
286
    async fn enqueue_transfer(
287
        &self,
Ryan Olson's avatar
Ryan Olson committed
288
        mut pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
289
    ) -> Result<()> {
290
291
        let notify = pending_transfer
            .sources
Ryan Olson's avatar
Ryan Olson committed
292
            .write_to(&mut pending_transfer.targets, self.transfer_ctx.clone())?;
293
294

        let completion_future = async move {
295
            let _ = notify.await;
Ryan Olson's avatar
Ryan Olson committed
296
            pending_transfer
297
298
299
300
301
        };

        // Futures_(tx/rx) has a capacity of 1. If the queue worker has received another future and is awaiting next() due to a full `FuturesUnordered`,
        // this call will block until the worker has processed the prior future.
        self.futures_tx.send(Box::pin(completion_future)).await?;
302
303
304
305

        Ok(())
    }
}
306
307

/// A transfer manager that enforces a max batch size for transfers.
Ryan Olson's avatar
Ryan Olson committed
308
pub struct TransferBatcher<Source, Target, Locality, Metadata, Manager>
309
310
311
where
    Source: Storage,
    Target: Storage,
Ryan Olson's avatar
Ryan Olson committed
312
    Locality: LocalityProvider,
313
    Metadata: BlockMetadata,
Ryan Olson's avatar
Ryan Olson committed
314
    Manager: TransferManager<Source, Target, Locality, Metadata>,
315
316
317
{
    transfer_manager: Manager,
    max_transfer_batch_size: usize,
318
319
    runtime: Handle,
    cancellation_token: CancellationToken,
Ryan Olson's avatar
Ryan Olson committed
320
    _phantom: PhantomData<(Source, Target, Locality, Metadata)>,
321
322
}

Ryan Olson's avatar
Ryan Olson committed
323
324
impl<Source, Target, Locality, Metadata, Manager>
    TransferBatcher<Source, Target, Locality, Metadata, Manager>
325
326
327
where
    Source: Storage,
    Target: Storage,
Ryan Olson's avatar
Ryan Olson committed
328
329
330
    Locality: LocalityProvider + 'static,
    Metadata: BlockMetadata + 'static,
    Manager: TransferManager<Source, Target, Locality, Metadata> + 'static,
331
{
332
333
334
335
336
337
    pub fn new(
        transfer_manager: Manager,
        max_transfer_batch_size: usize,
        runtime: &Handle,
        cancellation_token: CancellationToken,
    ) -> Self {
338
339
340
        Self {
            transfer_manager,
            max_transfer_batch_size,
341
342
            runtime: runtime.clone(),
            cancellation_token,
343
344
345
346
347
348
            _phantom: PhantomData,
        }
    }
}

#[async_trait]
Ryan Olson's avatar
Ryan Olson committed
349
350
351
impl<Source, Target, Locality, Metadata, Manager>
    TransferManager<Source, Target, Locality, Metadata>
    for TransferBatcher<Source, Target, Locality, Metadata, Manager>
352
where
Ryan Olson's avatar
Ryan Olson committed
353
354
355
    Source: Storage + 'static,
    Target: Storage + 'static,
    Locality: LocalityProvider + 'static,
356
    Metadata: BlockMetadata,
Ryan Olson's avatar
Ryan Olson committed
357
    Manager: TransferManager<Source, Target, Locality, Metadata>,
358
359
360
{
    async fn enqueue_transfer(
        &self,
Ryan Olson's avatar
Ryan Olson committed
361
        pending_transfer: PendingTransfer<Source, Target, Locality, Metadata>,
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    ) -> Result<()> {
        // If it's smaller than the max batch size, just enqueue it.
        if pending_transfer.sources.len() < self.max_transfer_batch_size {
            return self
                .transfer_manager
                .enqueue_transfer(pending_transfer)
                .await;
        }

        // Otherwise, we need to split the transfer into multiple smaller transfers.

        let PendingTransfer {
            mut sources,
            mut targets,
            completion_indicator,
            target_pool,
        } = pending_transfer;

        let mut indicators = Vec::new();

        while !sources.is_empty() {
            let sources = sources
                .drain(..std::cmp::min(self.max_transfer_batch_size, sources.len()))
                .collect();
            let targets = targets
                .drain(..std::cmp::min(self.max_transfer_batch_size, targets.len()))
                .collect();

            // If we have a completion indicator, we need to create a new one for each sub-transfer.
            let indicator = if completion_indicator.is_some() {
                let (batch_tx, batch_rx) = oneshot::channel();
                indicators.push(batch_rx);
                Some(batch_tx)
            } else {
                None
            };

            let request = PendingTransfer::new(sources, targets, indicator, target_pool.clone());
            // Enqueue our reduced transfer. This may block if the queue is full.
            self.transfer_manager.enqueue_transfer(request).await?;
        }

        if let Some(completion_indicator) = completion_indicator {
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            CriticalTaskExecutionHandle::new_with_runtime(
                move |cancel_token| async move {
                    let mut results = Vec::new();

                    for indicator in indicators.into_iter() {
                        // Await each sub-transfer, and append the results to our final results.
                        tokio::select! {
                            _ = cancel_token.cancelled() => {
                                return Ok(());
                            }

                            Ok(indicator) = indicator => {
                                let result = match indicator {
                                    Ok(result) => result,
                                    Err(e) => {
                                        tracing::error!("Error receiving transfer results: {:?}", e);
Ryan Olson's avatar
Ryan Olson committed
421
                                        let _ = completion_indicator.send(Err(e));
422
423
424
425
426
                                        return Ok(());
                                    }
                                };
                                results.extend(result);
                            }
427
                        }
428
429
430
                    }

                    // Send the final results to the top-level completion indicator.
Ryan Olson's avatar
Ryan Olson committed
431
                    let _ = completion_indicator.send(Ok(results));
432

433
434
435
436
437
438
                    Ok(())
                },
                self.cancellation_token.clone(),
                "Transfer Batcher",
                &self.runtime,
            )?.detach();
439
440
441
442
443
        }

        Ok(())
    }
}