offload.rs 89.4 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
//! # Offload Manager
//! The offload manager is responsible for handling all block transfers between different cache levels.
//!
//! ## Offloading
//! Offloading is the process of moving blocks to a cache level further away from the device.
Ryan Olson's avatar
Ryan Olson committed
9
//! When blocks are registered (via [`ManagedBlockPool::register_blocks`]), they are automatically sent to the offload manager.
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
//! Due to limited bandwidth, the offload manager must prioritize which offloads to perform.
//! This is indicated by the `priority` parameter to [`OffloadManager::offload`].
//! When a offload request is received, the offload manager will enqueue it into a priority queue.
//! This priority queue is keyed by the `priority` parameter, where blocks with lower priority values are processed first.
//! Within the same priority, blocks that were sent to the offload manager earlier are processed first.
//!
//! ## Onboarding
//! Onboarding is the process of moving blocks to a cache level closer to the device.
//! All onboardings are manually triggered through the [`OffloadManager::onboard`] method.
//!
//! ## Transfer Managers
//! The offload manager uses two transfer managers to handle the offloading and onboarding of blocks.
//!
//! The [`CudaTransferManager`] is responsible for transfers between the device and host.
//! The [`DiskTransferManager`] is responsible for transfers from host to disk and disk to device.
//!
//! ## Worker Threads
//! The offload manager uses two kinds of worker threads to handle the offloading and onboarding of blocks.
//!
//! The [`OffloadManager::offload_worker`] is responsible for offloading blocks.
//! The [`OffloadManager::onboard_worker`] is responsible for onboarding blocks.
//!
//! The kind of offloads/onboards they perform is dictated by the source and target arguments
33
//! of the [`OffloadManager::offload_worker`] and [`OffloadManager::onboard_worker`] methods.
34

Ryan Olson's avatar
Ryan Olson committed
35
use super::block::{
36
    BlockError, BlockMetadata, BlockState, ImmutableBlock, MutableBlock,
37
38
    locality::LocalityProvider,
    transfer::{PoolConfig, TransferContext},
Ryan Olson's avatar
Ryan Olson committed
39
40
};
use super::pool::{BlockPool, BlockPoolError};
41
use super::storage::{Cuda, Storage};
42
use super::{DeviceStorage, DiskStorage, KvManagerModelConfig, PinnedStorage};
43
use nixl_sys::Agent as NixlAgent;
44
45
46
47
use std::sync::{
    Arc,
    atomic::{AtomicU64, Ordering},
};
48
49
50
use tokio::runtime::Handle;
use tokio::sync::{
    mpsc::{self, error::TryRecvError},
51
    oneshot,
52
};
53
use tokio_util::sync::CancellationToken;
54
55
56
57
58
59

use anyhow::Result;
use std::any::Any;

use std::collections::BTreeSet;

60
pub mod filter;
61
mod pending;
62
pub mod request;
63

64
use filter::OffloadFilter;
Ryan Olson's avatar
Ryan Olson committed
65
use pending::{LocalTransferManager, PendingTransfer, TransferBatcher, TransferManager};
66
use request::{BlockResult, OffloadRequest, OffloadRequestKey, OnboardRequest};
67

68
69
use derive_builder::Builder;
use derive_getters::Getters;
70
71
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;

72
73
74
75
76
77
78
79
80
pub const MAX_CONCURRENT_TRANSFERS: usize = 4;
pub const MAX_TRANSFER_BATCH_SIZE: usize = 16;

/// Configuration for creating an OffloadManager
pub struct OffloadManagerConfig {
    pub nixl_agent: Arc<Option<NixlAgent>>,
    pub async_rt_handle: Handle,
    pub cancellation_token: CancellationToken,
    pub model_config: KvManagerModelConfig,
81
82
    /// Optional KVBM-level metrics for tracking offload/onboard operations
    pub kvbm_metrics: Option<crate::block_manager::metrics_kvbm::KvbmMetrics>,
83
84
    /// If true, offload directly from device (G1) to disk (G3), bypassing host (G2)
    pub bypass_cpu_mem: bool,
85
}
86
87

/// The offload manager handles all block transfers between different cache levels.
Ryan Olson's avatar
Ryan Olson committed
88
pub struct OffloadManager<Locality: LocalityProvider, Metadata: BlockMetadata> {
89
    // Handles to the device, host, and disk pools.
Ryan Olson's avatar
Ryan Olson committed
90
91
92
    disk: Option<Arc<dyn BlockPool<DiskStorage, Locality, Metadata>>>,
    host: Option<Arc<dyn BlockPool<PinnedStorage, Locality, Metadata>>>,
    device: Option<Arc<dyn BlockPool<DeviceStorage, Locality, Metadata>>>,
93

94
    /// Queue of offloading requests.
Ryan Olson's avatar
Ryan Olson committed
95
96
    device_offload_tx: mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Locality, Metadata>>,
    host_offload_tx: mpsc::UnboundedSender<OffloadRequest<PinnedStorage, Locality, Metadata>>,
97

98
99
100
101
    /// Queue of device-to-disk direct offloading requests (bypass CPU memory)
    device_to_disk_offload_tx:
        mpsc::UnboundedSender<OffloadRequest<DeviceStorage, Locality, Metadata>>,

102
    /// Queue of pending onboarding requests.
Ryan Olson's avatar
Ryan Olson committed
103
104
105
106
    host_onboard_tx:
        mpsc::UnboundedSender<OnboardRequest<PinnedStorage, DeviceStorage, Locality, Metadata>>,
    disk_onboard_tx:
        mpsc::UnboundedSender<OnboardRequest<DiskStorage, DeviceStorage, Locality, Metadata>>,
107
108

    /// An incrementing counter for offloaded blocks. Within the same priority, blocks with lower tick values are processed first.
109
    tick: Arc<AtomicU64>,
110
111
112

    /// If true, offload directly from device (G1) to disk (G3), bypassing host (G2)
    bypass_cpu_mem: bool,
113
114
}

