sequence.rs 46.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: Apache-2.0

//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
12
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
13
14
//!   token sequences, managing shared KV cache blocks efficiently.
//!
15
16
//! - [`ActiveSequencesMultiWorker`]: Multi-worker extension that stores per-worker
//!   `ActiveSequences` in a shared `DashMap` for lock-free concurrent access.
17
18
19
20
21
22
23
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).

24
use crate::kv_router::protocols::OverlapScores;
25
26
use anyhow::Result;
use dashmap::DashMap;
27
use derive_getters::Getters;
28
29
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
30
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
31
use dynamo_tokens::SequenceHash;
32
use std::collections::{HashMap, HashSet};
33
use std::sync::Arc;
34
35
use std::time::Duration;
use tokio::time::Instant;
36
37
use uuid::Uuid;

38
use super::metrics::WORKER_LOAD_METRICS;
39
40
41
42
use super::protocols::{
    ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank,
};
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
Yan Ru Pei's avatar
Yan Ru Pei committed
43
use crate::local_model::runtime_config::ModelRuntimeConfig;
44
use dynamo_runtime::CancellationToken;
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/// Errors that can occur during sequence management operations
#[derive(Debug, thiserror::Error)]
pub enum SequenceError {
    #[error("Worker {worker:?} not found")]
    WorkerNotFound { worker: WorkerWithDpRank },

    #[error("Request {request_id} already exists (assigned to worker {worker:?})")]
    DuplicateRequest {
        request_id: String,
        worker: WorkerWithDpRank,
    },

    #[error("Request {request_id} not found")]
    RequestNotFound { request_id: String },

    #[error("Failed to publish event: {0}")]
    PublishFailed(#[from] anyhow::Error),
}

65
66
67
/// Duration after which stale requests are forcibly expired (5 minutes)
const EXPIRY_DURATION: Duration = Duration::from_secs(300);

68
69
70
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;

71
72
73
74
75
76
77
78
79
80
81
/// Bundled parameters for adding a request to the sequence tracker.
pub struct SequenceRequest {
    pub request_id: RequestId,
    pub token_sequence: Option<Vec<SequenceHash>>,
    pub isl: usize,
    pub overlap: u32,
    pub expected_output_tokens: Option<u32>,
    pub worker: WorkerWithDpRank,
    pub lora_name: Option<String>,
}

82
83
84
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
85
    active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,
86

87
88
    prefill_tokens: HashMap<RequestId, usize>,

89
90
91
    /// Expected output tokens per request (used for resource estimation)
    expected_output_tokens: HashMap<RequestId, u32>,

92
    unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,
93

94
95
96
97
98
    /// Fractional block counts for blocks that are partially cached
    /// When a block is in both unique_blocks and fractional_blocks,
    /// it contributes the fractional value instead of 1 to active_blocks()
    fractional_blocks: HashMap<SequenceHash, f64>,

99
100
101
    #[getter(copy)]
    block_size: usize,

102
103
    #[getter(copy)]
    active_tokens: usize,
104
105
106
107
108
109

    /// Timer for when to force expiry of stale requests
    expiry_timer: Instant,

    /// Set of request IDs to check for expiry
    expiry_requests: HashSet<RequestId>,
110
111
112
113
114
115
116
117
118
119
}

impl ActiveSequences {
    /// Create a new SharedSequenceManager instance
    pub fn new(block_size: usize) -> Self {
        // TODO: make this not a hard req
        assert!(block_size > 1, "block_size must be greater than 1");

        Self {
            active_seqs: HashMap::new(),
120
            prefill_tokens: HashMap::new(),
121
            expected_output_tokens: HashMap::new(),
122
            unique_blocks: HashMap::new(),
123
            fractional_blocks: HashMap::new(),
124
            block_size,
125
            active_tokens: 0,
126
127
            expiry_timer: Instant::now() + EXPIRY_DURATION,
            expiry_requests: HashSet::new(),
128
129
130
        }
    }

131
    fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
132
133
134
135
        if let Some(weak) = self.unique_blocks.get(block)
            && let Some(rc) = weak.upgrade()
        {
            return rc;
136
137
        }

138
139
        let rc = Arc::new(());
        self.unique_blocks.insert(*block, Arc::downgrade(&rc));
140
141
        rc
    }
142

143
144
145
146
    fn try_remove_block(&mut self, block: &SequenceHash) {
        if let Some(weak) = self.unique_blocks.get(block)
            && weak.strong_count() == 0
        {
147
            self.unique_blocks.remove(block);
148
            self.fractional_blocks.remove(block);
149
150
151
        }
    }

152
    pub fn active_blocks(&self) -> usize {
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        let mut count = self.unique_blocks.len() as f64;
        for (hash, frac) in &self.fractional_blocks {
            if self.unique_blocks.contains_key(hash) {
                // Subtract 1 (the full block) and add the fractional value
                count = count - 1.0 + frac;
            }
        }
        count.round() as usize
    }

    /// Find all blocks in a request that have only a single strong reference (only used by this request)
    /// and insert them into fractional_blocks with the given fraction value.
    pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
        let Some(blocks) = self.active_seqs.get(request_id) else {
            tracing::warn!(
                "Request {request_id} not found for set_single_ref_blocks_as_fractional"
            );
            return;
        };

        for (hash, rc) in blocks {
            // A block with strong_count == 1 means only this request holds a reference
175
            if Arc::strong_count(rc) == 1 {
176
177
178
                self.fractional_blocks.insert(*hash, fraction);
            }
        }
