migration.rs 36.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::error::Error as StdError;
5
6
7
8
9
10
use std::sync::Arc;

use anyhow::{Error, Result};
use futures::{stream, stream::StreamExt};

use crate::{
11
    http::service::metrics::Metrics, model_card::ModelDeploymentCard, preprocessor::BackendOutput,
12
    protocols::common::llm_backend::PreprocessedRequest,
13
14
};

15
use dynamo_runtime::error::{self, BackendError, DynamoError, ErrorType};
16
17
use dynamo_runtime::pipeline::{
    AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
18
    ServerStreamingEngine, SingleIn, async_trait,
19
};
20
use dynamo_runtime::protocols::{annotated::Annotated, maybe_error::MaybeError};
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
/// Check if an error chain indicates the request should be migrated.
fn is_migratable(err: &(dyn StdError + 'static)) -> bool {
    const MIGRATABLE: &[ErrorType] = &[
        ErrorType::CannotConnect,
        ErrorType::Disconnected,
        ErrorType::ConnectionTimeout,
        ErrorType::Backend(BackendError::EngineShutdown),
    ];
    const NON_MIGRATABLE: &[ErrorType] = &[
        // Future: ErrorType::Cancelled, ErrorType::ValidationError, etc.
    ];
    error::match_error_chain(err, MIGRATABLE, NON_MIGRATABLE)
}

36
37
pub struct Migration {
    migration_limit: u32,
38
39
    model_name: Arc<String>,
    metrics: Arc<Metrics>,
40
41
42
}

impl Migration {
43
44
    pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> {
        tracing::debug!("model {} migration limit {}", model_name, migration_limit);
45
        Arc::new(Self {
46
47
            migration_limit,
            model_name: Arc::new(model_name),
48
            metrics,
49
        })
50
    }
51
52
53
54
55
56
57
58

    pub fn from_mdc(
        mdc: &ModelDeploymentCard,
        migration_limit: u32,
        metrics: Arc<Metrics>,
    ) -> Arc<Self> {
        Self::new(migration_limit, mdc.display_name.clone(), metrics)
    }
59
60
61
62
63
64
}

#[async_trait]
impl
    Operator<
        SingleIn<PreprocessedRequest>,
65
        ManyOut<Annotated<BackendOutput>>,
66
        SingleIn<PreprocessedRequest>,
67
        ManyOut<Annotated<BackendOutput>>,
68
69
70
71
72
    > for Migration
{
    async fn generate(
        &self,
        request: SingleIn<PreprocessedRequest>,
73
74
        next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
    ) -> Result<ManyOut<Annotated<BackendOutput>>> {
75
76
        let (preprocessed_request, context) = request.transfer(());
        let engine_ctx = context.context();
77
        let engine_ctx_ = engine_ctx.clone();
78
79
80
81
82
83
84
85
86
        let retry_manager = RetryManager::build(
            engine_ctx,
            preprocessed_request,
            next,
            self.migration_limit,
            self.model_name.clone(),
            self.metrics.clone(),
        )
        .await?;
87
88
89
90
91
        let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
            retry_manager
                .next()
                .await
                .map(|response| (response, retry_manager))
92
93
        })
        .fuse();
94
        Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
95
96
97
98
    }
}

struct RetryManager {
99
    context: Arc<dyn AsyncEngineContext>,
100
    request: PreprocessedRequest,
101
102
    next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
    next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
103
    retries_left: u32,
104
105
    model_name: Arc<String>,
    metrics: Arc<Metrics>,
106
107
108
109
}