Ryan Olson's avatar
Ryan Olson committed
115
116
117
impl<Locality: LocalityProvider + 'static, Metadata: BlockMetadata>
    OffloadManager<Locality, Metadata>
{
118
    #[allow(clippy::too_many_arguments)]
119
    pub fn new(
Ryan Olson's avatar
Ryan Olson committed
120
121
122
        disk: Option<Arc<dyn BlockPool<DiskStorage, Locality, Metadata>>>,
        host: Option<Arc<dyn BlockPool<PinnedStorage, Locality, Metadata>>>,
        device: Option<Arc<dyn BlockPool<DeviceStorage, Locality, Metadata>>>,
123
        filters: OffloadFilters,
124
        config: OffloadManagerConfig,
125
    ) -> Result<Arc<Self>> {
126
127
        let (device_offload_tx, device_offload_rx) = mpsc::unbounded_channel();
        let (host_offload_tx, host_offload_rx) = mpsc::unbounded_channel();
128
        let (device_to_disk_offload_tx, device_to_disk_offload_rx) = mpsc::unbounded_channel();
129
130
131

        let (host_onboard_tx, host_onboard_rx) = mpsc::unbounded_channel();
        let (disk_onboard_tx, disk_onboard_rx) = mpsc::unbounded_channel();
132
133

        let this = Arc::new(Self {
134
            disk,
135
            host,
136
137
138
            device,
            device_offload_tx,
            host_offload_tx,
139
            device_to_disk_offload_tx,
140
141
            host_onboard_tx,
            disk_onboard_tx,
142
            tick: Arc::new(AtomicU64::new(0)),
143
            bypass_cpu_mem: config.bypass_cpu_mem,
144
145
        });

146
        let cuda_ctx = Cuda::device_or_create(0)?;
147

148
149
150
151
152
153
154
155
        let pool_config = PoolConfig {
            enable_pool: true,
            max_concurrent_transfers: MAX_CONCURRENT_TRANSFERS,
            max_transfer_batch_size: MAX_TRANSFER_BATCH_SIZE,
            num_outer_components: config.model_config.outer_dim,
            num_layers: config.model_config.num_layers,
        };

156
157
        // We want cuda offloads to happen in parallel with host onboards, so we need to use a different stream.
        let device_offload_transfer_ctx = Arc::new(TransferContext::new(
158
            config.nixl_agent.clone(),
159
            cuda_ctx.new_stream()?,
160
161
            config.async_rt_handle.clone(),
            Some(pool_config),
162
        ));
163

164
        // Device -> Host offload
165
166
167
168
169
        let device_to_host_task = OffloadManager::offload_worker(
            this.device.clone(),
            this.host.clone(),
            device_offload_rx,
            Arc::new(TransferBatcher::new(
Ryan Olson's avatar
Ryan Olson committed
170
                LocalTransferManager::new(
171
172
                    device_offload_transfer_ctx,
                    MAX_CONCURRENT_TRANSFERS,
173
174
                    &config.async_rt_handle,
                    config.cancellation_token.clone(),
175
                )?,
176
                MAX_TRANSFER_BATCH_SIZE,
177
178
                &config.async_rt_handle,
                config.cancellation_token.clone(),
179
            )),
180
            filters.device.clone(),
181
182
183
184
            config
                .kvbm_metrics
                .as_ref()
                .map(|m| m.offload_blocks_d2h.clone()),
185
            config.cancellation_token.clone(),
186
187
188
        );
        CriticalTaskExecutionHandle::new_with_runtime(
            |_| device_to_host_task,
189
            config.cancellation_token.clone(),
190
            "Device -> Host offload worker",
191
            &config.async_rt_handle,
192
193
        )?
        .detach();
194

195
        let transfer_ctx = Arc::new(TransferContext::new(
196
            config.nixl_agent.clone(),
197
            cuda_ctx.new_stream()?,
198
199
            config.async_rt_handle.clone(),
            None,
200
        ));
201

202
        // Host -> Disk offload
203
204
205
206
207
        let host_to_disk_task = OffloadManager::offload_worker(
            this.host.clone(),
            this.disk.clone(),
            host_offload_rx,
            Arc::new(TransferBatcher::new(
Ryan Olson's avatar
Ryan Olson committed
208
                LocalTransferManager::new(
209
210
                    transfer_ctx.clone(),
                    MAX_CONCURRENT_TRANSFERS,
211
212
                    &config.async_rt_handle,
                    config.cancellation_token.clone(),
213
                )?,
214
                MAX_TRANSFER_BATCH_SIZE,
215
216
                &config.async_rt_handle,
                config.cancellation_token.clone(),
217
            )),
218
            filters.host.clone(),
219
220
221
222
            config
                .kvbm_metrics
                .as_ref()
                .map(|m| m.offload_blocks_h2d.clone()),
223
            config.cancellation_token.clone(),
224
225
226
        );
        CriticalTaskExecutionHandle::new_with_runtime(
            |_| host_to_disk_task,
227
            config.cancellation_token.clone(),
228
            "Host -> Disk offload worker",
229
            &config.async_rt_handle,
230
231
        )?
        .detach();
232

233
        // Host -> Device onboarding
234
235
236
237
238
        let host_to_device_task = OffloadManager::onboard_worker(
            this.host.clone(),
            this.device.clone(),
            host_onboard_rx,
            Arc::new(TransferBatcher::new(
Ryan Olson's avatar
Ryan Olson committed
239
                LocalTransferManager::new(
240
241
                    transfer_ctx.clone(),
                    MAX_CONCURRENT_TRANSFERS,
242
243
                    &config.async_rt_handle,
                    config.cancellation_token.clone(),
244
                )?,
245
                MAX_TRANSFER_BATCH_SIZE,
246
247
                &config.async_rt_handle,
                config.cancellation_token.clone(),
248
            )),
249
            config.cancellation_token.clone(),
250
251
252
        );
        CriticalTaskExecutionHandle::new_with_runtime(
            |_| host_to_device_task,
253
            config.cancellation_token.clone(),
254
            "Host -> Device onboarding worker",
255
            &config.async_rt_handle,
256
257
        )?
        .detach();
258

259
        // Disk -> Device onboarding
260
261
262
263
264
        let disk_to_device_task = OffloadManager::onboard_worker(
            this.disk.clone(),
            this.device.clone(),
            disk_onboard_rx,
            Arc::new(TransferBatcher::new(
Ryan Olson's avatar
Ryan Olson committed
265
                LocalTransferManager::new(
266
267
                    transfer_ctx.clone(),
                    MAX_CONCURRENT_TRANSFERS,
268
269
                    &config.async_rt_handle,
                    config.cancellation_token.clone(),
270
                )?,
271
                MAX_TRANSFER_BATCH_SIZE,
272
273
                &config.async_rt_handle,
                config.cancellation_token.clone(),
274
            )),
275
            config.cancellation_token.clone(),
276
277
278
        );
        CriticalTaskExecutionHandle::new_with_runtime(
            |_| disk_to_device_task,
279
            config.cancellation_token.clone(),
280
            "Disk -> Device onboarding worker",
281
            &config.async_rt_handle,
282
283
        )?
        .detach();
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        // Device -> Disk direct offload (bypass CPU memory)
        if config.bypass_cpu_mem {
            tracing::info!(
                "G1->G3 direct offload enabled: Device will offload directly to Disk, bypassing Host memory (CPU cache disabled)"
            );

            let device_to_disk_task = OffloadManager::offload_worker(
                this.device.clone(),
                this.disk.clone(),
                device_to_disk_offload_rx,
                Arc::new(TransferBatcher::new(
                    LocalTransferManager::new(
                        transfer_ctx.clone(),
                        MAX_CONCURRENT_TRANSFERS,
                        &config.async_rt_handle,
                        config.cancellation_token.clone(),
                    )?,
                    MAX_TRANSFER_BATCH_SIZE,
                    &config.async_rt_handle,
                    config.cancellation_token.clone(),
                )),
                filters.device.clone(),
                config
                    .kvbm_metrics
                    .as_ref()
                    .map(|m| m.offload_blocks_d2d.clone()),
                config.cancellation_token.clone(),
            );
            CriticalTaskExecutionHandle::new_with_runtime(
                |_| device_to_disk_task,
                config.cancellation_token.clone(),
                "Device -> Disk direct offload worker (bypass CPU)",
                &config.async_rt_handle,
            )?
            .detach();
        }

322
        Ok(this)
323
    }
324

325
    async fn offload_worker<Source: Storage, Target: Storage>(
Ryan Olson's avatar
Ryan Olson committed
326
327
328
329
        source_pool: Option<Arc<dyn BlockPool<Source, Locality, Metadata>>>,
        target_pool: Option<Arc<dyn BlockPool<Target, Locality, Metadata>>>,
        mut offload_rx: mpsc::UnboundedReceiver<OffloadRequest<Source, Locality, Metadata>>,
        transfer_manager: Arc<dyn TransferManager<Source, Target, Locality, Metadata>>,
330
        offload_filter: Option<Arc<dyn OffloadFilter>>,
331
        offload_metric: Option<prometheus::IntCounter>,
332
        cancellation_token: CancellationToken,
333
    ) -> Result<()> {
334
        if source_pool.is_none() || target_pool.is_none() {
335
336
337
            return Ok(());
        }

338
339
        let source_pool = source_pool.as_ref().unwrap();
        let target_pool = target_pool.as_ref().unwrap();
340

341
        let mut queue = BTreeSet::new();
342
343

        loop {
344
345
346
347
            if cancellation_token.is_cancelled() {
                return Ok(());
            }

348
            // Try to check the offload queue.
349
350
351
352
353
354
355
356
            loop {
                match offload_rx.try_recv() {
                    Ok(request) => {
                        queue.insert(request);
                    }
                    Err(TryRecvError::Empty) => {
                        break;
                    }
357
                    Err(e) => return Err(e.into()),
358
359
                }
            }
360
361

            // If there is a request, process it.
362
            if let Some(request) = queue.pop_first() {
363
364
                // Try to upgrade the block to a strong reference.
                let block = match request.block.upgrade() {
Ryan Olson's avatar
Ryan Olson committed
365
                    Some(block) => Some(ImmutableBlock::new(block)),
366
                    // If unable to upgrade, the block may have been moved to the inactive pool.
367
                    None => source_pool
368
369
                        .match_sequence_hashes(vec![request.sequence_hash].as_slice())
                        .await?
Ryan Olson's avatar
Ryan Olson committed
370
                        .pop(),
371
372
                };

373
                // If we've found the block, offload it.
374
                if let Some(block) = block {
375
376
                    // If the block is already in the target, don't offload it.
                    if let Ok(blocks) = target_pool
Ryan Olson's avatar
Ryan Olson committed
377
378
                        .match_sequence_hashes(vec![request.sequence_hash].as_slice())
                        .await
379
                        && !blocks.is_empty()
380
                    {
381
                        continue;
382
383
                    }

384
385
386
387
388
389
                    if let Some(offload_filter) = offload_filter.as_ref()
                        && !offload_filter.should_offload(request.sequence_hash)
                    {
                        continue;
                    }

390
                    let target_block = 'target_block: {
391
392
393
394
                        if let Ok(blocks) = target_pool.allocate_blocks(1).await
                            && let Some(block) = blocks.into_iter().next()
                        {
                            break 'target_block Some(block);
395
                        }
396

397
398
399
                        tracing::warn!(
                            "Target pool full. Skipping offload. This should only ever happen with very small pool sizes."
                        );
400
                        None
401
402
                    };

403
                    if let Some(target_block) = target_block {
Ryan Olson's avatar
Ryan Olson committed
404
405
406
407
                        tracing::debug!(
                            "Offloading block with sequence hash {} to target pool.",
                            request.sequence_hash
                        );
408
409
410
411
412
413

                        // Track the offload metric if available
                        if let Some(ref metric) = offload_metric {
                            metric.inc();
                        }

414
                        transfer_manager
415
                            .enqueue_transfer(PendingTransfer::new(
416
                                vec![block],
417
                                vec![target_block],
418
                                None,
419
                                target_pool.clone(),
420
421
422
423
424
                            ))
                            .await?;
                    }
                }
            } else {
425
                // Await the next request.
426
427
428
429
430
                tokio::select! {
                    _ = cancellation_token.cancelled() => return Ok(()),
                    Some(request) = offload_rx.recv() => {
                        queue.insert(request);
                    }
431
                }
432
433
434
435
            }
        }
    }

436
    async fn onboard_worker<Source: Storage, Target: Storage>(
Ryan Olson's avatar
Ryan Olson committed
437
438
439
440
        source_pool: Option<Arc<dyn BlockPool<Source, Locality, Metadata>>>,
        target_pool: Option<Arc<dyn BlockPool<Target, Locality, Metadata>>>,
        mut onboard_rx: mpsc::UnboundedReceiver<OnboardRequest<Source, Target, Locality, Metadata>>,
        transfer_manager: Arc<dyn TransferManager<Source, Target, Locality, Metadata>>,
441
        cancellation_token: CancellationToken,
442
    ) -> Result<()> {
443
        if source_pool.is_none() || target_pool.is_none() {
444
445
446
            return Ok(());
        }

447
        let target_pool = target_pool.as_ref().unwrap();
448
449
450
451
        loop {
            tokio::select! {
                _ = cancellation_token.cancelled() => return Ok::<(), anyhow::Error>(()),
                Some(request) = onboard_rx.recv() => {
452

453
                    // Try to allocate blocks on the device.
Ryan Olson's avatar
Ryan Olson committed
454
455
456
457
458
459
460
461
462
                    let target_blocks = if let Some(targets) = request.targets {
                        targets
                    } else {
                            match target_pool.allocate_blocks(request.blocks.len()).await {
                            Ok(blocks) => blocks,
                            Err(err) => {
                                let _ = request.response_tx.send(Err(err));
                                continue;
                            }
463
464
                        }
                    };
465

Ryan Olson's avatar
Ryan Olson committed
466
                    tracing::debug!("Onboarding {} blocks to target pool.", request.blocks.len());
467
468
469

                    transfer_manager
                        .enqueue_transfer(PendingTransfer::new(
Ryan Olson's avatar
Ryan Olson committed
470
                            request.blocks,
471
472
473
474
475
476
477
                            target_blocks,
                            Some(request.response_tx),
                            target_pool.clone(),
                        ))
                        .await?;

                    Ok::<(), anyhow::Error>(())
478
                }
479
            }?;
480
481
482
483
484
        }
    }

    pub async fn offload<S: Storage>(
        &self,
Ryan Olson's avatar
Ryan Olson committed
485
        block: &ImmutableBlock<S, Locality, Metadata>,
486
487
488
        priority: u64,
    ) -> core::result::Result<(), BlockPoolError> {
        match block.state() {
489
            BlockState::Registered(_, _) => {}
490
491
492
493
494
495
            _ => {
                return Err(BlockPoolError::BlockError(BlockError::InvalidState(
                    "Block is not registered.".to_string(),
                )));
            }
        }
496

497
        let tick = self.tick.fetch_add(1, Ordering::Relaxed);
498
499
        let key = OffloadRequestKey {
            priority,
500
            timestamp: tick,
501
502
        };

503
504
505
506
507
508
        // This can get called by all pools, regardless of whether or not they have a place to offload to.
        // Because of this, we need to check the block type here.
        let any_block = block as &dyn Any;

        // TODO: What's the performance penalty of this runtime type-checking?
        if let Some(device_block) =
Ryan Olson's avatar
Ryan Olson committed
509
            any_block.downcast_ref::<ImmutableBlock<DeviceStorage, Locality, Metadata>>()
510
        {
511
512
513
514
515
516
            // Check if we should bypass CPU memory and go directly to disk
            if self.bypass_cpu_mem && self.disk.is_some() {
                // Offload directly from Device (G1) to Disk (G3), bypassing Host (G2)
                if self.device_to_disk_offload_tx.is_closed() {
                    return Ok(());
                }
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
                let request = OffloadRequest {
                    block: Arc::downgrade(device_block.mutable_block()),
                    sequence_hash: device_block.sequence_hash(),
                    key,
                };

                tracing::debug!(
                    "Offloading device block {} directly to disk (bypassing host memory)",
                    device_block.sequence_hash()
                );
                self.device_to_disk_offload_tx.send(request).unwrap();
            } else {
                // Standard path: Device (G1) -> Host (G2)
                if self.device_offload_tx.is_closed() {
                    return Ok(());
                }
534

535
536
537
538
539
540
541
542
                let request = OffloadRequest {
                    block: Arc::downgrade(device_block.mutable_block()),
                    sequence_hash: device_block.sequence_hash(),
                    key,
                };

                self.device_offload_tx.send(request).unwrap();
            }
543
        } else if let Some(host_block) =
Ryan Olson's avatar
Ryan Olson committed
544
            any_block.downcast_ref::<ImmutableBlock<PinnedStorage, Locality, Metadata>>()
545
        {
546
            // Host (G2) -> Disk (G3) offload
547
548
549
550
551
552
            if self.host_offload_tx.is_closed() {
                return Ok(());
            }

            let request = OffloadRequest {
                block: Arc::downgrade(host_block.mutable_block()),
Ryan Olson's avatar
Ryan Olson committed
553
                sequence_hash: host_block.sequence_hash(),
554
555
556
557
                key,
            };

            self.host_offload_tx.send(request).unwrap();
558
559
560
561
562
        }

        Ok(())
    }