179
180
    }

181
    /// Add a new request with its initial tokens
182
    /// Returns the set of expired request IDs that were removed during cleanup
183
184
185
    pub fn add_request(
        &mut self,
        request_id: RequestId,
186
        token_sequence: Option<Vec<SequenceHash>>,
187
        isl: usize,
188
        overlap: u32,
189
        expected_output_tokens: Option<u32>,
190
    ) -> HashSet<RequestId> {
191
        // Check for double-add and log error, returning early
192
        if self.active_seqs.contains_key(&request_id) {
193
194
            tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
            return HashSet::new();
195
196
        }

197
198
199
        // Lazily check and clean up expired requests, capturing removed IDs
        let removed_requests = self.force_expiry();

200
        let prefill_tokens = self.new_tokens(isl, overlap);
201
202
203
204
        self.prefill_tokens
            .insert(request_id.clone(), prefill_tokens);
        self.active_tokens += prefill_tokens;

205
206
207
208
209
210
        // Store expected output tokens if provided
        if let Some(tokens) = expected_output_tokens {
            self.expected_output_tokens
                .insert(request_id.clone(), tokens);
        }

211
        if let Some(sequence) = token_sequence {
212
            let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
213
214
215
216
217
                .iter()
                .map(|block| (*block, self.touch_block(block)))
                .collect();
            self.active_seqs
                .insert(request_id.clone(), sequence_with_refs);
218
219
220
        } else {
            // dummy empty sequence
            self.active_seqs.insert(request_id.clone(), Vec::new());
221
222
        }

223
        removed_requests
224
225
    }

226
227
228
229
230
    /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
    pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
        if let Some(tokens) = self.prefill_tokens.remove(request_id) {
            self.active_tokens = self
                .active_tokens
231
                .checked_sub(tokens)
232
233
234
235
236
                .expect("active_tokens underflow");
        }
    }

    pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
237
238
239
240
241
242
243
244
245
        let cached_tokens = (overlap as usize) * self.block_size;
        isl.checked_sub(cached_tokens)
            .unwrap_or_else(|| {
                tracing::error!(
                    "prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
                    self.block_size
                );
                0
            })
246
247
248
249
    }

    pub fn potential_blocks_and_tokens(
        &self,
250
        token_sequence: Option<&[SequenceHash]>,
251
        isl: usize,
252
253
        overlap: u32,
    ) -> (usize, usize) {
254
        let potential_blocks = if let Some(token_seq) = token_sequence {
255
            self.new_blocks(token_seq) + self.active_blocks()
256
        } else {
257
            self.active_blocks()
258
        };
259
        let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
260
261
262
        (potential_blocks, potential_tokens)
    }

263
    /// Match a request against existing blocks and return the number of new blocks that would be added
264
265
    pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        token_sequence
266
267
268
269
270
271
272
            .iter()
            .filter(|block| !self.unique_blocks.contains_key(block))
            .count()
    }

    /// Return the total number of blocks that would be used if the token sequence was added
    /// This is the sum of new blocks that would be added plus the current active blocks
273
    pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
274
        self.new_blocks(token_sequence) + self.active_blocks()
275
276
277
278
    }

    /// Free all blocks associated with a request
    pub fn free(&mut self, request_id: &RequestId) -> usize {
279
        self.mark_prefill_completed(request_id);
280

281
282
        self.expiry_requests.remove(request_id);

283
284
285
        // Remove expected output tokens tracking
        self.expected_output_tokens.remove(request_id);

286
287
288
289
290
        // Remove from active_seqs and get the token sequence
        let token_seq = match self.active_seqs.remove(request_id) {
            Some(seq) => seq,
            None => {
                tracing::warn!("Trying to free non-existent request {request_id}");
291
                return self.active_blocks();
292
            }
293
294
        };

295
296
297
298
        // Drop each Rc reference, then clean up the corresponding weak reference
        for (block_hash, rc) in token_seq {
            drop(rc);
            self.try_remove_block(&block_hash);
299
300
        }

301
        self.active_blocks()
302
    }
303

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    /// Add an output block with a random hash and optional fractional decay weight.
    ///
    /// This is used during generation to track output blocks as they are created.
    /// The decay_fraction (if provided) represents how "temporary" the block is:
    /// - 1.0 means fully counted (early in generation)
    /// - 0.0 means not counted (near end of expected output)
    /// - Computed as: 1 - (current_osl / expected_output_tokens)
    ///
    /// Returns true if the block was added, false if the request was not found.
    pub fn add_output_block(
        &mut self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
    ) -> bool {
        // Check if request exists first (immutable borrow)
        if !self.active_seqs.contains_key(request_id) {
            tracing::warn!("Request {request_id} not found for add_output_block");
            return false;
        }

        // Generate a random block hash using UUID
        let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;

        // Touch the block (adds to unique_blocks)
        let rc = self.touch_block(&random_hash);

        // Now we can safely get_mut and push
        self.active_seqs
            .get_mut(request_id)
            .unwrap()
            .push((random_hash, rc));

        // Apply fractional decay to all single-ref blocks in this request if provided
        if let Some(frac) = decay_fraction {
            self.set_single_ref_blocks_as_fractional(request_id, frac);
        }

        true
    }

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    /// Force expiry of stale requests if the timer has elapsed
    /// Returns the set of expired request IDs that were removed
    pub fn force_expiry(&mut self) -> HashSet<RequestId> {
        let now = Instant::now();

        // Early return if timer hasn't expired yet
        if now < self.expiry_timer {
            return HashSet::new();
        }

        // Process expired requests - drain to avoid clone
        let expired_requests: HashSet<RequestId> = self.expiry_requests.drain().collect();
        for request_id in &expired_requests {
            tracing::warn!("Force expiring stale request: {}", request_id);
            self.free(request_id);
        }

        self.expiry_timer = now + EXPIRY_DURATION;
        self.expiry_requests = self.active_seqs.keys().cloned().collect();

        expired_requests
    }