impl RetryManager {
    pub async fn build(
110
        context: Arc<dyn AsyncEngineContext>,
111
        preprocessed_request: PreprocessedRequest,
112
        next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
113
        retries_left: u32,
114
115
        model_name: Arc<String>,
        metrics: Arc<Metrics>,
116
117
    ) -> Result<Self> {
        let mut slf = Self {
118
            context,
119
120
121
122
            request: preprocessed_request,
            next_generate: next,
            next_stream: None,
            retries_left: retries_left + 1, // +1 to account for the initial attempt
123
124
            model_name,
            metrics,
125
126
127
128
129
        };
        slf.new_stream().await?;
        Ok(slf)
    }

130
    pub async fn next(&mut self) -> Option<Annotated<BackendOutput>> {
131
132
133
134
135
        loop {
            let response_stream = match self.next_stream.as_mut() {
                Some(stream) => stream,
                None => {
                    tracing::error!("next() called with next_stream is None - should not happen");
136
                    return Some(Annotated::from_err(DynamoError::msg("next_stream is None")));
137
138
139
                }
            };
            if let Some(response) = response_stream.next().await {
140
                // Check if this is a migratable error that should trigger stream recreation.
141
                if let Some(err) = response.err()
142
                    && is_migratable(&err)
143
                {
144
                    tracing::warn!("Stream disconnected... recreating stream... {}", err);
145
                    self.metrics.inc_migration_ongoing_request(&self.model_name);
146
147
148
149
                    if let Err(err) = self.new_stream().await {
                        tracing::warn!("Cannot recreate stream: {:#}", err);
                    } else {
                        continue;
150
151
152
153
154
155
156
157
158
159
                    }
                }
                self.track_response(&response);
                return Some(response);
            }
            return None;
        }
    }

    async fn new_stream(&mut self) -> Result<()> {
160
        let mut response_stream: Option<Result<ManyOut<Annotated<BackendOutput>>>> = None;
161
162
        while self.retries_left > 0 {
            self.retries_left -= 1;
163
164
            let request = Context::with_id(self.request.clone(), self.context.id().to_string());
            self.context.link_child(request.context());
165
166
167
168
169
170
171
            if self.context.is_stopped() || self.context.is_killed() {
                tracing::debug!("Abort creating new stream after context is stopped or killed");
                return Err(Error::msg(format!(
                    "Context id {} is stopped or killed",
                    self.context.id()
                )));
            }
172
            response_stream = Some(self.next_generate.generate(request).await);
173
            if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
174
                && is_migratable(err.as_ref())
175
            {
176
                tracing::warn!("Creating new stream... retrying... {}", err);
177
                self.metrics.inc_migration_new_request(&self.model_name);
178
                continue;
179
180
181
182
183
184
185
186
            }
            break;
        }
        match response_stream {
            Some(Ok(next_stream)) => {
                self.next_stream = Some(next_stream);
                Ok(())
            }
187
            Some(Err(err)) => Err(err), // should propagate original error if any
188
            None => Err(Error::msg(
189
                "Migration limit exhausted", // should propagate original error if any
190
191
192
193
            )),
        }
    }

194
    fn track_response(&mut self, response: &Annotated<BackendOutput>) {
195
196
197
198
199
200
201
        if self.retries_left == 0 {
            return;
        }
        let llm_engine_output = match response.data.as_ref() {
            Some(output) => output,
            None => return,
        };
202
203
204
205
        if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
            self.request.stop_conditions.max_tokens =
                Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
        }
206
207
208
209
210
211
212
213
214
        for token_id in llm_engine_output.token_ids.iter() {
            self.request.token_ids.push(*token_id);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
215
    use crate::http::service::metrics::Metrics;
Greg Clark's avatar
Greg Clark committed
216
    use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
217
    use dynamo_runtime::error::{DynamoError, ErrorType};
218
    use dynamo_runtime::pipeline::AsyncEngine;
219
    use dynamo_runtime::pipeline::context::Controller;
220
221
222
    use std::sync::atomic::{AtomicU32, Ordering};
    use tokio::sync::mpsc;

223
224
    const TEST_MODEL: &str = "test-model";

225
    // Helper to create a mock preprocessed request
226
    fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
227
228
229
230
        PreprocessedRequest::builder()
            .model("mock".to_string())
            .token_ids(vec![1, 2, 3])
            .stop_conditions(StopConditions {
231
232
                max_tokens: Some(max_tokens),
                ..Default::default()
233
234
235
236
237
238
239
            })
            .sampling_options(SamplingOptions::default())
            .output_options(OutputOptions::default())
            .eos_token_ids(vec![])
            .annotations(vec![])
            .build()
            .unwrap()
240
241
242
    }