Ryan Olson's avatar
Ryan Olson committed
563
    pub fn onboard<S: Storage>(
564
        &self,
Ryan Olson's avatar
Ryan Olson committed
565
566
567
568
        blocks: Vec<ImmutableBlock<S, Locality, Metadata>>,
        targets: Option<Vec<MutableBlock<DeviceStorage, Locality, Metadata>>>,
    ) -> oneshot::Receiver<BlockResult<DeviceStorage, Locality, Metadata>> {
        let (tx, rx) = oneshot::channel();
569
570
        for block in &blocks {
            match block.state() {
571
                BlockState::Registered(_, _) => {}
572
                _ => {
Ryan Olson's avatar
Ryan Olson committed
573
                    tx.send(Err(BlockPoolError::BlockError(BlockError::InvalidState(
574
                        "Block is not registered.".to_string(),
Ryan Olson's avatar
Ryan Olson committed
575
576
577
                    ))))
                    .unwrap();
                    return rx;
578
579
580
581
                }
            }
        }

582
583
584
585
586
587
588
589
        if let Some(targets) = targets.as_ref()
            && targets.len() != blocks.len()
        {
            tx.send(Err(BlockPoolError::BlockError(BlockError::Other(
                anyhow::anyhow!("Number of targets does not match number of blocks."),
            ))))
            .unwrap();
            return rx;
590
591
        }

Ryan Olson's avatar
Ryan Olson committed
592
593
594
595
        if blocks.is_empty() {
            tx.send(Ok(vec![])).unwrap();
            return rx;
        }
596

597
598
599
600
        let any_block = blocks.first().unwrap() as &dyn Any;

        // TODO: This is really ugly.
        if any_block
Ryan Olson's avatar
Ryan Olson committed
601
            .downcast_ref::<ImmutableBlock<PinnedStorage, Locality, Metadata>>()
602
603
604
605
606
607
            .is_some()
        {
            let host_blocks = blocks
                .iter()
                .map(|b| {
                    (b as &dyn Any)
Ryan Olson's avatar
Ryan Olson committed
608
                        .downcast_ref::<ImmutableBlock<PinnedStorage, Locality, Metadata>>()
609
610
611
612
613
                        .unwrap()
                        .clone()
                })
                .collect();

Ryan Olson's avatar
Ryan Olson committed
614
615
616
617
618
619
620
621
            if let Err(e) = self
                .host_onboard_tx
                .send(OnboardRequest::new(host_blocks, tx, targets))
            {
                e.0.response_tx
                    .send(Err(BlockPoolError::ProgressEngineShutdown))
                    .unwrap();
            }
622
        } else if any_block
Ryan Olson's avatar
Ryan Olson committed
623
            .downcast_ref::<ImmutableBlock<DiskStorage, Locality, Metadata>>()
624
625
626
627
628
629
            .is_some()
        {
            let disk_blocks = blocks
                .iter()
                .map(|b| {
                    (b as &dyn Any)
Ryan Olson's avatar
Ryan Olson committed
630
                        .downcast_ref::<ImmutableBlock<DiskStorage, Locality, Metadata>>()
631
632
633
634
635
                        .unwrap()
                        .clone()
                })
                .collect();

Ryan Olson's avatar
Ryan Olson committed
636
637
638
639
640
641
642
643
            if let Err(e) = self
                .disk_onboard_tx
                .send(OnboardRequest::new(disk_blocks, tx, targets))
            {
                e.0.response_tx
                    .send(Err(BlockPoolError::ProgressEngineShutdown))
                    .unwrap();
            }
644
        } else {
Ryan Olson's avatar
Ryan Olson committed
645
            tx.send(Err(BlockPoolError::BlockError(BlockError::Other(
646
                anyhow::anyhow!("Block type not supported for onboarding."),
Ryan Olson's avatar
Ryan Olson committed
647
648
            ))))
            .unwrap();
649
650
        }

Ryan Olson's avatar
Ryan Olson committed
651
        rx
652
653
654
    }
}

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
#[derive(Debug, Clone, Getters, Builder)]
#[builder(pattern = "owned", build_fn(validate = "Self::validate"))]
pub struct OffloadFilters {
    #[builder(default)]
    device: Option<Arc<dyn OffloadFilter>>,
    #[builder(default)]
    host: Option<Arc<dyn OffloadFilter>>,
    #[builder(default)]
    disk: Option<Arc<dyn OffloadFilter>>,
}

impl OffloadFilters {
    pub fn builder() -> OffloadFiltersBuilder {
        OffloadFiltersBuilder::default()
    }
}

impl OffloadFiltersBuilder {
    pub fn validate(&self) -> Result<(), String> {
        if let Some(disk) = self.disk.as_ref()
            && disk.is_some()
        {
            return Err("Disk offload filter is not supported.".to_string());
        }

        let host_is_none = if let Some(host) = self.host.as_ref() {
            host.is_none()
        } else {
            true
        };

        if host_is_none {
            tracing::warn!(
                "Host to Disk offload filter is not provided. All blocks in host will be offloaded to disk. This may result in excessive disk offloading and accelerated SSD degradation."
            );
        }

        Ok(())
    }
}

696
#[cfg(all(test, feature = "testing-cuda", feature = "testing-nixl"))]
Ryan Olson's avatar
Ryan Olson committed
697
mod tests {
698
699
700
    use super::*;

    use crate::block_manager::{
701
        LayoutConfig, NixlRegisterableStorage,
702
        block::{
703
            BasicMetadata, BlockDataExt, BlockDataProvider, Blocks, MutableBlock, locality::Local,
704
        },
705
        layout::{FullyContiguous, LayerSeparate, LayoutType, nixl::NixlLayout},
Ryan Olson's avatar
Ryan Olson committed
706
        pool::{BlockRegistrationDuplicationSetting, ManagedBlockPool},
707
        storage::{
708
            DeviceAllocator, DeviceStorage, DiskAllocator, DiskStorage, PinnedAllocator,
Ryan Olson's avatar
Ryan Olson committed
709
            PinnedStorage, StorageAllocator, StorageType,
710
711
        },
    };
712
    use crate::tokens::{TokenBlockSequence, Tokens};
713
    use nixl_sys::{MemoryRegion, NixlDescriptor};
714

715
    use aligned_vec::avec;
716
    use cudarc::runtime::sys::{cudaDeviceSynchronize, cudaMemcpy, cudaMemcpyKind, cudaMemset};
Ryan Olson's avatar
Ryan Olson committed
717
    use rstest::*;
718
    use std::fs::File;
719
    use std::io::{Read, Seek, SeekFrom, Write};
720
721
    use std::mem::ManuallyDrop;
    use std::os::unix::io::FromRawFd;
722
723

    const BLOCK_SIZE: usize = 4;
724
    const NUM_LAYERS: usize = 8;
725

Ryan Olson's avatar
Ryan Olson committed
726
727
728
    type DevicePool = Option<Arc<dyn BlockPool<DeviceStorage, Local, BasicMetadata>>>;
    type HostPool = Option<Arc<dyn BlockPool<PinnedStorage, Local, BasicMetadata>>>;
    type DiskPool = Option<Arc<dyn BlockPool<DiskStorage, Local, BasicMetadata>>>;
729
730
731
732
733

    lazy_static::lazy_static! {
        static ref NIXL_AGENT: Arc<Option<NixlAgent>> = {
            let agent = NixlAgent::new("offload-manager").unwrap();
            let (_, ucx_params) = agent.get_plugin_params("UCX").unwrap();
Ryan Olson's avatar
Ryan Olson committed
734
            let (_, gds_mt_params) = agent.get_plugin_params("GDS_MT").unwrap();
735
            let (_, posix_params) = agent.get_plugin_params("POSIX").unwrap();
736
            agent.create_backend("UCX", &ucx_params).unwrap();
Ryan Olson's avatar
Ryan Olson committed
737
            agent.create_backend("GDS_MT", &gds_mt_params).unwrap();
738
            agent.create_backend("POSIX", &posix_params).unwrap();
739
740
741
            Arc::new(Some(agent))
        };
    }
742

Ryan Olson's avatar
Ryan Olson committed
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    fn build_layout<S: Storage + NixlRegisterableStorage>(
        config: LayoutConfig,
        layout_type: LayoutType,
        agent: &NixlAgent,
        allocator: &dyn StorageAllocator<S>,
        duplication_setting: BlockRegistrationDuplicationSetting,
    ) -> Result<Arc<dyn BlockPool<S, Local, BasicMetadata>>> {
        match layout_type {
            LayoutType::FullyContiguous => {
                let mut pool_layout = FullyContiguous::allocate(config.clone(), allocator)?;
                pool_layout.nixl_register(agent, None)?;
                let blocks = Blocks::new(pool_layout, 42, 0)?.into_blocks()?;
                Ok(Arc::new(
                    ManagedBlockPool::builder()
                        .blocks(blocks)
                        .default_duplication_setting(duplication_setting)
                        .build()?,
                ))
            }
            LayoutType::LayerSeparate { outer_contiguous } => {
                let mut pool_layout =
                    LayerSeparate::allocate(config.clone(), allocator, outer_contiguous)?;
                pool_layout.nixl_register(agent, None)?;
                let blocks = Blocks::new(pool_layout, 42, 0)?.into_blocks()?;
                Ok(Arc::new(
                    ManagedBlockPool::builder()
                        .blocks(blocks)
                        .default_duplication_setting(duplication_setting)
                        .build()?,
                ))
            }
        }
    }