366
367
}

368
/// Multi-worker extension of ActiveSequences using shared DashMap for lock-free concurrent access
369
pub struct ActiveSequencesMultiWorker {
370
    workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
371
    request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
372
    request_to_lora: Arc<DashMap<RequestId, String>>,
373
    block_size: usize,
374
    router_id: u64,
375
    event_publisher: EventPublisher,
376
    metrics_publisher: Arc<EventPublisher>,
377
    replica_sync: bool,
378
    worker_type: &'static str,
379
380
381
}

impl ActiveSequencesMultiWorker {
382
    pub async fn new(
383
384
        component: Component,
        block_size: usize,
385
        workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
386
        replica_sync: bool,
387
        router_id: u64,
388
        worker_type: &'static str,
389
    ) -> Result<Self> {
390
391
        assert!(block_size > 1, "block_size must be greater than 1");

392
        let workers = Arc::new(DashMap::new());
393
        let request_to_worker = Arc::new(DashMap::new());
394
        let request_to_lora = Arc::new(DashMap::new());
395

Yan Ru Pei's avatar
Yan Ru Pei committed
396
        for (worker_id, config) in workers_with_configs {
397
            let dp_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
398
399
400

            for dp_rank in 0..dp_size {
                let worker = WorkerWithDpRank::new(worker_id, dp_rank);
401
                workers.insert(worker, ActiveSequences::new(block_size));
Yan Ru Pei's avatar
Yan Ru Pei committed
402
            }
403
404
        }

405
406
        let event_publisher =
            EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
407
408
409
        let metrics_publisher = Arc::new(
            EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?,
        );
410

411
        let multi_worker = Self {
412
            workers: workers.clone(),
413
            request_to_worker: request_to_worker.clone(),
414
            request_to_lora: request_to_lora.clone(),
415
            block_size,
416
417
            event_publisher,
            metrics_publisher,
418
419
            router_id,
            replica_sync,
420
            worker_type,
421
422
423
        };

        if replica_sync {
424
            let workers_clone = workers.clone();
425
            let request_to_worker_clone = request_to_worker.clone();
426
            let request_to_lora_clone = request_to_lora.clone();
427
428
            let component_clone = component.clone();
            let router_id_clone = router_id;
429
            let cancel_token = component.drt().runtime().child_token();
430

431
            tokio::spawn(async move {
432
                if let Err(e) = Self::subscribe_to_events(
433
                    workers_clone,
434
                    request_to_worker_clone,
435
                    request_to_lora_clone,
436
437
                    component_clone,
                    router_id_clone,
438
                    cancel_token,
439
440
441
442
443
444
                )
                .await
                {
                    tracing::error!("Error in active sequences events subscription: {}", e);
                }
            });
445
        }
446

447
        Ok(multi_worker)
448
449
    }

450
451
    /// Background task to subscribe to active sequence events and update all workers
    async fn subscribe_to_events(
452
        workers: Arc<DashMap<WorkerWithDpRank, ActiveSequences>>,
Yan Ru Pei's avatar
Yan Ru Pei committed
453
        request_to_worker: Arc<DashMap<RequestId, WorkerWithDpRank>>,
454
        request_to_lora: Arc<DashMap<RequestId, String>>,
455
        component: Component,
456
        router_id: u64,
457
        cancel_token: CancellationToken,
458
    ) -> Result<()> {
459
460
461
        let mut subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
            .await?
            .typed::<ActiveSequenceEvent>();
462

463
464
465
466
467
468
        loop {
            tokio::select! {
                result = subscriber.next() => {
                    let Some(result) = result else {
                        break;
                    };
469

470
                    let Ok((_envelope, event)) = result else {
471
472
473
                        tracing::error!(
                            "Error receiving active sequence event: {}",
                            result.unwrap_err()
474
                        );
475
476
477
478
479
                        continue;
                    };

                    if event.router_id == router_id {
                        continue;
480
                    }
481
482
483
484
485
486

                    match &event.data {
                        ActiveSequenceEventData::AddRequest {
                            token_sequence,
                            isl,
                            overlap,
487
                            expected_output_tokens,
488
                        } => {
Yan Ru Pei's avatar
Yan Ru Pei committed
489
                            request_to_worker.insert(event.request_id.clone(), event.worker);
490

491
492
493
494
                            if let Some(ref lora_name) = event.lora_name {
                                request_to_lora.insert(event.request_id.clone(), lora_name.clone());
                            }

495
496
497
498
499
500
501
502
                            if let Some(mut entry) = workers.get_mut(&event.worker) {
                                entry.add_request(
                                    event.request_id.clone(),
                                    token_sequence.clone(),
                                    *isl,
                                    *overlap,
                                    *expected_output_tokens,
                                );
503
504
                            } else {
                                tracing::warn!(
Yan Ru Pei's avatar
Yan Ru Pei committed
505
506
                                    "Worker {:?} not found, cannot process AddRequest",
                                    event.worker
507
508
509
510
                                );
                            }
                        }
                        ActiveSequenceEventData::Free => {
Yan Ru Pei's avatar
Yan Ru Pei committed
511
                            if let Some((_, worker)) = request_to_worker.remove(&event.request_id)
512
                                && let Some(mut entry) = workers.get_mut(&worker)
513
                            {
514
                                entry.free(&event.request_id);
515
                            }
516
                            request_to_lora.remove(&event.request_id);
517
518
                        }
                        ActiveSequenceEventData::MarkPrefillCompleted => {
Yan Ru Pei's avatar
Yan Ru Pei committed
519
                            if let Some(worker) = request_to_worker.get(&event.request_id)
520
                                && let Some(mut entry) = workers.get_mut(&*worker)
521
                            {
522
                                entry.mark_prefill_completed(&event.request_id);
523
524
                            }
                        }
525
                    }
526
                }
527
528
529
                _ = cancel_token.cancelled() => {
                    tracing::debug!("Subscription task cancelled");
                    break;
530
531
                }
            }
532
        }
533

534
        Ok(())
535
536
537
    }