    // Helper to create mock LLM engine output
243
244
    fn create_mock_output(token_id: u32) -> Annotated<BackendOutput> {
        Annotated::from_data(BackendOutput {
245
            token_ids: vec![token_id],
246
247
            tokens: vec![],
            text: Some(format!("token_{token_id}")),
248
249
            cum_log_probs: None,
            log_probs: None,
Greg Clark's avatar
Greg Clark committed
250
            top_logprobs: None,
251
            finish_reason: None,
252
            stop_reason: None,
253
            index: None,
254
            disaggregated_params: None,
255
            completion_usage: None,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        })
    }

    #[derive(Debug, Clone)]
    enum MockBehavior {
        /// Always succeeds with all responses
        Success,
        /// Fails on first call with NoResponders error, then succeeds on subsequent calls
        FailThenSuccess,
        /// Succeeds initially, fails mid-stream with specific error, then succeeds on retry
        MidStreamFail { fail_after: usize },
        /// Succeeds initially, fails mid-stream with specific error, then always fails on retry attempts
        MidStreamFailAlways { fail_after: usize },
        /// Succeeds initially, fails mid-stream, then always fails with stream error on retry attempts
        MidStreamFailAlwaysStreamError { fail_after: usize },
        /// Always fails with NoResponders error (same as FailThenSuccess first call)
        AlwaysFail,
    }

    // Unified mock server streaming engine that can simulate different scenarios
    struct MockEngine {
        behavior: MockBehavior,
        num_responses: usize,
        token_offset: u32,
        call_count: Arc<AtomicU32>,
281
        context_id: String,
282
283
284
    }

    impl MockEngine {
285
286
287
288
289
290
        fn new(
            behavior: MockBehavior,
            num_responses: usize,
            token_offset: u32,
            context_id: String,
        ) -> Self {
291
292
293
294
295
            Self {
                behavior,
                num_responses,
                token_offset,
                call_count: Arc::new(AtomicU32::new(0)),
296
                context_id,
297
298
299
300
301
302
            }
        }
    }