    #[allow(clippy::type_complexity)]
    fn build_pools(
779
780
        device_blocks: usize,
        host_blocks: Option<usize>,
781
        disk_blocks: Option<usize>,
782
        inner_dim: Option<usize>,
783
    ) -> Result<(
Ryan Olson's avatar
Ryan Olson committed
784
785
786
787
788
789
790
791
792
793
794
795
        Arc<OffloadManager<Local, BasicMetadata>>,
        DevicePool,
        HostPool,
        DiskPool,
    )> {
        build_pools_with_layout(
            device_blocks,
            host_blocks,
            disk_blocks,
            inner_dim,
            LayoutType::FullyContiguous,
            BlockRegistrationDuplicationSetting::Disabled,
796
            false,
Ryan Olson's avatar
Ryan Olson committed
797
798
799
800
801
802
803
804
805
806
807
        )
    }

    #[allow(clippy::type_complexity)]
    pub fn build_pools_with_layout(
        device_blocks: usize,
        host_blocks: Option<usize>,
        disk_blocks: Option<usize>,
        inner_dim: Option<usize>,
        layout_type: LayoutType,
        duplication_setting: BlockRegistrationDuplicationSetting,
808
        bypass_cpu_mem: bool,
Ryan Olson's avatar
Ryan Olson committed
809
810
    ) -> Result<(
        Arc<OffloadManager<Local, BasicMetadata>>,
811
812
813
814
        DevicePool,
        HostPool,
        DiskPool,
    )> {
815
816
        let mut config = LayoutConfig {
            num_blocks: device_blocks,
817
            num_layers: NUM_LAYERS,
818
            outer_dim: 1,
819
            page_size: BLOCK_SIZE,
820
            inner_dim: inner_dim.unwrap_or(1024),
821
            alignment: 1,
Ryan Olson's avatar
Ryan Olson committed
822
            dtype_width_bytes: 2,
823
824
        };

825
826
827
        let agent_arc = NIXL_AGENT.clone();
        let agent = agent_arc.as_ref().as_ref().unwrap();

Ryan Olson's avatar
Ryan Olson committed
828
829
830
831
832
833
834
        let device_pool = Some(build_layout(
            config.clone(),
            layout_type,
            agent,
            &DeviceAllocator::default(),
            duplication_setting,
        )?);
835
836
837

        let host_pool = if let Some(host_blocks) = host_blocks {
            config.num_blocks = host_blocks;
Ryan Olson's avatar
Ryan Olson committed
838
839
840
841
842
843
844
            Some(build_layout(
                config.clone(),
                layout_type,
                agent,
                &PinnedAllocator::default(),
                duplication_setting,
            )?)
845
        } else {
846
            None
847
848
        };

849
850
        let disk_pool = if let Some(disk_blocks) = disk_blocks {
            config.num_blocks = disk_blocks;
Ryan Olson's avatar
Ryan Olson committed
851
            Some(build_layout(
852
                config.clone(),
Ryan Olson's avatar
Ryan Olson committed
853
854
855
856
857
                layout_type,
                agent,
                &DiskAllocator,
                duplication_setting,
            )?)
858
        } else {
859
            None
860
        };
861

862
863
        let async_rt_handle = Handle::current();

864
865
866
867
868
869
870
871
872
873
874
875
876
        let minimal_config = KvManagerModelConfig::builder()
            .num_layers(config.num_layers)
            .outer_dim(config.outer_dim) // K and V
            .page_size(config.page_size) // Minimal page size
            .inner_dim(config.inner_dim) // Small inner dim
            .build()
            .expect("Failed to build minimal config");

        let config = OffloadManagerConfig {
            nixl_agent: agent_arc,
            async_rt_handle,
            cancellation_token: CancellationToken::new(),
            model_config: minimal_config,
877
            kvbm_metrics: None,
878
            bypass_cpu_mem,
879
880
        };

881
882
883
884
        let manager = OffloadManager::new(
            disk_pool.clone(),
            host_pool.clone(),
            device_pool.clone(),
885
            OffloadFilters::builder().build()?,
886
            config,
887
888
889
        )?;

        Ok((manager, device_pool, host_pool, disk_pool))
890
891
892
    }

    /// Create a block in the 'RESET' state.
Ryan Olson's avatar
Ryan Olson committed
893
    #[expect(dead_code)]
894
    async fn get_block<S: Storage, Metadata: BlockMetadata>(
Ryan Olson's avatar
Ryan Olson committed
895
896
897
898
        pool: &Arc<dyn BlockPool<S, Local, Metadata>>,
    ) -> Result<MutableBlock<S, Local, Metadata>> {
        let mut blocks = pool.allocate_blocks(1).await?;
        Ok(blocks.pop().unwrap())
899
900
901
902
    }

    /// Create a block in the 'COMPLETED' state.
    async fn completed_block<S: Storage, Metadata: BlockMetadata>(
Ryan Olson's avatar
Ryan Olson committed
903
        pool: &Arc<dyn BlockPool<S, Local, Metadata>>,
904
        tokens: [u32; BLOCK_SIZE],
Ryan Olson's avatar
Ryan Olson committed
905
906
907
908
909
910
911
912
    ) -> Result<MutableBlock<S, Local, Metadata>> {
        let mut block = pool
            .allocate_blocks(1)
            .await?
            .into_iter()
            .next()
            .ok_or(anyhow::anyhow!("Failed to allocate block"))?;

913
914
915
916
917
918
919
920
        block.init_sequence(42)?;
        for token in tokens {
            block.add_token(token)?;
        }
        block.commit()?;
        Ok(block)
    }

921
    fn populate_block<S: Storage + NixlDescriptor>(
922
        block: &impl BlockDataProvider<StorageType = S>,
Ryan Olson's avatar
Ryan Olson committed
923
        start_value: u8,
924
    ) -> Result<()> {
Ryan Olson's avatar
Ryan Olson committed
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
        let block_data = block.block_data();

        let mut value = start_value;

        for layer_idx in 0..block_data.num_layers() {
            for outer_idx in 0..block_data.num_outer_dims() {
                let layer_view = block_data.layer_view(layer_idx, outer_idx)?;
                match block_data.storage_type() {
                    StorageType::Device(_) | StorageType::Pinned => unsafe {
                        cudaMemset(
                            layer_view.as_ptr() as *mut std::ffi::c_void,
                            value as i32,
                            layer_view.size(),
                        )
                        .result()?;
                    },
                    StorageType::Disk(_) => {
                        let nixl_desc = layer_view.as_nixl_descriptor();
                        let mut file: ManuallyDrop<File>;
                        let data = avec![[4096] | value; layer_view.size()];

                        unsafe {
                            file =
                                ManuallyDrop::new(File::from_raw_fd(nixl_desc.device_id() as i32));
                            file.seek(SeekFrom::Start(nixl_desc.as_ptr() as u64))?;
                        }
                        file.write_all(&data)?;
                        file.sync_all()?;
                        file.flush()?;
                    }
                    _ => panic!(),
956
957
                }
            }
Ryan Olson's avatar
Ryan Olson committed
958
959

            value += 1;
960
        }
961

962
963
964
        Ok(())
    }

965
966
    fn get_block_contents<S: Storage + NixlDescriptor>(
        block: &impl BlockDataProvider<StorageType = S>,
Ryan Olson's avatar
Ryan Olson committed
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    ) -> Result<Vec<Vec<u8>>> {
        let block_data = block.block_data();

        let mut contents: Vec<Vec<u8>> = Vec::new();

        for layer_idx in 0..block_data.num_layers() {
            for outer_idx in 0..block_data.num_outer_dims() {
                let layer_view = block_data.layer_view(layer_idx, outer_idx)?;
                match block_data.storage_type() {
                    StorageType::Device(_) => unsafe {
                        let mut buffer = vec![0_u8; layer_view.size()];

                        cudaMemcpy(
                            buffer.as_mut_ptr() as *mut std::ffi::c_void,
                            layer_view.as_ptr() as *const std::ffi::c_void,
                            layer_view.size(),
                            cudaMemcpyKind::cudaMemcpyDeviceToHost,
                        )
                        .result()?;

                        contents.push(buffer);
                    },
                    StorageType::Pinned => unsafe {
                        contents.push(
                            std::slice::from_raw_parts(layer_view.as_ptr(), layer_view.size())
                                .to_vec(),
                        );
                    },
                    StorageType::Disk(_) => {
                        let nixl_desc = layer_view.as_nixl_descriptor();
                        let mut file: ManuallyDrop<File>;
                        let mut aligned = avec![[4096] | 0; layer_view.size()];

                        unsafe {
                            file =
                                ManuallyDrop::new(File::from_raw_fd(nixl_desc.device_id() as i32));
                            file.seek(SeekFrom::Start(nixl_desc.as_ptr() as u64))?;
                        }
                        file.read_exact(&mut aligned)?;
                        contents.push(aligned.to_vec());
                    }
                    _ => anyhow::bail!("Unsupported storage type."),
1009
1010
                }
            }
1011
1012
        }

Ryan Olson's avatar
Ryan Olson committed
1013
        Ok(contents)
1014
1015
    }

1016
    fn check_block_contents(
1017
1018
        block1: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
        block2: &impl BlockDataProvider<StorageType = impl Storage + NixlDescriptor>,
Ryan Olson's avatar
Ryan Olson committed
1019
        start_value: u8,
1020
    ) -> Result<()> {
1021
1022
        let contents1 = get_block_contents(block1)?;
        let contents2 = get_block_contents(block2)?;
1023

Ryan Olson's avatar
Ryan Olson committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
        assert_eq!(contents1.len(), contents2.len());

        let mut value = start_value;

        for (layer1_vec, layer2_vec) in contents1.iter().zip(contents2.iter()) {
            for (c1_value, c2_value) in layer1_vec.iter().zip(layer2_vec.iter()) {
                if c1_value != c2_value || c1_value != &value {
                    panic!("{} != {} != {}", c1_value, c2_value, value);
                }
1033
            }
Ryan Olson's avatar
Ryan Olson committed
1034
            value += 1;
1035
        }
1036
1037
1038
1039
1040
        Ok(())
    }