    /// Update the set of workers, adding and removing as needed
538
    pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
Yan Ru Pei's avatar
Yan Ru Pei committed
539
        let current_workers: HashSet<WorkerWithDpRank> =
540
            self.workers.iter().map(|entry| *entry.key()).collect();
541

Yan Ru Pei's avatar
Yan Ru Pei committed
542
543
        let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
        for (worker_id, config) in &new_workers_with_configs {
544
            let dp_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
545
546
547
548
549
550
551

            for dp_rank in 0..dp_size {
                new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
            }
        }

        let workers_to_remove: Vec<WorkerWithDpRank> =
552
            current_workers.difference(&new_workers).copied().collect();
Yan Ru Pei's avatar
Yan Ru Pei committed
553
        let workers_to_add: Vec<WorkerWithDpRank> =
554
555
            new_workers.difference(&current_workers).copied().collect();

Yan Ru Pei's avatar
Yan Ru Pei committed
556
557
        for worker in &workers_to_remove {
            tracing::warn!("Removing worker {:?}", worker);
558

559
            self.workers.remove(worker);
560

561
562
563
564
565
566
567
            let requests_to_remove: Vec<RequestId> = self
                .request_to_worker
                .iter()
                .filter(|entry| entry.value() == worker)
                .map(|entry| entry.key().clone())
                .collect();

568
            self.request_to_worker
Yan Ru Pei's avatar
Yan Ru Pei committed
569
                .retain(|_request_id, mapped_worker| mapped_worker != worker);
570
571
572
573

            for request_id in requests_to_remove {
                self.request_to_lora.remove(&request_id);
            }
574
575
        }

Yan Ru Pei's avatar
Yan Ru Pei committed
576
577
        for worker in &workers_to_add {
            tracing::warn!("Adding worker {:?}", worker);
578
579
            self.workers
                .insert(*worker, ActiveSequences::new(self.block_size));
580
581
582
        }
    }

583
584
585
586
587
588
589
590
591
592
593
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        let SequenceRequest {
            request_id,
            token_sequence,
            isl,
            overlap,
            expected_output_tokens,
            worker,
            lora_name,
        } = req;

594
595
596
        if !self.workers.contains_key(&worker) {
            return Err(SequenceError::WorkerNotFound { worker });
        }
597
598
599
600
601
602

        if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
            return Err(SequenceError::DuplicateRequest {
                request_id,
                worker: *existing_worker,
            });
603
604
605
606
607
        }

        if self.replica_sync {
            let event = ActiveSequenceEvent {
                request_id: request_id.clone(),
Yan Ru Pei's avatar
Yan Ru Pei committed
608
                worker,
609
610
611
612
                data: ActiveSequenceEventData::AddRequest {
                    token_sequence: token_sequence.clone(),
                    isl,
                    overlap,
613
                    expected_output_tokens,
614
615
                },
                router_id: self.router_id,
616
                lora_name: lora_name.clone(),
617
            };
618
            self.event_publisher.publish(&event).await?;
619
620
        }

Yan Ru Pei's avatar
Yan Ru Pei committed
621
        self.request_to_worker.insert(request_id.clone(), worker);
622

623
624
625
626
        if let Some(lora) = lora_name {
            self.request_to_lora.insert(request_id.clone(), lora);
        }

627
628
629
630
631
632
        let removed_requests = {
            let mut entry = self
                .workers
                .get_mut(&worker)
                .ok_or(SequenceError::WorkerNotFound { worker })?;
            entry.add_request(
633
634
                request_id,
                token_sequence,
635
                isl,
636
                overlap,
637
                expected_output_tokens,
638
639
            )
        };
640
641
642

        for expired_id in &removed_requests {
            self.request_to_worker.remove(expired_id);
643
            self.request_to_lora.remove(expired_id);
644
645
        }

646
        self.publish_active_load_for_worker(worker);
647

648
        Ok(())
649
650
    }

651
    /// Send a mutation to the worker assigned to a request, optionally publishing
652
    /// a replica-sync event and cleaning up request mappings afterward.