    #[async_trait]
    impl
303
304
        AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, anyhow::Error>
        for MockEngine
305
306
307
308
    {
        async fn generate(
            &self,
            request: SingleIn<PreprocessedRequest>,
309
        ) -> Result<ManyOut<Annotated<BackendOutput>>> {
310
            let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
311
312
313
314
315
316
317
318
            let (preprocessed_request, context) = request.transfer(());

            // Assert that the context_id matches the expected one
            assert_eq!(
                context.id().to_string(),
                self.context_id,
                "Context ID mismatch"
            );
319
320
321
322
323
324
325
326

            // Calculate how many responses we've already generated based on request token_ids
            // Initial request has [1, 2, 3], so anything beyond that are generated responses
            let initial_tokens = 3; // [1, 2, 3]
            let responses_already_generated = preprocessed_request
                .token_ids
                .len()
                .saturating_sub(initial_tokens);
327
328
329
330
331
332
333
334
335
336
337
338

            // Assert that max_tokens reflects the expected remaining tokens
            let expected_max_tokens =
                self.num_responses
                    .saturating_sub(responses_already_generated) as u32;
            assert_eq!(
                preprocessed_request.stop_conditions.max_tokens,
                Some(expected_max_tokens),
                "max_tokens should be {} but got {:?}",
                expected_max_tokens,
                preprocessed_request.stop_conditions.max_tokens
            );
339
340
341
342
343
344
345
346
347
348

            match &self.behavior {
                MockBehavior::Success => {
                    // Always succeed with remaining responses
                    self.send_responses(responses_already_generated, self.num_responses)
                        .await
                }
                MockBehavior::FailThenSuccess => {
                    if call_num == 0 {
                        // First call - return "No responders available" error to trigger retry
349
350
351
352
353
354
                        return Err(anyhow::anyhow!(
                            DynamoError::builder()
                                .error_type(ErrorType::CannotConnect)
                                .message("no responders")
                                .build()
                        ));
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
                    } else {
                        // Subsequent calls - succeed with remaining responses
                        self.send_responses(responses_already_generated, self.num_responses)
                            .await
                    }
                }
                MockBehavior::MidStreamFail { fail_after } => {
                    let (tx, rx) = mpsc::channel(1);
                    let token_offset = self.token_offset;
                    let fail_after = *fail_after;
                    let num_responses = self.num_responses;

                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
378
379
380
381
382
383
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
                            let _ = tx.send(error_response).await;
                        });
                    } else {
                        // Second call - send remaining responses from where we left off
                        tokio::spawn(async move {
                            for i in responses_already_generated..num_responses {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                        });
                    }

                    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
399
                    let ctx = Arc::new(Controller::new(self.context_id.clone()));
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
                    Ok(dynamo_runtime::pipeline::ResponseStream::new(
                        Box::pin(stream),
                        ctx,
                    ))
                }
                MockBehavior::MidStreamFailAlways { fail_after } => {
                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        let (tx, rx) = mpsc::channel(1);
                        let token_offset = self.token_offset;
                        let fail_after = *fail_after;
                        let num_responses = self.num_responses;

                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
422
423
424
425
426
427
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
428
429
430
431
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
432
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
433
434
435
436
437
438
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    } else {
                        // Subsequent calls - always fail with NoResponders error (same as AlwaysFail)
439
440
441
442
443
444
                        Err(anyhow::anyhow!(
                            DynamoError::builder()
                                .error_type(ErrorType::CannotConnect)
                                .message("no responders")
                                .build()
                        ))
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
                    }
                }
                MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
                    let (tx, rx) = mpsc::channel(1);
                    let token_offset = self.token_offset;
                    let fail_after = *fail_after;
                    let num_responses = self.num_responses;

                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
464
465
466
467
468
469
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
470
471
472
473
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
474
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
475
476
477
478
479
480
481
482
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    } else {
                        // Subsequent calls - immediately send stream error (no successful responses)
                        tokio::spawn(async move {
                            // Send the stream error immediately
483
484
485
486
487
488
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
489
490
491
492
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
493
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
494
495
496
497
498
499
500
501
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    }
                }
                MockBehavior::AlwaysFail => {
                    // Always fail with NoResponders error (same as FailThenSuccess first call)
502
503
504
505
506
507
                    Err(anyhow::anyhow!(
                        DynamoError::builder()
                            .error_type(ErrorType::CannotConnect)
                            .message("no responders")
                            .build()
                    ))
508
509
510
511
512
513
514
515
516
517
                }
            }
        }
    }

    impl MockEngine {
        async fn send_responses(
            &self,
            start: usize,
            end: usize,
518
        ) -> Result<ManyOut<Annotated<BackendOutput>>> {
519
520
521
522
523
524
525
526
527
528
529
530
531
            let (tx, rx) = mpsc::channel(1);
            let token_offset = self.token_offset;

            tokio::spawn(async move {
                for i in start..end {
                    let response = create_mock_output(token_offset + 1 + i as u32);
                    if tx.send(response).await.is_err() {
                        break;
                    }
                }
            });

            let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
532
            let ctx = Arc::new(Controller::new(self.context_id.clone()));
533
534
535
536
537
538
539
540
541
542
543
544
545
            Ok(dynamo_runtime::pipeline::ResponseStream::new(
                Box::pin(stream),
                ctx,
            ))
        }
    }

    /// Test case 1: No migration needed
    /// Tests the normal case where the RetryManager successfully processes all responses
    /// from a single stream without any failures or need for retries/migration.
    /// Expected behavior: All 10 responses should be received successfully.
    #[tokio::test]
    async fn test_retry_manager_no_migration() {
546
        dynamo_runtime::logging::init();
547
        let context_id = uuid::Uuid::new_v4().to_string();
548
        let request = create_mock_request(10);
549
550
551
552
553
554
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::Success,
            10,
            100,
            context_id.clone(),
        ));