    #[tokio::test]
    async fn test_offload_invalid_blocks() -> Result<()> {
1041
        let (offload_manager, device_pool, _, _) = build_pools(4, Some(4), None, None)?;
1042

1043
        let device_pool = device_pool.as_ref().unwrap();
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057

        // Check blocks in the 'COMPLETED' state.
        let immutable_block = ImmutableBlock::new(Arc::new(
            completed_block(device_pool, [0; BLOCK_SIZE]).await?,
        ));
        assert!(matches!(
            offload_manager.offload(&immutable_block, 0).await,
            Err(BlockPoolError::BlockError(BlockError::InvalidState(_)))
        ));

        Ok(())
    }

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_offload_registered_blocks(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools_with_layout(
            4,
            Some(4),
            None,
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1070
            false,
Ryan Olson's avatar
Ryan Olson committed
1071
        )?;
1072

1073
1074
        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085

        // Create a block and register it with the offload manager
        let block = completed_block(device_pool, [0, 1, 2, 3]).await?;

        let immutable_device_block = device_pool
            .register_blocks(vec![block])
            .await?
            .into_iter()
            .next()
            .ok_or(anyhow::anyhow!("Failed to register block"))?;

1086
        populate_block(&immutable_device_block, 42)?;
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097

        // Offloads should only go to G2 (for now)
        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for it to be processed.
        // TODO: This is a bit of a hack, and may lead to non-deterministic behavior.
        // In theory, the offload + memcpy should take much less time than this.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // Check that the block exists in the host pool
        let host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
1098
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
1099
1100
1101
1102
            .await?;

        assert_eq!(host_blocks.len(), 1);
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1103
1104
            host_blocks[0].sequence_hash(),
            immutable_device_block.sequence_hash()
1105
1106
        );

1107
        check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;
1108
1109
1110
1111

        Ok(())
    }

1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
    #[tokio::test]
    async fn test_offload_device_to_disk_bypass_cpu() -> Result<()> {
        let (offload_manager, device_pool, host_pool, disk_pool) = build_pools_with_layout(
            4,
            Some(4),
            Some(4),
            None,
            LayoutType::FullyContiguous,
            BlockRegistrationDuplicationSetting::Disabled,
            true,
        )?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
        let disk_pool = disk_pool.as_ref().unwrap();

        // Create a block and register it with the offload manager
        let block = completed_block(device_pool, [0, 1, 2, 3]).await?;

        let immutable_device_block = device_pool
            .register_blocks(vec![block])
            .await?
            .into_iter()
            .next()
            .ok_or(anyhow::anyhow!("Failed to register block"))?;

        populate_block(&immutable_device_block, 42)?;

        // Synchronize ALL CUDA streams to ensure populate_block completes before offload starts
        // This is critical because cudaMemset uses the default stream, but GDS transfer uses a different stream
        unsafe {
            cudaDeviceSynchronize().result()?;
        }

        // Offloads should only go to G3 directly since bypass_cpu_mem is true in offload_manager config
        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for it to be processed.
        // TODO: This is a bit of a hack, and may lead to non-deterministic behavior.
        // In theory, the offload + memcpy should take much less time than this.
        tokio::time::sleep(std::time::Duration::from_millis(1000)).await;

        // Check that the block exists in the host pool
        let disk_blocks = disk_pool
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
            .await?;

        assert_eq!(disk_blocks.len(), 1);
        assert_eq!(
            disk_blocks[0].sequence_hash(),
            immutable_device_block.sequence_hash()
        );

        check_block_contents(&immutable_device_block, &disk_blocks[0], 42)?;

        let host_blocks = host_pool
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
            .await?;

        // since host is bypassed, there should be no host blocks
        assert_eq!(host_blocks.len(), 0);

        Ok(())
    }

1177
1178
    #[tokio::test]
    async fn test_no_host_blocks_available() -> Result<()> {
1179
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
1180

1181
1182
        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201

        let host_blocks = host_pool.allocate_blocks(4).await?;
        assert_eq!(host_blocks.len(), 4);

        let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for offload to be processed.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // The offload should fail gracefuly due to a lack of host blocks
        let matched_host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
1202
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
            .await?;
        assert_eq!(matched_host_blocks.len(), 0);

        // Wait for blocks to be returned to the pool.
        drop(host_blocks);
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // Try the offload again.
        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for offload to be processed.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // This time, the offload should succeed.
        let matched_host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
1218
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
1219
1220
1221
1222
1223
1224
1225
            .await?;
        assert_eq!(matched_host_blocks.len(), 1);

        Ok(())
    }

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_onboard(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools_with_layout(
            4,
            Some(4),
            None,
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1238
            false,
Ryan Olson's avatar
Ryan Olson committed
1239
        )?;
1240

1241
1242
        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252

        // Allocate and fill a block on the host.
        let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
        let immutable_host_block = host_pool
            .register_blocks(vec![host_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

1253
        populate_block(&immutable_host_block, 42)?;
1254
1255
1256

        // Onboard the block.
        let onboarded_blocks = offload_manager
Ryan Olson's avatar
Ryan Olson committed
1257
1258
            .onboard(vec![immutable_host_block.clone()], None)
            .await??;
1259
1260
1261
1262

        assert_eq!(onboarded_blocks.len(), 1);
        // Check that the sequence hash is the same.
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1263
1264
            onboarded_blocks[0].sequence_hash(),
            immutable_host_block.sequence_hash()
1265
1266
1267
1268
        );
        // Check that the block is registered.
        assert!(matches!(
            onboarded_blocks[0].state(),
1269
            BlockState::Registered(_, _)
1270
1271
        ));

1272
        check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
1273
1274
1275
1276

        // Wait for the new value to show up in the device pool.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        let device_blocks = device_pool
Ryan Olson's avatar
Ryan Olson committed
1277
            .match_sequence_hashes(vec![onboarded_blocks[0].sequence_hash()].as_slice())
1278
1279
1280
            .await?;
        assert_eq!(device_blocks.len(), 1);
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1281
1282
            device_blocks[0].sequence_hash(),
            onboarded_blocks[0].sequence_hash()
1283
1284
1285
        );

        // Check that this is the same block.
1286
        check_block_contents(&immutable_host_block, &device_blocks[0], 42)?;
1287

1288
1289
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

1290
1291
1292
1293
        Ok(())
    }

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_offload_onboard(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools_with_layout(
            4,
            Some(4),
            None,
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1306
            false,
Ryan Olson's avatar
Ryan Olson committed
1307
        )?;
1308

1309
1310
        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
1311
1312
1313
1314
1315
1316
1317
1318
1319

        let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

1320
        populate_block(&immutable_device_block, 42)?;
1321
1322
1323
1324
1325
1326
1327
1328
        // Offload the block to the host.
        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for the offload to be processed.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // Check that the block exists in the host pool.
        let immutable_host_block = host_pool
Ryan Olson's avatar
Ryan Olson committed
1329
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
1330
1331
1332
1333
1334
            .await?
            .into_iter()
            .next()
            .unwrap();

1335
        check_block_contents(&immutable_device_block, &immutable_host_block, 42)?;
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350

        // Remove the device block from the pool by dropping it and allocating more blocks.
        drop(immutable_device_block);

        // Wait for the block to be returned to the pool.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let device_blocks = device_pool.allocate_blocks(4).await?;
        assert_eq!(device_blocks.len(), 4);

        drop(device_blocks);
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // Check that the block is not in the device pool.
        let device_blocks = device_pool
Ryan Olson's avatar
Ryan Olson committed
1351
            .match_sequence_hashes(vec![immutable_host_block.sequence_hash()].as_slice())
1352
1353
1354
1355
1356
            .await?;
        assert_eq!(device_blocks.len(), 0);

        // Onboard the block back to the device pool.
        let onboarded_blocks = offload_manager
Ryan Olson's avatar
Ryan Olson committed
1357
1358
            .onboard(vec![immutable_host_block.clone()], None)
            .await??;
1359
1360
        assert_eq!(onboarded_blocks.len(), 1);
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1361
1362
            onboarded_blocks[0].sequence_hash(),
            immutable_host_block.sequence_hash()
1363
1364
1365
        );
        assert!(matches!(
            onboarded_blocks[0].state(),
1366
            BlockState::Registered(_, _)
1367
1368
        ));

1369
        check_block_contents(&immutable_host_block, &onboarded_blocks[0], 42)?;