653
    async fn mutate_request_worker(
654
655
656
        &self,
        request_id: &RequestId,
        event_data: ActiveSequenceEventData,
657
        mutate_fn: impl FnOnce(&mut ActiveSequences, &RequestId),
658
659
660
661
662
663
664
665
666
        remove_mapping: bool,
    ) -> Result<(), SequenceError> {
        let worker = self
            .request_to_worker
            .get(request_id)
            .map(|entry| *entry)
            .ok_or_else(|| SequenceError::RequestNotFound {
                request_id: request_id.clone(),
            })?;
667

668
        if self.replica_sync {
669
670
671
672
673
            let lora_name = self
                .request_to_lora
                .get(request_id)
                .map(|entry| entry.value().clone());

674
675
            let event = ActiveSequenceEvent {
                request_id: request_id.clone(),
Yan Ru Pei's avatar
Yan Ru Pei committed
676
                worker,
677
                data: event_data,
678
                router_id: self.router_id,
679
                lora_name,
680
            };
681
            self.event_publisher.publish(&event).await?;
682
683
        }

684
685
686
687
688
689
690
        {
            let mut entry = self
                .workers
                .get_mut(&worker)
                .ok_or(SequenceError::WorkerNotFound { worker })?;
            mutate_fn(&mut entry, request_id);
        }
691

692
693
694
695
        if remove_mapping {
            self.request_to_worker.remove(request_id);
            self.request_to_lora.remove(request_id);
        }
696

697
        self.publish_active_load_for_worker(worker);
698

699
        Ok(())
700
701
    }

702
703
704
705
706
707
708
709
710
711
    /// Free all blocks associated with a request
    ///
    /// Note: This operation is idempotent. Calling it multiple times for the same request
    /// will log a warning but not return an error (double free is allowed).
    pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
        if !self.request_to_worker.contains_key(request_id) {
            tracing::debug!("Request {request_id} not found, already freed (idempotent)");
            return Ok(());
        }

712
        self.mutate_request_worker(
713
714
            request_id,
            ActiveSequenceEventData::Free,
715
716
717
            |seqs, rid| {
                seqs.free(rid);
            },
718
719
720
721
722
            true,
        )
        .await
    }

723
    /// Mark prefill as completed for a request
724
725
726
727
728
729
730
    ///
    /// Note: Calling this multiple times for the same request is allowed and will be a no-op
    /// after the first call (idempotent).
    pub async fn mark_prefill_completed(
        &self,
        request_id: &RequestId,
    ) -> Result<(), SequenceError> {
731
        self.mutate_request_worker(
732
733
            request_id,
            ActiveSequenceEventData::MarkPrefillCompleted,
734
735
736
            |seqs, rid| {
                seqs.mark_prefill_completed(rid);
            },
737
738
739
            false,
        )
        .await
740
741
    }

742
743
744
745
    /// Add an output block with optional fractional decay weight
    ///
    /// This is used during generation to track output blocks as they are created.
    /// The decay_fraction represents how "temporary" the block is based on generation progress.
746
747
748
    // TODO: output blocks are not replicated via replica_sync — add an
    // ActiveSequenceEventData variant if cross-instance accuracy matters.
    pub fn add_output_block(
749
750
751
752
753
754
755
756
757
758
759
760
        &self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        let worker = self
            .request_to_worker
            .get(request_id)
            .map(|entry| *entry)
            .ok_or_else(|| SequenceError::RequestNotFound {
                request_id: request_id.clone(),
            })?;

761
762
763
764
765
766
767
        let success = {
            let mut entry = self
                .workers
                .get_mut(&worker)
                .ok_or(SequenceError::WorkerNotFound { worker })?;
            entry.add_output_block(request_id, decay_fraction)
        };
768
769
770
771
772
773
774

        if !success {
            return Err(SequenceError::RequestNotFound {
                request_id: request_id.clone(),
            });
        }

775
        self.publish_active_load_for_worker(worker);
776
777
778
779

        Ok(())
    }

780
781
782
783
784
    /// Read active blocks/tokens from a worker and publish ActiveLoad metrics.
    /// The NATS publish is spawned as a background task to avoid blocking the caller.
    fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
        let (active_blocks, active_tokens) = {
            let Some(entry) = self.workers.get(&worker) else {
785
786
787
                tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
                return;
            };
788
            (entry.active_blocks(), entry.active_tokens())
789
790
        };

791
792
793
794
795
796
797
        WORKER_LOAD_METRICS.observe(
            worker.worker_id,
            worker.dp_rank,
            self.worker_type,
            active_blocks,
            active_tokens,
        );
798

799
800
801
802
803
804
805
        let active_load = ActiveLoad {
            worker_id: worker.worker_id,
            dp_rank: worker.dp_rank,
            active_decode_blocks: Some(active_blocks as u64),
            active_prefill_tokens: Some(active_tokens as u64),
        };

806
807
808
809
810
811
812
813
        let publisher = self.metrics_publisher.clone();
        tokio::spawn(async move {
            if let Err(e) = publisher.publish(&active_load).await {
                tracing::trace!(
                    "Failed to publish ActiveLoad to NATS for worker {worker:?}: {e:?}"
                );
            }
        });
814
815
    }

816
817
    /// Get the number of workers
    pub fn num_workers(&self) -> usize {
818
        self.workers.len()
819
820
    }

821
822
823
824
825
826
    /// Get the worker type for this router ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.worker_type
    }

827
    /// Query all workers for the number of new blocks that would be added by a token sequence