555
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
556
557
            mock_engine;

558
        let ctx = Arc::new(Controller::new(context_id.clone()));
559
560
561
562
563
564
565
566
567
568
569
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            0,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
570
571
572
573
574
575
576
577
578
579
580
581
582

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        assert_eq!(responses.len(), 10);
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
583
584
585

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
586
587
588
589
590
591
592
593
594
595
    }

    /// Test case 2: New request migration
    /// Tests the scenario where a worker becomes unreachable for new requests initially,
    /// triggering the RetryManager to retry the request. The MockEngine with FailThenSuccess
    /// fails on the first call with a "No responders available" error, then succeeds
    /// on subsequent calls, simulating a worker becoming available after initial failure.
    /// Expected behavior: All 10 responses should be received successfully after retry.
    #[tokio::test]
    async fn test_retry_manager_new_request_migration() {
596
        dynamo_runtime::logging::init();
597
        let context_id = uuid::Uuid::new_v4().to_string();
598
        let request = create_mock_request(10);
599
600
601
602
603
604
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::FailThenSuccess,
            10,
            100,
            context_id.clone(),
        ));
605
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
606
607
            mock_engine;

608
        let ctx = Arc::new(Controller::new(context_id.clone()));
609
610
611
612
613
614
615
616
617
618
619
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
620
621
622
623
624
625
626
627
628
629
630
631
632

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        assert_eq!(responses.len(), 10);
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
633
634
635

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 1);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
636
637
638
639
640
641
642
643
644
645
    }

    /// Test case 3: Ongoing request migration
    /// Tests the scenario where a worker fails mid-stream during an ongoing request.
    /// This simulates a connection being lost after partial response delivery, requiring
    /// the RetryManager to detect the failure (via "Stream ended before generation completed" error),
    /// create a new stream, and continue from where it left off.
    /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration() {
646
647
        dynamo_runtime::logging::init();

648
        let context_id = uuid::Uuid::new_v4().to_string();
649
        let request = create_mock_request(10);
650
651
652
653
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFail { fail_after: 5 },
            10,
            100,
654
            context_id.clone(),
655
        ));
656
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
657
658
            mock_engine;

659
        let ctx = Arc::new(Controller::new(context_id.clone()));
660
661
662
663
664
665
666
667
668
669
670
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received all 10 responses (5 from first stream + 5 from second stream)
        assert_eq!(responses.len(), 10);

        // Check that we received responses from both streams
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
687
688
689

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
690
691
692
693
694
695
696
697
    }

    /// Test case 4: New request migration - indefinite failure
    /// Tests the scenario where a worker becomes unreachable for new requests indefinitely.
    /// The RetryManager should exhaust all retries and return the original error from the first attempt.
    /// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
    #[tokio::test]
    async fn test_retry_manager_new_request_migration_indefinite_failure() {
698
        dynamo_runtime::logging::init();
699
        let context_id = uuid::Uuid::new_v4().to_string();
700
        let request = create_mock_request(0);
701
702
703
704
705
706
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::AlwaysFail,
            0,
            100,
            context_id.clone(),
        ));
707
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
708
709
710
            mock_engine;

        // Should fail to build due to initial stream creation failure after exhausting all 3 retries
711
        let ctx = Arc::new(Controller::new(context_id.clone()));
712
713
714
715
716
717
718
719
720
721
        let metrics = Arc::new(Metrics::new());
        let retry_manager_result = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await;
722
723
724
725
726

        assert!(retry_manager_result.is_err());
        if let Err(error) = retry_manager_result {
            assert!(error.to_string().contains("no responders"));
        }