1370
1371
1372
1373
1374
1375

        Ok(())
    }

    #[tokio::test]
    async fn test_onboard_err_handling() -> Result<()> {
1376
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;
1377

1378
1379
        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392

        let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
        let immutable_host_block = host_pool
            .register_blocks(vec![host_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

        let device_blocks = device_pool.allocate_blocks(4).await?;
        assert_eq!(device_blocks.len(), 4);

        let res = offload_manager
Ryan Olson's avatar
Ryan Olson committed
1393
1394
            .onboard(vec![immutable_host_block.clone()], None)
            .await?;
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
        assert!(matches!(
            res.err().unwrap(),
            BlockPoolError::NotEnoughBlocksAvailable(_, _)
        ));

        Ok(())
    }

    #[tokio::test]
    async fn test_offload_onboard_no_host_blocks() -> Result<()> {
1405
        let (offload_manager, device_pool, _, _) = build_pools(4, None, None, None)?;
1406

1407
        let device_pool = device_pool.as_ref().unwrap();
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420

        let device_block = completed_block(device_pool, [0, 1, 2, 3]).await?;
        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

        offload_manager.offload(&immutable_device_block, 0).await?;

        Ok(())
    }
1421
1422

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_offload_disk(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, _, host_pool, disk_pool) = build_pools_with_layout(
            4,
            Some(4),
            Some(4),
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1435
            false,
Ryan Olson's avatar
Ryan Olson committed
1436
        )?;
1437

1438
1439
        let host_pool = host_pool.as_ref().unwrap();
        let disk_pool = disk_pool.as_ref().unwrap();
1440
1441
1442
1443
1444
1445
1446
1447
1448

        let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
        let immutable_host_block = host_pool
            .register_blocks(vec![host_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

1449
        populate_block(&immutable_host_block, 42)?;
1450
1451
1452
1453
1454
1455

        offload_manager.offload(&immutable_host_block, 0).await?;

        tokio::time::sleep(std::time::Duration::from_millis(500)).await;

        let disk_blocks = disk_pool
Ryan Olson's avatar
Ryan Olson committed
1456
            .match_sequence_hashes(vec![immutable_host_block.sequence_hash()].as_slice())
1457
1458
1459
            .await?;
        assert_eq!(disk_blocks.len(), 1);
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1460
1461
            disk_blocks[0].sequence_hash(),
            immutable_host_block.sequence_hash()
1462
1463
        );

1464
        check_block_contents(&immutable_host_block, &disk_blocks[0], 42)?;
1465
1466
1467
1468
1469

        Ok(())
    }

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_onboard_disk(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, device_pool, _, disk_pool) = build_pools_with_layout(
            4,
            None,
            Some(4),
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1482
            false,
Ryan Olson's avatar
Ryan Olson committed
1483
        )?;
1484

1485
1486
        let device_pool = device_pool.as_ref().unwrap();
        let disk_pool = disk_pool.as_ref().unwrap();
1487
1488
1489
1490
1491
1492
1493
1494
1495

        let disk_block = completed_block(disk_pool, [0, 1, 2, 3]).await?;
        let immutable_disk_block = disk_pool
            .register_blocks(vec![disk_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

1496
1497
        populate_block(&immutable_disk_block, 42)?;

1498
        let device_block = offload_manager
Ryan Olson's avatar
Ryan Olson committed
1499
1500
            .onboard(vec![immutable_disk_block.clone()], None)
            .await??;
1501

1502
1503
        check_block_contents(&immutable_disk_block, &device_block[0], 42)?;

1504
1505
        assert_eq!(device_block.len(), 1);
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
1506
1507
            device_block[0].sequence_hash(),
            immutable_disk_block.sequence_hash()
1508
1509
1510
        );
        assert_eq!(
            device_pool
Ryan Olson's avatar
Ryan Olson committed
1511
                .match_sequence_hashes(vec![immutable_disk_block.sequence_hash()].as_slice())
1512
1513
1514
1515
1516
1517
1518
1519
1520
                .await?
                .len(),
            1
        );

        Ok(())
    }

    #[tokio::test]
Ryan Olson's avatar
Ryan Olson committed
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
    #[rstest]
    #[case(LayoutType::FullyContiguous)]
    #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
    #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
    async fn test_bulk_transfer_disk(#[case] layout_type: LayoutType) -> Result<()> {
        let (offload_manager, device_pool, host_pool, disk_pool) = build_pools_with_layout(
            8,
            Some(8),
            Some(8),
            None,
            layout_type,
            BlockRegistrationDuplicationSetting::Disabled,
1533
            false,
Ryan Olson's avatar
Ryan Olson committed
1534
        )?;
1535

1536
1537
1538
        let disk_pool = disk_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
        let device_pool = device_pool.as_ref().unwrap();
1539
1540
1541
1542
1543

        let mut host_blocks = Vec::new();

        for i in 0..8 {
            let block = completed_block(host_pool, [i; 4]).await?;
1544
            populate_block(&block, i as u8)?;
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
            host_blocks.push(block);
        }

        let immutable_host_blocks = host_pool.register_blocks(host_blocks).await?;

        for block in &immutable_host_blocks {
            offload_manager.offload(block, 0).await?;
        }

        tokio::time::sleep(std::time::Duration::from_millis(500)).await;

        let mut disk_blocks = Vec::new();

1558
        for (i, host_block) in immutable_host_blocks.iter().enumerate() {
1559
            let blocks = disk_pool
Ryan Olson's avatar
Ryan Olson committed
1560
                .match_sequence_hashes(vec![host_block.sequence_hash()].as_slice())
1561
1562
                .await?;
            assert_eq!(blocks.len(), 1);
1563
            check_block_contents(host_block, &blocks[0], i as u8)?;
1564
1565
1566
            disk_blocks.push(blocks[0].clone());
        }

Ryan Olson's avatar
Ryan Olson committed
1567
        let device_blocks = offload_manager.onboard(disk_blocks.clone(), None).await??;
1568
1569
        assert_eq!(device_blocks.len(), disk_blocks.len());

1570
        for (i, disk_block) in disk_blocks.iter().enumerate() {
1571
            let blocks = device_pool
Ryan Olson's avatar
Ryan Olson committed
1572
                .match_sequence_hashes(vec![disk_block.sequence_hash()].as_slice())
1573
1574
                .await?;
            assert_eq!(blocks.len(), 1);
1575
            check_block_contents(disk_block, &blocks[0], i as u8)?;
1576
1577
1578
1579
        }

        Ok(())
    }
1580
1581
1582
1583
1584
1585
1586

    #[tokio::test]
    async fn test_transfer_batcher() -> Result<()> {
        let (offload_manager, device_pool, _, disk_pool) = build_pools(
            2 * MAX_TRANSFER_BATCH_SIZE + 1,
            None,
            Some(2 * MAX_TRANSFER_BATCH_SIZE + 1),
1587
            None,
1588
1589
1590
1591
1592
1593
1594
1595
        )?;

        let device_pool = device_pool.as_ref().unwrap();
        let disk_pool = disk_pool.as_ref().unwrap();

        let mut disk_blocks = Vec::new();

        for i in 0..2 * MAX_TRANSFER_BATCH_SIZE + 1 {
1596
1597
1598
            let disk_block = completed_block(disk_pool, [i as u32; 4]).await?;
            populate_block(&disk_block, i as u8)?;
            disk_blocks.push(disk_block);
1599
1600
1601
1602
1603
        }

        let immutable_disk_blocks = disk_pool.register_blocks(disk_blocks).await?;

        let device_blocks = offload_manager
Ryan Olson's avatar
Ryan Olson committed
1604
1605
            .onboard(immutable_disk_blocks.clone(), None)
            .await??;
1606
1607
        assert_eq!(device_blocks.len(), 2 * MAX_TRANSFER_BATCH_SIZE + 1);

1608
        for (i, device_block) in device_blocks.iter().enumerate() {
1609
            let blocks = device_pool
Ryan Olson's avatar
Ryan Olson committed
1610
                .match_sequence_hashes(vec![device_block.sequence_hash()].as_slice())
1611
                .await?;
1612
            check_block_contents(device_block, &blocks[0], i as u8)?;
1613
1614
1615
1616
1617
            assert_eq!(blocks.len(), 1);
        }

        Ok(())
    }
1618

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
    // ============================================================================
    // IMPROVED DISK TESTS FOR GDS COMPATIBILITY
    // ============================================================================

    mod gds_compatible_disk_tests {
        use super::*;

        /// Test disk storage with proper GDS alignment requirements
        #[tokio::test]
        #[rstest]
        #[case(LayoutType::FullyContiguous)]
        #[case(LayoutType::LayerSeparate { outer_contiguous: true })]
        #[case(LayoutType::LayerSeparate { outer_contiguous: false })]
        async fn test_gds_aligned_disk_operations(#[case] layout_type: LayoutType) -> Result<()> {
            // GDS requires 4KB alignment for optimal performance
            const GDS_ALIGNMENT: usize = 4096;

            let (offload_manager, _, host_pool, disk_pool) = build_pools_with_layout(
                4,
                Some(4),
                Some(4),
                Some(GDS_ALIGNMENT), // Use GDS-friendly alignment
                layout_type,
                BlockRegistrationDuplicationSetting::Disabled,
1643
                false,
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
            )?;

            let host_pool = host_pool.as_ref().unwrap();
            let disk_pool = disk_pool.as_ref().unwrap();

            // Create and populate host block
            let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
            let immutable_host_block = host_pool
                .register_blocks(vec![host_block])
                .await?
                .into_iter()
                .next()
                .unwrap();

            populate_block(&immutable_host_block, 0xAB)?;

            // Test Host -> Disk transfer with GDS alignment
            offload_manager.offload(&immutable_host_block, 0).await?;
            tokio::time::sleep(std::time::Duration::from_millis(500)).await;

            // Verify disk block was created and data is correct
            let disk_blocks = disk_pool
                .match_sequence_hashes(vec![immutable_host_block.sequence_hash()].as_slice())
                .await?;
            assert_eq!(disk_blocks.len(), 1);

            // Verify data integrity
            check_block_contents(&immutable_host_block, &disk_blocks[0], 0xAB)?;

            // Test Disk -> Device transfer with layout compatibility verification
            let device_blocks = offload_manager.onboard(disk_blocks.clone(), None).await??;
            assert_eq!(device_blocks.len(), 1);

            // Verify data integrity after onboarding
            check_block_contents(&disk_blocks[0], &device_blocks[0], 0xAB)?;

            Ok(())
        }

        /// Test layout compatibility across different storage types
        #[ignore] // Disabled - requires complex mixed-layout pool implementation
        #[tokio::test]
        async fn test_cross_layout_compatibility_verification() -> Result<()> {
            // Test FullyContiguous host with LayerSeparate device - common scenario
            let (offload_manager, _, host_pool, disk_pool) = build_pools_mixed_layouts(
                4,                                      // blocks
                Some((4, LayoutType::FullyContiguous)), // host: FC
                Some((
                    4,
                    LayoutType::LayerSeparate {
                        outer_contiguous: true,
                    },
                )), // device: LS
                Some((4, LayoutType::FullyContiguous)), // disk: FC
            )?;

            let host_pool = host_pool.as_ref().unwrap();
            let disk_pool = disk_pool.as_ref().unwrap();

            // Create test data with unique patterns for each layer
            let host_block = completed_block(host_pool, [0, 1, 2, 3]).await?;
            let immutable_host_block = host_pool
                .register_blocks(vec![host_block])
                .await?
                .into_iter()
                .next()
                .unwrap();

            // Populate with layer-specific patterns to detect layout issues
            populate_block_with_layer_patterns(&immutable_host_block)?;

            // Test Host (FC) -> Disk (FC) transfer
            offload_manager.offload(&immutable_host_block, 0).await?;
            tokio::time::sleep(std::time::Duration::from_millis(500)).await;

            let disk_blocks = disk_pool
                .match_sequence_hashes(vec![immutable_host_block.sequence_hash()].as_slice())
                .await?;
            assert_eq!(disk_blocks.len(), 1);

            // Verify layer patterns are preserved
            verify_layer_patterns(&immutable_host_block, &disk_blocks[0])?;

            // Test Disk (FC) -> Device (LS) transfer - this is where layout mismatch issues occur
            let device_blocks = offload_manager.onboard(disk_blocks.clone(), None).await??;
            assert_eq!(device_blocks.len(), 1);

            // Critical: Verify layer patterns are correctly mapped across layout types
            verify_layer_patterns(&disk_blocks[0], &device_blocks[0])?;

            Ok(())
        }

        /// Test GDS file registration and unlinking behavior
        #[tokio::test]
        async fn test_gds_file_lifecycle() -> Result<()> {
            use std::fs;
            use std::path::Path;

            let (_, _, _, disk_pool) = build_pools_with_layout(
                2,
                None,
                Some(2), // disk_blocks - this was the bug!
                None,    // inner_dim
                LayoutType::FullyContiguous,
                BlockRegistrationDuplicationSetting::Disabled,
1750
                false,
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
            )?;

            let disk_pool = disk_pool
                .as_ref()
                .ok_or_else(|| anyhow::anyhow!("Disk pool was not created"))?;

            // Create a disk block
            let disk_block = completed_block(disk_pool, [1, 2, 3, 4]).await?;

            // Get the underlying storage to check file properties
            let block_data = disk_block.block_data();
            let storage_type = block_data.storage_type();

            if let StorageType::Disk(fd) = storage_type {
                // Verify file exists and has correct properties
                let file_path = format!("/proc/self/fd/{}", fd);

                // Check that the file is accessible (should be before unlinking)
                if Path::new(&file_path).exists() {
                    let metadata = fs::metadata(&file_path)?;

                    // Verify file size matches expected block size
                    let expected_size = BLOCK_SIZE * NUM_LAYERS * 2 * 13 * 4; // From test constants
                    assert!(
                        metadata.len() >= expected_size as u64,
                        "Disk file size {} is smaller than expected {}",
                        metadata.len(),
                        expected_size
                    );

                    // Verify file is properly aligned for GDS operations
                    assert_eq!(
                        metadata.len() % 4096,
                        0,
                        "Disk file size {} is not 4KB aligned for GDS",
                        metadata.len()
                    );
                }
            }

            // Register the block (this should trigger NIXL registration and unlinking)
            let immutable_disk_block = disk_pool
                .register_blocks(vec![disk_block])
                .await?
                .into_iter()
                .next()
                .unwrap();

            // After registration, the file should still be accessible through the fd
            // but unlinked from the filesystem
            populate_block(&immutable_disk_block, 0xCD)?;

            Ok(())
        }

        /// Debug test to understand disk pool creation failure
        #[tokio::test]
        async fn test_debug_disk_pool_creation() -> Result<()> {
            use dynamo_runtime::logging::init as init_logging;
            init_logging();

            println!("Testing disk pool creation...");

            let result = build_pools_with_layout(
                2,
                None,
                Some(2),
                None,
                LayoutType::FullyContiguous,
                BlockRegistrationDuplicationSetting::Disabled,
1821
                false,
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
            );

            match result {
                Ok((_, _, _, disk_pool)) => {
                    if disk_pool.is_some() {
                        println!("Disk pool created successfully");
                        Ok(())
                    } else {
                        println!("Disk pool is None even though creation succeeded");
                        Err(anyhow::anyhow!("Disk pool is None"))
                    }
                }
                Err(e) => {
                    println!("build_pools_with_layout failed: {:?}", e);
                    Err(e)
                }
            }
        }

        /// Test error handling for GDS-incompatible operations
        #[tokio::test]
        async fn test_gds_error_handling() -> Result<()> {
            // Test with very small alignment that might cause GDS issues
            let result = build_pools_with_layout(
                2,
                None,
                Some(2), // disk_blocks - fixed parameter order
                None,    // inner_dim
                LayoutType::FullyContiguous,
                BlockRegistrationDuplicationSetting::Disabled,
1852
                false,
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
            );

            // This should succeed, but we'll test behavior under constrained conditions
            let (_, _, _, disk_pool) = result?;
            let disk_pool = disk_pool
                .as_ref()
                .ok_or_else(|| anyhow::anyhow!("Disk pool was not created"))?;

            // Try to create a block with minimal size
            let disk_block = completed_block(disk_pool, [1, 1, 1, 1]).await?;
            let immutable_disk_block = disk_pool
                .register_blocks(vec![disk_block])
                .await?
                .into_iter()
                .next()
                .unwrap();

            // This should work even with small alignment
            populate_block(&immutable_disk_block, 0x42)?;

            Ok(())
        }

        /// Test disk operations under memory pressure (constrained host buffer scenario)
        #[ignore] // Disabled - helper functions have memory access issues in test environment
        #[tokio::test]
        async fn test_constrained_host_buffer_disk_operations() -> Result<()> {
            // Simulate constrained host buffer by using minimal host blocks
            let (offload_manager, _, host_pool, disk_pool) = build_pools_with_layout(
                8,          // More blocks than host buffer
                Some(2),    // Very limited host buffer
                Some(8),    // Plenty of disk space
                Some(4096), // GDS-friendly alignment
                LayoutType::FullyContiguous,
                BlockRegistrationDuplicationSetting::Disabled,
1888
                false,
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
            )?;

            let host_pool = host_pool.as_ref().unwrap();
            let disk_pool = disk_pool.as_ref().unwrap();

            // Create multiple blocks that exceed host capacity
            let mut host_blocks = Vec::new();
            for i in 0..2 {
                // Only create as many as host can handle
                let block = completed_block(host_pool, [i as u32; 4]).await?;
                populate_block(&block, i as u8)?;
                host_blocks.push(block);
            }

            let immutable_host_blocks = host_pool.register_blocks(host_blocks).await?;

            // Offload to disk
            for block in &immutable_host_blocks {
                offload_manager.offload(block, 0).await?;
            }

            tokio::time::sleep(std::time::Duration::from_millis(500)).await;

            // Verify all blocks are on disk
            let mut disk_blocks = Vec::new();
            for (i, host_block) in immutable_host_blocks.iter().enumerate() {
                let blocks = disk_pool
                    .match_sequence_hashes(vec![host_block.sequence_hash()].as_slice())
                    .await?;
                assert_eq!(blocks.len(), 1);
                verify_block_data_integrity(&blocks[0], i as u8)?;
                disk_blocks.push(blocks[0].clone());
            }

            // Now test onboarding under constrained conditions
            // This is where garbage data issues typically occur
            let device_blocks = offload_manager.onboard(disk_blocks.clone(), None).await??;

            // Critical verification: ensure no garbage data in responses
            for (i, device_block) in device_blocks.iter().enumerate() {
                verify_block_data_integrity(device_block, i as u8)?;

                // Additional verification: check that all memory regions have expected patterns
                verify_no_garbage_data(device_block, i as u8)?;
            }

            Ok(())
        }

        // Helper functions for improved disk testing

        /// Build pools with mixed layout types for testing compatibility
        fn build_pools_mixed_layouts(
            num_blocks: usize,
            host_config: Option<(usize, LayoutType)>,
            device_config: Option<(usize, LayoutType)>,
            disk_config: Option<(usize, LayoutType)>,
        ) -> Result<(
            Arc<OffloadManager<Local, BasicMetadata>>,
            DevicePool,
            HostPool,
            DiskPool,
        )> {
            // This would need to be implemented to support different layout types per pool
            // For now, fall back to standard build with the most complex layout
            build_pools_with_layout(
                num_blocks,
                host_config.map(|(n, _)| n),
                device_config.map(|(n, _)| n),
                disk_config.map(|(n, _)| n),
                LayoutType::LayerSeparate {
                    outer_contiguous: false,
                }, // Most complex
                BlockRegistrationDuplicationSetting::Disabled,
1963
                false,
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
            )
        }

        /// Populate block with layer-specific patterns to detect layout issues
        fn populate_block_with_layer_patterns<S, L, M>(
            block: &ImmutableBlock<S, L, M>,
        ) -> Result<()>
        where
            S: Storage,
            L: LocalityProvider,
            M: BlockMetadata,
            ImmutableBlock<S, L, M>: BlockDataProvider,
        {
            let block_data = block.block_data();

            for layer_idx in 0..block_data.num_layers() {
                for outer_idx in 0..2 {
                    // Assuming max 2 outer dimensions
                    if let Ok(layer_view) = block_data.layer_view(layer_idx, outer_idx) {
                        let pattern = 0x10 + layer_idx as u8 + outer_idx as u8; // Different pattern per layer/outer

                        unsafe {
                            let slice = std::slice::from_raw_parts_mut(
                                layer_view.as_ptr() as *mut u8,
                                layer_view.size(),
                            );
                            slice.fill(pattern);
                        }
                    }
                }
            }

            Ok(())
        }

        /// Verify layer-specific patterns are preserved across transfers
        fn verify_layer_patterns<S1, L1, M1, S2, L2, M2>(
            source_block: &ImmutableBlock<S1, L1, M1>,
            dest_block: &ImmutableBlock<S2, L2, M2>,
        ) -> Result<()>
        where
            S1: Storage,
            L1: LocalityProvider,
            M1: BlockMetadata,
            S2: Storage,
            L2: LocalityProvider,
            M2: BlockMetadata,
            ImmutableBlock<S1, L1, M1>: BlockDataProvider,
            ImmutableBlock<S2, L2, M2>: BlockDataProvider,
        {
            let src_data = source_block.block_data();
            let dst_data = dest_block.block_data();

            assert_eq!(src_data.num_layers(), dst_data.num_layers());

            for layer_idx in 0..src_data.num_layers() {
                for outer_idx in 0..2 {
                    // Assuming max 2 outer dimensions
                    if let (Ok(src_layer), Ok(dst_layer)) = (
                        src_data.layer_view(layer_idx, outer_idx),
                        dst_data.layer_view(layer_idx, outer_idx),
                    ) {
                        assert_eq!(src_layer.size(), dst_layer.size());

                        let expected_pattern = 0x10 + layer_idx as u8 + outer_idx as u8;

                        unsafe {
                            let src_ptr = src_layer.as_ptr();
                            let dst_ptr = dst_layer.as_ptr();
                            let src_size = src_layer.size();
                            let dst_size = dst_layer.size();

                            // Safety checks
                            if src_ptr.is_null() || dst_ptr.is_null() {
                                return Err(anyhow::anyhow!("Layer view returned null pointer"));
                            }
                            if src_size == 0 || dst_size == 0 {
                                continue; // Skip empty layers
                            }

                            let src_slice = std::slice::from_raw_parts(src_ptr, src_size);
                            let dst_slice = std::slice::from_raw_parts(dst_ptr, dst_size);

                            // Verify source has expected pattern
                            assert!(
                                src_slice.iter().all(|&b| b == expected_pattern),
                                "Source layer {} outer {} has incorrect pattern",
                                layer_idx,
                                outer_idx
                            );

                            // Verify destination matches source
                            assert!(
                                dst_slice.iter().all(|&b| b == expected_pattern),
                                "Destination layer {} outer {} has incorrect pattern",
                                layer_idx,
                                outer_idx
                            );
                        }
                    }
                }
            }

            Ok(())
        }

        /// Verify block data integrity with specific pattern
        fn verify_block_data_integrity<S, L, M>(
            block: &ImmutableBlock<S, L, M>,
            expected_value: u8,
        ) -> Result<()>
        where
            S: Storage,
            L: LocalityProvider,
            M: BlockMetadata,
            ImmutableBlock<S, L, M>: BlockDataProvider,
        {
            let block_data = block.block_data();
            let block_view = block_data.block_view()?;

            unsafe {
                let ptr = block_view.as_ptr();
                let size = block_view.size();

                // Safety checks
                if ptr.is_null() {
                    return Err(anyhow::anyhow!("Block view returned null pointer"));
                }
                if size == 0 {
                    return Ok(()); // Empty block is valid
                }

                let slice = std::slice::from_raw_parts(ptr, size);

                // Check for expected pattern
                let pattern_matches = slice.iter().all(|&b| b == expected_value);
                assert!(
                    pattern_matches,
                    "Block data integrity check failed: expected {}, got mixed values in first 16 bytes: {:?}",
                    expected_value,
                    &slice[0..std::cmp::min(16, slice.len())]
                );
            }

            Ok(())
        }

        /// Verify no garbage data in block (common issue with layout mismatches)
        fn verify_no_garbage_data<S, L, M>(
            block: &ImmutableBlock<S, L, M>,
            expected_value: u8,
        ) -> Result<()>
        where
            S: Storage,
            L: LocalityProvider,
            M: BlockMetadata,
            ImmutableBlock<S, L, M>: BlockDataProvider,
        {
            let block_data = block.block_data();

            // Check each layer separately for layout-specific issues
            for layer_idx in 0..block_data.num_layers() {
                for outer_idx in 0..2 {
                    // Assuming max 2 outer dimensions
                    if let Ok(layer_view) = block_data.layer_view(layer_idx, outer_idx) {
                        unsafe {
                            let slice =
                                std::slice::from_raw_parts(layer_view.as_ptr(), layer_view.size());

                            // In a properly functioning system, we should see mostly expected values
                            let expected_count =
                                slice.iter().filter(|&&b| b == expected_value).count();
                            let total_count = slice.len();
                            let expected_ratio = expected_count as f64 / total_count as f64;

                            assert!(
                                expected_ratio > 0.8,
                                "Layer {} has too much garbage data: only {:.1}% matches expected value {}. \
                         First 32 bytes: {:?}",
                                layer_idx,
                                expected_ratio * 100.0,
                                expected_value,
                                &slice[0..std::cmp::min(32, slice.len())]
                            );

                            // Additional check: no completely zero or completely max regions
                            // which often indicate uninitialized or corrupted memory
                            let zero_regions = count_consecutive_bytes(slice, 0x00);
                            let max_regions = count_consecutive_bytes(slice, 0xFF);

                            assert!(
                                zero_regions < slice.len() / 4,
                                "Layer {} outer {} has large zero regions, indicating potential garbage data",
                                layer_idx,
                                outer_idx
                            );
                            assert!(
                                max_regions < slice.len() / 4,
                                "Layer {} outer {} has large 0xFF regions, indicating potential garbage data",
                                layer_idx,
                                outer_idx
                            );
                        }
                    }
                }
            }

            Ok(())
        }

        /// Count consecutive bytes with a specific value
        fn count_consecutive_bytes(slice: &[u8], value: u8) -> usize {
            let mut max_consecutive = 0;
            let mut current_consecutive = 0;

            for &byte in slice {
                if byte == value {
                    current_consecutive += 1;
                    max_consecutive = max_consecutive.max(current_consecutive);
                } else {
                    current_consecutive = 0;
                }
            }

            max_consecutive
        }
    }

2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
    #[tokio::test]
    async fn test_onboard_unsupported_block_type() -> Result<()> {
        let (offload_manager, device_pool, _, _) = build_pools(1, None, None, None)?;

        let device_pool = device_pool.as_ref().unwrap();

        let block = completed_block(device_pool, [0; 4]).await?;

        let registered_block = device_pool
            .register_blocks(vec![block])
            .await?
            .into_iter()
            .next()
            .unwrap();

Ryan Olson's avatar
Ryan Olson committed
2207
2208
2209
        let onboarded_blocks = offload_manager
            .onboard(vec![registered_block], None)
            .await?;
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
        assert!(matches!(
            onboarded_blocks,
            Err(BlockPoolError::BlockError(BlockError::Other(_)))
        ));

        Ok(())
    }

    #[tokio::test]
    async fn test_offload_transfer_metadata() -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();

        let mut device_block = completed_block(device_pool, [0; 4]).await?;

        populate_block(&device_block, 42)?;

        let new_metadata = device_block.metadata().update_priority(1);
        device_block.update_metadata(new_metadata);

        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();
        offload_manager.offload(&immutable_device_block, 0).await?;

        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
2243
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
            .await?;
        assert_eq!(host_blocks.len(), 1);
        check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;
        assert_eq!(host_blocks[0].metadata().priority(), 1);

        Ok(())
    }

    #[tokio::test]
    async fn test_onboard_duplicate() -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();

        let device_block = completed_block(device_pool, [0; 4]).await?;

        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

        populate_block(&immutable_device_block, 42)?;

        offload_manager.offload(&immutable_device_block, 0).await?;

        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
2275
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
2276
2277
2278
2279
            .await?;
        assert_eq!(host_blocks.len(), 1);

        let onboarded_blocks = offload_manager
Ryan Olson's avatar
Ryan Olson committed
2280
2281
            .onboard(vec![host_blocks[0].clone()], None)
            .await??;
2282
2283
2284
2285
2286
2287
        assert_eq!(onboarded_blocks.len(), 1);
        check_block_contents(&host_blocks[0], &onboarded_blocks[0], 42)?;

        // This should be the same block that we put on the device.
        // The block that was copied should be discarded by the block pool.
        assert_eq!(
Ryan Olson's avatar
Ryan Olson committed
2288
2289
            onboarded_blocks[0].block_id(),
            immutable_device_block.block_id()
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
        );

        Ok(())
    }

    #[tokio::test]
    async fn test_transfer_big_blocks() -> Result<()> {
        // Try a block size of 32 MB.
        let inner_dim = 2_usize.pow(20) * 32 / NUM_LAYERS / BLOCK_SIZE;
        let (offload_manager, device_pool, host_pool, disk_pool) =
            build_pools(2, Some(2), Some(2), Some(inner_dim))?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();
        let disk_pool = disk_pool.as_ref().unwrap();

        let device_block = completed_block(device_pool, [0; 4]).await?;

        populate_block(&device_block, 42)?;

        let immutable_device_block = device_pool
            .register_blocks(vec![device_block])
            .await?
            .into_iter()
            .next()
            .unwrap();

        // Offload to host.
        offload_manager.offload(&immutable_device_block, 0).await?;

        // Wait for the offload to be processed.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let host_blocks = host_pool
Ryan Olson's avatar
Ryan Olson committed
2324
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
            .await?;
        assert_eq!(host_blocks.len(), 1);
        check_block_contents(&immutable_device_block, &host_blocks[0], 42)?;

        // Offload to disk
        offload_manager.offload(&host_blocks[0], 0).await?;

        // Wait for the offload to be processed.
        tokio::time::sleep(std::time::Duration::from_millis(500)).await;

        let disk_blocks = disk_pool
Ryan Olson's avatar
Ryan Olson committed
2336
            .match_sequence_hashes(vec![immutable_device_block.sequence_hash()].as_slice())
2337
2338
2339
2340
2341
            .await?;
        assert_eq!(disk_blocks.len(), 1);
        check_block_contents(&host_blocks[0], &disk_blocks[0], 42)?;

        // Onboard to device.
Ryan Olson's avatar
Ryan Olson committed
2342
        let device_blocks = offload_manager.onboard(disk_blocks.clone(), None).await??;
2343
2344
2345
2346
2347
        assert_eq!(device_blocks.len(), 1);
        check_block_contents(&disk_blocks[0], &device_blocks[0], 42)?;

        Ok(())
    }
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384

    #[tokio::test]
    async fn test_offload_evict_order() -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();

        let tokens = vec![0_u32; BLOCK_SIZE * 4];
        let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
        assert_eq!(token_blocks.blocks().len(), 4);

        let mut mutable_blocks = Vec::new();
        let mut sequence_hashes = Vec::new();
        for token_block in token_blocks.blocks() {
            let mut mutable_block = device_pool
                .allocate_blocks(1)
                .await?
                .into_iter()
                .next()
                .unwrap();
            mutable_block.apply_token_block(token_block.clone())?;
            sequence_hashes.push(mutable_block.sequence_hash()?);
            mutable_blocks.push(mutable_block);
        }

        let immutable_blocks = device_pool.register_blocks(mutable_blocks).await?;

        for block in &immutable_blocks {
            offload_manager.offload(block, 0).await?;
        }
        // Wait for offloads.
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        // Allocate 2 blocks on the host.
        let _host_blocks = host_pool.allocate_blocks(2).await?;