828
    pub fn new_blocks(
Yan Ru Pei's avatar
Yan Ru Pei committed
829
830
831
        &self,
        token_sequence: Vec<SequenceHash>,
    ) -> HashMap<WorkerWithDpRank, usize> {
832
833
834
835
836
        let mut results = HashMap::with_capacity(self.workers.len());
        for entry in self.workers.iter() {
            results.insert(*entry.key(), entry.value().new_blocks(&token_sequence));
        }
        results
837
838
839
    }

    /// Query all workers for the total number of blocks (new + active) that would be used by a token sequence
840
    pub fn potential_blocks(
841
842
        &self,
        token_sequence: Vec<SequenceHash>,
Yan Ru Pei's avatar
Yan Ru Pei committed
843
    ) -> HashMap<WorkerWithDpRank, usize> {
844
845
846
847
848
849
850
851
        let mut results = HashMap::with_capacity(self.workers.len());
        for entry in self.workers.iter() {
            results.insert(
                *entry.key(),
                entry.value().potential_blocks(&token_sequence),
            );
        }
        results
852
853
    }

854
855
    /// Query all workers for the potential blocks and tokens
    pub fn potential_blocks_and_tokens(
856
        &self,
857
        token_sequence: Option<Vec<SequenceHash>>,
858
        isl: usize,
859
        overlaps: OverlapScores,
Yan Ru Pei's avatar
Yan Ru Pei committed
860
861
862
863
    ) -> (
        HashMap<WorkerWithDpRank, usize>,
        HashMap<WorkerWithDpRank, usize>,
    ) {
864
865
866
        #[cfg(feature = "bench")]
        let start = Instant::now();
        #[cfg(feature = "bench")]
867
        let num_workers = self.workers.len();
868

869
870
        let mut potential_blocks = HashMap::with_capacity(self.workers.len());
        let mut potential_tokens = HashMap::with_capacity(self.workers.len());
871

872
873
        for entry in self.workers.iter() {
            let worker = *entry.key();
874
875
            let overlap = *overlaps.scores.get(&worker).unwrap_or(&0);

876
877
878
879
880
881
            let (blocks, tokens) =
                entry
                    .value()
                    .potential_blocks_and_tokens(token_sequence.as_deref(), isl, overlap);
            potential_blocks.insert(worker, blocks);
            potential_tokens.insert(worker, tokens);
882
883
        }

884
885
886
887
888
889
890
891
892
893
        #[cfg(feature = "bench")]
        {
            let total_elapsed = start.elapsed();
            tracing::info!(
                num_workers,
                total_us = total_elapsed.as_micros() as u64,
                "potential_blocks_and_tokens completed"
            );
        }

894
895
896
        (potential_blocks, potential_tokens)
    }

897
    /// Query all workers for their current number of active blocks
898
899
900
901
902
903
    pub fn active_blocks(&self) -> HashMap<WorkerWithDpRank, usize> {
        let mut results = HashMap::with_capacity(self.workers.len());
        for entry in self.workers.iter() {
            results.insert(*entry.key(), entry.value().active_blocks());
        }
        results
904
    }
905
906

    /// Query all workers for their current number of active tokens
907
908
909
910
911
912
    pub fn active_tokens(&self) -> HashMap<WorkerWithDpRank, usize> {
        let mut results = HashMap::with_capacity(self.workers.len());
        for entry in self.workers.iter() {
            results.insert(*entry.key(), entry.value().active_tokens());
        }
        results
913
    }
914
915
916
917
918
919
920
921
922

    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        let mut counts: HashMap<String, usize> = HashMap::new();
        for entry in self.request_to_lora.iter() {
            let lora_name = entry.value().clone();
            *counts.entry(lora_name).or_insert(0) += 1;
        }
        counts
    }
923
924
925
926
927
}

#[cfg(test)]
mod tests {
    use super::*;
928
    use dynamo_runtime::{DistributedRuntime, Runtime};
929
    use std::sync::Arc;
930

931
932
933
934
935
    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

936
        seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
937
938
939
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

940
        seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
941
942
943
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

944
        seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.free(&"request_2".to_string());
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_3".to_string());
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

961
    #[tokio::test]
962
    #[ignore]
963
    async fn test_multi_worker_cross_instance_sync() -> Result<()> {
964
965
966
        // Initialize logging once
        dynamo_runtime::logging::init();

967
968
        let block_size = 4; // arbitrary block size

969
970
971
        // Create runtime and distributed runtime
        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
972

973
974
        // Create namespace and shared component for both seq_managers
        let namespace = distributed.namespace("test_cross_instance_sync")?;
975
        let component = namespace.component("sequences")?;
976

Yan Ru Pei's avatar
Yan Ru Pei committed
977
978
979
980
981
982
983
984
985
986
        // Create multi-worker sequence managers with:
        // - Worker 0 with dp_size=2 (dp_ranks 0 and 1)
        // - Worker 1 with dp_size=1 (dp_rank 0)
        // This gives us 3 effective workers total to test dp_rank effect
        // Both seq_managers use the same component to ensure event synchronization works
        let mut workers_with_configs = HashMap::new();

        // Create runtime config for worker 0 with dp_size=2
        let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
        config_worker_0.data_parallel_size = 2;
987
        workers_with_configs.insert(0, config_worker_0);
Yan Ru Pei's avatar
Yan Ru Pei committed
988
989
990

        // Create runtime config for worker 1 with dp_size=1 (default)
        let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
991
        workers_with_configs.insert(1, config_worker_1);
Yan Ru Pei's avatar
Yan Ru Pei committed
992

993
994
995
996
997
998
999
        let seq_manager_1 = Arc::new(
            ActiveSequencesMultiWorker::new(
                component.clone(),
                block_size,
                workers_with_configs.clone(),
                true,
                1,
1000
                crate::discovery::WORKER_TYPE_DECODE,
1001
1002
1003
1004
            )
            .await?,
        );
        let seq_manager_2 = Arc::new(
1005
1006
1007
1008
1009
1010
1011
1012
1013
            ActiveSequencesMultiWorker::new(
                component,
                block_size,
                workers_with_configs,
                true,
                2,
                crate::discovery::WORKER_TYPE_DECODE,
            )
            .await?,
1014
        );
1015
1016
1017
1018
1019
1020

        // Give some time for the subscription loops to start
        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;

        // PHASE 1: Add requests using both seq_manager_1 and seq_manager_2

Yan Ru Pei's avatar
Yan Ru Pei committed
1021
        // Add request_0 to worker 0, dp_rank 0: sequence [0, 1, 2]
1022
        seq_manager_1
1023
1024
1025
1026
1027
1028
1029
1030
1031
            .add_request(SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: Some(vec![0, 1, 2]),
                isl: 12,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(0, 0),
                lora_name: None,
            })
1032
            .await?;
1033

Yan Ru Pei's avatar
Yan Ru Pei committed
1034
        // Add request_1 to worker 0, dp_rank 1: sequence [3, 4]
1035
        seq_manager_1
1036
1037
1038
1039
1040
1041
1042
1043
1044
            .add_request(SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: Some(vec![3, 4]),
                isl: 8,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(0, 1),
                lora_name: None,
            })