727
728
729

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 4); // 3 retries + 1 final failure
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
730
731
732
733
734
735
736
737
    }

    /// Test case 5: Ongoing request migration - indefinite failure
    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests.
    /// The RetryManager should exhaust all retries and return the original stream disconnection error.
    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
738
        dynamo_runtime::logging::init();
739
        let context_id = uuid::Uuid::new_v4().to_string();
740
        let request = create_mock_request(10);
741
742
743
744
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFailAlways { fail_after: 3 },
            10,
            100,
745
            context_id.clone(),
746
        ));
747
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
748
749
            mock_engine;

750
        let ctx = Arc::new(Controller::new(context_id.clone()));
751
752
753
754
755
756
757
758
759
760
761
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        ) // 3 retries
        .await
        .expect("Failed to build RetryManager");
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780

        let mut responses = Vec::new();

        // Collect all responses (both successful and error responses)
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received 4 total responses: 3 successful + 1 error
        assert_eq!(responses.len(), 4);

        // First 3 responses should be successful with tokens 101, 102, 103
        for (i, response) in responses[0..3].iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
            }
        }

781
        // 4th response should be a Disconnected error after retries are exhausted
782
        let error_response = &responses[3];
783
784
        let err = error_response.err().expect("expected error response");
        assert_eq!(err.error_type(), ErrorType::Disconnected);
785
786
787

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 3); // 2 retries + 1 final failure
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1); // initial ongoing failure retry
788
789
790
791
792
793
794
795
    }

    /// Test case 6: Ongoing request migration - indefinite failure with stream errors
    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests,
    /// and all retry attempts also fail with stream errors instead of NATS errors.
    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
796
        dynamo_runtime::logging::init();
797
        let context_id = uuid::Uuid::new_v4().to_string();
798
        let request = create_mock_request(10);
799
800
801
802
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
            10,
            100,
803
            context_id.clone(),
804
        ));
805
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
806
807
            mock_engine;

808
        let ctx = Arc::new(Controller::new(context_id.clone()));
809
810
811
812
813
814
815
816
817
818
819
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        ) // 3 retries
        .await
        .expect("Failed to build RetryManager");
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838

        let mut responses = Vec::new();

        // Collect all responses (both successful and error responses)
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received 4 total responses: 3 successful + 1 error
        assert_eq!(responses.len(), 4);

        // First 3 responses should be successful with tokens 101, 102, 103
        for (i, response) in responses[0..3].iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
            }
        }

839
        // 4th response should be a Disconnected error after retries are exhausted
840
        let error_response = &responses[3];
841
842
        let err = error_response.err().expect("expected error response");
        assert_eq!(err.error_type(), ErrorType::Disconnected);
843
844
845

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 4); // 3 retries + 1 final failure
846
    }
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862

    /// Test case 7: Request cancelled when creating new stream
    /// Tests the scenario where context.stop_generating() is called when creating a new stream.
    /// The RetryManager should detect that the context is stopped and abort creating new streams.
    /// Expected behavior: Should fail to build RetryManager with "Context is stopped or killed" error.
    #[tokio::test]
    async fn test_retry_manager_context_stopped_before_stream() {
        dynamo_runtime::logging::init();
        let context_id = uuid::Uuid::new_v4().to_string();
        let request = create_mock_request(10);
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::Success,
            10,
            100,
            context_id.clone(),
        ));
863
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
864
865
866
867
868
869
870
871
            mock_engine;

        let ctx = Arc::new(Controller::new(context_id.clone()));

        // Stop the context before building RetryManager
        ctx.stop_generating();

        // Should fail to build due to stopped context
872
873
874
875
876
877
878
879
880
881
        let metrics = Arc::new(Metrics::new());
        let retry_manager_result = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await;
882
883
884
885
886
887
888
889
890

        assert!(retry_manager_result.is_err());
        if let Err(error) = retry_manager_result {
            assert!(
                error
                    .to_string()
                    .contains(&format!("Context id {} is stopped or killed", context_id))
            );
        }
891
892
893

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
894
    }
895
}