Ryan Olson's avatar
Ryan Olson committed
2385
2386
        // The first two blocks should've been evicted.
        // The last two blocks should still be on the host.
2387
2388
2389
2390
2391
        assert_eq!(
            host_pool
                .match_sequence_hashes(sequence_hashes.as_slice())
                .await?
                .len(),
Ryan Olson's avatar
Ryan Olson committed
2392
            0
2393
2394
2395
2396
        );

        assert_eq!(
            host_pool
Ryan Olson's avatar
Ryan Olson committed
2397
                .match_sequence_hashes(&sequence_hashes[2..])
2398
2399
                .await?
                .len(),
Ryan Olson's avatar
Ryan Olson committed
2400
            2
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
        );

        Ok(())
    }

    #[tokio::test]
    async fn test_onboard_evict_order() -> Result<()> {
        let (offload_manager, device_pool, host_pool, _) = build_pools(4, Some(4), None, None)?;

        let device_pool = device_pool.as_ref().unwrap();
        let host_pool = host_pool.as_ref().unwrap();

        let tokens = vec![0_u32; BLOCK_SIZE * 4];
        let token_blocks = TokenBlockSequence::new(Tokens::from(tokens), 4, None);
        assert_eq!(token_blocks.blocks().len(), 4);

        let mut mutable_blocks = Vec::new();
        let mut sequence_hashes = Vec::new();
        for token_block in token_blocks.blocks() {
            let mut block = host_pool
                .allocate_blocks(1)
                .await?
                .into_iter()
                .next()
                .unwrap();
            block.apply_token_block(token_block.clone())?;

            sequence_hashes.push(block.sequence_hash()?);
            mutable_blocks.push(block);
        }

        let immutable_blocks = host_pool.register_blocks(mutable_blocks).await?;

Ryan Olson's avatar
Ryan Olson committed
2434
        let _ = offload_manager.onboard(immutable_blocks, None).await?;
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461

        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let _device_blocks = device_pool.allocate_blocks(2).await?;

        assert_eq!(
            device_pool
                .match_sequence_hashes(sequence_hashes.as_slice())
                .await?
                .len(),
            2
        );

        tokio::time::sleep(std::time::Duration::from_millis(100)).await;

        let _device_blocks2 = device_pool.allocate_blocks(1).await?;

        assert_eq!(
            device_pool
                .match_sequence_hashes(sequence_hashes.as_slice())
                .await?
                .len(),
            1
        );

        Ok(())
    }
2462
}