1045
            .await?;
1046

Yan Ru Pei's avatar
Yan Ru Pei committed
1047
        // Add request_2 to worker 1, dp_rank 0: sequence [0, 1, 2, 3] using seq_manager_2
1048
        seq_manager_2
1049
1050
1051
1052
1053
1054
1055
1056
1057
            .add_request(SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: Some(vec![0, 1, 2, 3]),
                isl: 16,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::new(1, 0),
                lora_name: None,
            })
1058
            .await?;
1059

1060
1061
        // Give some time for synchronization
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
1062

1063
        // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
1064
1065
        let blocks_phase1 = seq_manager_1.active_blocks();
        let tokens_phase1 = seq_manager_1.active_tokens();
1066

Yan Ru Pei's avatar
Yan Ru Pei committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
        // Verify that seq_manager_1 sees all requests including request_2 from seq_manager_2
        // We now have:
        // - Worker 0, dp_rank 0: request_0
        // - Worker 0, dp_rank 1: request_1
        // - Worker 1, dp_rank 0: request_2
        let worker_0_dp0 = WorkerWithDpRank::new(0, 0);
        let worker_0_dp1 = WorkerWithDpRank::new(0, 1);
        let worker_1_dp0 = WorkerWithDpRank::new(1, 0);

1076
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1077
1078
            blocks_phase1[&worker_0_dp0], 3,
            "Worker 0 dp_rank 0 should have 3 active blocks (from request_0)"
1079
        );
1080
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1081
1082
            blocks_phase1[&worker_0_dp1], 2,
            "Worker 0 dp_rank 1 should have 2 active blocks (from request_1)"
1083
1084
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1085
1086
            blocks_phase1[&worker_1_dp0], 4,
            "Worker 1 dp_rank 0 should have 4 active blocks (from request_2 added by seq_manager_2)"
1087
1088
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1089
1090
            tokens_phase1[&worker_0_dp0], 12,
            "Worker 0 dp_rank 0 should have 12 active tokens"
1091
        );
1092
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1093
1094
1095
1096
1097
1098
            tokens_phase1[&worker_0_dp1], 8,
            "Worker 0 dp_rank 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_1_dp0], 16,
            "Worker 1 dp_rank 0 should have 16 active tokens (from request_2 added by seq_manager_2)"
1099
        );
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113

        // PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2

        // Free request_2 (which was added by seq_manager_2) using seq_manager_1
        seq_manager_1.free(&"request_2".to_string()).await?;

        // Free request_0 and request_1 (which were added by seq_manager_1) using seq_manager_2
        seq_manager_2.free(&"request_0".to_string()).await?;
        seq_manager_2.free(&"request_1".to_string()).await?;

        // Give some time for synchronization
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

        // Query seq_manager_2 to verify everything is empty
1114
1115
        let blocks_phase2 = seq_manager_2.active_blocks();
        let tokens_phase2 = seq_manager_2.active_tokens();
1116

Yan Ru Pei's avatar
Yan Ru Pei committed
1117
1118
1119
1120
1121
1122
1123
1124
        // Verify phase 2 results - everything should be empty for all 3 workers
        let all_workers = vec![
            WorkerWithDpRank::new(0, 0),
            WorkerWithDpRank::new(0, 1),
            WorkerWithDpRank::new(1, 0),
        ];

        for worker in all_workers {
1125
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1126
1127
1128
                blocks_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active blocks after all requests freed",
                worker.worker_id, worker.dp_rank
1129
1130
            );
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1131
1132
1133
                tokens_phase2[&worker], 0,
                "Worker (id={}, dp_rank={}) should have 0 active tokens after all requests freed",
                worker.worker_id, worker.dp_rank
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
            );
        }

        Ok(())
    }

    #[tokio::test]
    #[ignore]
    async fn test_multi_worker_no_token_sequence_sync() -> Result<()> {
        // Initialize logging once
        dynamo_runtime::logging::init();

        let block_size = 4; // arbitrary block size

        // Create runtime and distributed runtime
        let runtime = Runtime::from_current()?;
        let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;

        // Create namespace and shared component for both seq_managers
        let namespace = distributed.namespace("test_no_token_seq_sync")?;
1154
        let component = namespace.component("sequences")?;
1155
1156
1157

        // Create multi-worker sequence managers with ALL workers [0, 1, 2]
        // Both use the same component to ensure event synchronization works
Yan Ru Pei's avatar
Yan Ru Pei committed
1158
        let mut workers_with_configs = HashMap::new();
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        workers_with_configs.insert(
            0,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            1,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
        workers_with_configs.insert(
            2,
            crate::local_model::runtime_config::ModelRuntimeConfig::new(),
        );
Yan Ru Pei's avatar
Yan Ru Pei committed
1171

1172
1173
1174
1175
1176
1177
1178
        let seq_manager_1 = Arc::new(
            ActiveSequencesMultiWorker::new(
                component.clone(),
                block_size,
                workers_with_configs.clone(),
                true,
                1,
1179
                crate::discovery::WORKER_TYPE_DECODE,
1180
1181
1182
1183
            )
            .await?,
        );
        let seq_manager_2 = Arc::new(
1184
1185
1186
1187
1188
1189
1190
1191
1192
            ActiveSequencesMultiWorker::new(
                component,
                block_size,
                workers_with_configs,
                true,
                2,
                crate::discovery::WORKER_TYPE_DECODE,
            )
            .await?,
1193
        );
1194
1195
1196
1197
1198
1199
1200
1201

        // Give some time for the subscription loops to start
        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;

        // PHASE 1: Add requests (without token sequences) using both seq_managers

        // Add request_0 to worker 0 with no token sequence
        seq_manager_1
1202
1203
1204
1205
1206
1207
1208
1209
1210
            .add_request(SequenceRequest {
                request_id: "request_0".to_string(),
                token_sequence: None,
                isl: 12,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(0),
                lora_name: None,
            })
1211
1212
1213
1214
            .await?;

        // Add request_1 to worker 1 with no token sequence
        seq_manager_1
1215
1216
1217
1218
1219
1220
1221
1222
1223
            .add_request(SequenceRequest {
                request_id: "request_1".to_string(),
                token_sequence: None,
                isl: 8,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(1),
                lora_name: None,
            })
1224
1225
1226
1227
            .await?;

        // Add request_2 to worker 2 with no token sequence using seq_manager_2
        seq_manager_2
1228
1229
1230
1231
1232
1233
1234
1235
1236
            .add_request(SequenceRequest {
                request_id: "request_2".to_string(),
                token_sequence: None,
                isl: 16,
                overlap: 0,
                expected_output_tokens: None,
                worker: WorkerWithDpRank::from_worker_id(2),
                lora_name: None,
            })
1237
1238
1239
1240
1241
1242
            .await?;

        // Give some time for synchronization
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

        // Query seq_manager_1 to verify it sees all requests including request_2 from seq_manager_2
1243
        let tokens_phase1 = seq_manager_1.active_tokens();
1244
1245

        // Verify that seq_manager_1 sees all requests including request_2 from thread 2
Yan Ru Pei's avatar
Yan Ru Pei committed
1246
1247
1248
1249
        let worker_0 = WorkerWithDpRank::from_worker_id(0);
        let worker_1 = WorkerWithDpRank::from_worker_id(1);
        let worker_2 = WorkerWithDpRank::from_worker_id(2);

1250
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1251
            tokens_phase1[&worker_0], 12,
1252
            "Worker 0 should have 12 active tokens"
1253
1254
        );
        assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1255
1256
1257
1258
1259
            tokens_phase1[&worker_1], 8,
            "Worker 1 should have 8 active tokens"
        );
        assert_eq!(
            tokens_phase1[&worker_2], 16,
1260
            "Worker 2 should have 16 active tokens (from request_2 added by seq_manager_2)"
1261
        );
1262

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        // PHASE 2: Free requests using opposite sequence managers, verify on seq_manager_2

        // Mark prefill completed and free request_2 (which was added by seq_manager_2) using seq_manager_1
        seq_manager_1
            .mark_prefill_completed(&"request_2".to_string())
            .await?;
        seq_manager_1.free(&"request_2".to_string()).await?;

        // Mark prefill completed and free requests 0 and 1 (which were added by seq_manager_1) using seq_manager_2
        seq_manager_2
            .mark_prefill_completed(&"request_0".to_string())
            .await?;
        seq_manager_2
            .mark_prefill_completed(&"request_1".to_string())
            .await?;
        seq_manager_2.free(&"request_0".to_string()).await?;
        seq_manager_2.free(&"request_1".to_string()).await?;

        // Give some time for synchronization
        tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

        // Query seq_manager_2 to verify everything is empty
1285
        let tokens_phase2 = seq_manager_2.active_tokens();
1286
1287
1288

        // Verify phase 2 results - everything should be empty
        for worker_id in 0..=2 {
Yan Ru Pei's avatar
Yan Ru Pei committed
1289
            let worker = WorkerWithDpRank::from_worker_id(worker_id);
1290
            assert_eq!(
Yan Ru Pei's avatar
Yan Ru Pei committed
1291
                tokens_phase2[&worker], 0,
1292
1293
1294
1295
1296
                "Worker {} should have 0 active tokens after all requests freed",
                worker_id
            );
        }

1297
        Ok(())
1298
1299
    }
}