migration.rs 36.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::error::Error as StdError;
5
6
7
8
9
10
use std::sync::Arc;

use anyhow::{Error, Result};
use futures::{stream, stream::StreamExt};

use crate::{
11
    http::service::metrics::Metrics, model_card::ModelDeploymentCard, preprocessor::BackendOutput,
12
    protocols::common::llm_backend::PreprocessedRequest,
13
14
};

15
use dynamo_runtime::error::{self, BackendError, DynamoError, ErrorType};
16
17
use dynamo_runtime::pipeline::{
    AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
18
    ServerStreamingEngine, SingleIn, async_trait,
19
};
20
use dynamo_runtime::protocols::{annotated::Annotated, maybe_error::MaybeError};
21

22
23
24
25
26
27
28
29
/// Check if an error chain indicates the request should be migrated.
fn is_migratable(err: &(dyn StdError + 'static)) -> bool {
    const MIGRATABLE: &[ErrorType] = &[
        ErrorType::CannotConnect,
        ErrorType::Disconnected,
        ErrorType::ConnectionTimeout,
        ErrorType::Backend(BackendError::EngineShutdown),
    ];
30
    const NON_MIGRATABLE: &[ErrorType] = &[ErrorType::Cancelled, ErrorType::ResourceExhausted];
31
32
33
    error::match_error_chain(err, MIGRATABLE, NON_MIGRATABLE)
}

34
35
pub struct Migration {
    migration_limit: u32,
36
37
    model_name: Arc<String>,
    metrics: Arc<Metrics>,
38
39
40
}

impl Migration {
41
42
    pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> {
        tracing::debug!("model {} migration limit {}", model_name, migration_limit);
43
        Arc::new(Self {
44
45
            migration_limit,
            model_name: Arc::new(model_name),
46
            metrics,
47
        })
48
    }
49
50
51
52
53
54
55
56

    pub fn from_mdc(
        mdc: &ModelDeploymentCard,
        migration_limit: u32,
        metrics: Arc<Metrics>,
    ) -> Arc<Self> {
        Self::new(migration_limit, mdc.display_name.clone(), metrics)
    }
57
58
59
60
61
62
}

#[async_trait]
impl
    Operator<
        SingleIn<PreprocessedRequest>,
63
        ManyOut<Annotated<BackendOutput>>,
64
        SingleIn<PreprocessedRequest>,
65
        ManyOut<Annotated<BackendOutput>>,
66
67
68
69
70
    > for Migration
{
    async fn generate(
        &self,
        request: SingleIn<PreprocessedRequest>,
71
72
        next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
    ) -> Result<ManyOut<Annotated<BackendOutput>>> {
73
74
        let (preprocessed_request, context) = request.transfer(());
        let engine_ctx = context.context();
75
        let engine_ctx_ = engine_ctx.clone();
76
77
78
79
80
81
82
83
84
        let retry_manager = RetryManager::build(
            engine_ctx,
            preprocessed_request,
            next,
            self.migration_limit,
            self.model_name.clone(),
            self.metrics.clone(),
        )
        .await?;
85
86
87
88
89
        let response_stream = stream::unfold(retry_manager, move |mut retry_manager| async move {
            retry_manager
                .next()
                .await
                .map(|response| (response, retry_manager))
90
91
        })
        .fuse();
92
        Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx_))
93
94
95
96
    }
}

struct RetryManager {
97
    context: Arc<dyn AsyncEngineContext>,
98
    request: PreprocessedRequest,
99
100
    next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
    next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
101
    retries_left: u32,
102
103
    model_name: Arc<String>,
    metrics: Arc<Metrics>,
104
105
106
107
}

impl RetryManager {
    pub async fn build(
108
        context: Arc<dyn AsyncEngineContext>,
109
        preprocessed_request: PreprocessedRequest,
110
        next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
111
        retries_left: u32,
112
113
        model_name: Arc<String>,
        metrics: Arc<Metrics>,
114
115
    ) -> Result<Self> {
        let mut slf = Self {
116
            context,
117
118
119
120
            request: preprocessed_request,
            next_generate: next,
            next_stream: None,
            retries_left: retries_left + 1, // +1 to account for the initial attempt
121
122
            model_name,
            metrics,
123
124
125
126
127
        };
        slf.new_stream().await?;
        Ok(slf)
    }

128
    pub async fn next(&mut self) -> Option<Annotated<BackendOutput>> {
129
130
131
132
133
        loop {
            let response_stream = match self.next_stream.as_mut() {
                Some(stream) => stream,
                None => {
                    tracing::error!("next() called with next_stream is None - should not happen");
134
                    return Some(Annotated::from_err(DynamoError::msg("next_stream is None")));
135
136
137
                }
            };
            if let Some(response) = response_stream.next().await {
138
                // Check if this is a migratable error that should trigger stream recreation.
139
                if let Some(err) = response.err()
140
                    && is_migratable(&err)
141
                {
142
                    tracing::warn!("Stream disconnected... recreating stream... {}", err);
143
                    self.metrics.inc_migration_ongoing_request(&self.model_name);
144
145
146
147
                    if let Err(err) = self.new_stream().await {
                        tracing::warn!("Cannot recreate stream: {:#}", err);
                    } else {
                        continue;
148
149
150
151
152
153
154
155
156
157
                    }
                }
                self.track_response(&response);
                return Some(response);
            }
            return None;
        }
    }

    async fn new_stream(&mut self) -> Result<()> {
158
        let mut response_stream: Option<Result<ManyOut<Annotated<BackendOutput>>>> = None;
159
160
        while self.retries_left > 0 {
            self.retries_left -= 1;
161
162
            let request = Context::with_id(self.request.clone(), self.context.id().to_string());
            self.context.link_child(request.context());
163
164
165
166
167
168
169
            if self.context.is_stopped() || self.context.is_killed() {
                tracing::debug!("Abort creating new stream after context is stopped or killed");
                return Err(Error::msg(format!(
                    "Context id {} is stopped or killed",
                    self.context.id()
                )));
            }
170
            response_stream = Some(self.next_generate.generate(request).await);
171
            if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
172
                && is_migratable(err.as_ref())
173
            {
174
                tracing::warn!("Creating new stream... retrying... {}", err);
175
                self.metrics.inc_migration_new_request(&self.model_name);
176
                continue;
177
178
179
180
181
182
183
184
            }
            break;
        }
        match response_stream {
            Some(Ok(next_stream)) => {
                self.next_stream = Some(next_stream);
                Ok(())
            }
185
            Some(Err(err)) => Err(err), // should propagate original error if any
186
            None => Err(Error::msg(
187
                "Migration limit exhausted", // should propagate original error if any
188
189
190
191
            )),
        }
    }

192
    fn track_response(&mut self, response: &Annotated<BackendOutput>) {
193
194
195
196
197
198
199
        if self.retries_left == 0 {
            return;
        }
        let llm_engine_output = match response.data.as_ref() {
            Some(output) => output,
            None => return,
        };
200
201
202
203
        if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
            self.request.stop_conditions.max_tokens =
                Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
        }
204
205
206
207
208
209
210
211
212
        for token_id in llm_engine_output.token_ids.iter() {
            self.request.token_ids.push(*token_id);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
213
    use crate::http::service::metrics::Metrics;
Greg Clark's avatar
Greg Clark committed
214
    use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
215
    use dynamo_runtime::error::{DynamoError, ErrorType};
216
    use dynamo_runtime::pipeline::AsyncEngine;
217
    use dynamo_runtime::pipeline::context::Controller;
218
219
220
    use std::sync::atomic::{AtomicU32, Ordering};
    use tokio::sync::mpsc;

221
222
    const TEST_MODEL: &str = "test-model";

223
    // Helper to create a mock preprocessed request
224
    fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
225
226
227
228
        PreprocessedRequest::builder()
            .model("mock".to_string())
            .token_ids(vec![1, 2, 3])
            .stop_conditions(StopConditions {
229
230
                max_tokens: Some(max_tokens),
                ..Default::default()
231
232
233
234
235
236
237
            })
            .sampling_options(SamplingOptions::default())
            .output_options(OutputOptions::default())
            .eos_token_ids(vec![])
            .annotations(vec![])
            .build()
            .unwrap()
238
239
240
    }

    // Helper to create mock LLM engine output
241
242
    fn create_mock_output(token_id: u32) -> Annotated<BackendOutput> {
        Annotated::from_data(BackendOutput {
243
            token_ids: vec![token_id],
244
245
            tokens: vec![],
            text: Some(format!("token_{token_id}")),
246
247
            cum_log_probs: None,
            log_probs: None,
Greg Clark's avatar
Greg Clark committed
248
            top_logprobs: None,
249
            finish_reason: None,
250
            stop_reason: None,
251
            index: None,
252
            disaggregated_params: None,
253
            completion_usage: None,
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        })
    }

    #[derive(Debug, Clone)]
    enum MockBehavior {
        /// Always succeeds with all responses
        Success,
        /// Fails on first call with NoResponders error, then succeeds on subsequent calls
        FailThenSuccess,
        /// Succeeds initially, fails mid-stream with specific error, then succeeds on retry
        MidStreamFail { fail_after: usize },
        /// Succeeds initially, fails mid-stream with specific error, then always fails on retry attempts
        MidStreamFailAlways { fail_after: usize },
        /// Succeeds initially, fails mid-stream, then always fails with stream error on retry attempts
        MidStreamFailAlwaysStreamError { fail_after: usize },
        /// Always fails with NoResponders error (same as FailThenSuccess first call)
        AlwaysFail,
    }

    // Unified mock server streaming engine that can simulate different scenarios
    struct MockEngine {
        behavior: MockBehavior,
        num_responses: usize,
        token_offset: u32,
        call_count: Arc<AtomicU32>,
279
        context_id: String,
280
281
282
    }

    impl MockEngine {
283
284
285
286
287
288
        fn new(
            behavior: MockBehavior,
            num_responses: usize,
            token_offset: u32,
            context_id: String,
        ) -> Self {
289
290
291
292
293
            Self {
                behavior,
                num_responses,
                token_offset,
                call_count: Arc::new(AtomicU32::new(0)),
294
                context_id,
295
296
297
298
299
300
            }
        }
    }

    #[async_trait]
    impl
301
302
        AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, anyhow::Error>
        for MockEngine
303
304
305
306
    {
        async fn generate(
            &self,
            request: SingleIn<PreprocessedRequest>,
307
        ) -> Result<ManyOut<Annotated<BackendOutput>>> {
308
            let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
309
310
311
312
313
314
315
316
            let (preprocessed_request, context) = request.transfer(());

            // Assert that the context_id matches the expected one
            assert_eq!(
                context.id().to_string(),
                self.context_id,
                "Context ID mismatch"
            );
317
318
319
320
321
322
323
324

            // Calculate how many responses we've already generated based on request token_ids
            // Initial request has [1, 2, 3], so anything beyond that are generated responses
            let initial_tokens = 3; // [1, 2, 3]
            let responses_already_generated = preprocessed_request
                .token_ids
                .len()
                .saturating_sub(initial_tokens);
325
326
327
328
329
330
331
332
333
334
335
336

            // Assert that max_tokens reflects the expected remaining tokens
            let expected_max_tokens =
                self.num_responses
                    .saturating_sub(responses_already_generated) as u32;
            assert_eq!(
                preprocessed_request.stop_conditions.max_tokens,
                Some(expected_max_tokens),
                "max_tokens should be {} but got {:?}",
                expected_max_tokens,
                preprocessed_request.stop_conditions.max_tokens
            );
337
338
339
340
341
342
343
344
345
346

            match &self.behavior {
                MockBehavior::Success => {
                    // Always succeed with remaining responses
                    self.send_responses(responses_already_generated, self.num_responses)
                        .await
                }
                MockBehavior::FailThenSuccess => {
                    if call_num == 0 {
                        // First call - return "No responders available" error to trigger retry
347
348
349
350
351
352
                        return Err(anyhow::anyhow!(
                            DynamoError::builder()
                                .error_type(ErrorType::CannotConnect)
                                .message("no responders")
                                .build()
                        ));
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                    } else {
                        // Subsequent calls - succeed with remaining responses
                        self.send_responses(responses_already_generated, self.num_responses)
                            .await
                    }
                }
                MockBehavior::MidStreamFail { fail_after } => {
                    let (tx, rx) = mpsc::channel(1);
                    let token_offset = self.token_offset;
                    let fail_after = *fail_after;
                    let num_responses = self.num_responses;

                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
376
377
378
379
380
381
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                            let _ = tx.send(error_response).await;
                        });
                    } else {
                        // Second call - send remaining responses from where we left off
                        tokio::spawn(async move {
                            for i in responses_already_generated..num_responses {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                        });
                    }

                    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
397
                    let ctx = Arc::new(Controller::new(self.context_id.clone()));
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                    Ok(dynamo_runtime::pipeline::ResponseStream::new(
                        Box::pin(stream),
                        ctx,
                    ))
                }
                MockBehavior::MidStreamFailAlways { fail_after } => {
                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        let (tx, rx) = mpsc::channel(1);
                        let token_offset = self.token_offset;
                        let fail_after = *fail_after;
                        let num_responses = self.num_responses;

                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
420
421
422
423
424
425
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
426
427
428
429
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
430
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
431
432
433
434
435
436
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    } else {
                        // Subsequent calls - always fail with NoResponders error (same as AlwaysFail)
437
438
439
440
441
442
                        Err(anyhow::anyhow!(
                            DynamoError::builder()
                                .error_type(ErrorType::CannotConnect)
                                .message("no responders")
                                .build()
                        ))
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
                    }
                }
                MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
                    let (tx, rx) = mpsc::channel(1);
                    let token_offset = self.token_offset;
                    let fail_after = *fail_after;
                    let num_responses = self.num_responses;

                    if call_num == 0 {
                        // First call - send some responses then an error to simulate disconnection
                        tokio::spawn(async move {
                            // Send responses from current position to fail_after
                            for i in responses_already_generated..fail_after.min(num_responses) {
                                let response = create_mock_output(token_offset + 1 + i as u32);
                                if tx.send(response).await.is_err() {
                                    break;
                                }
                            }
                            // Send the specific error that triggers retry logic
462
463
464
465
466
467
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
468
469
470
471
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
472
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
473
474
475
476
477
478
479
480
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    } else {
                        // Subsequent calls - immediately send stream error (no successful responses)
                        tokio::spawn(async move {
                            // Send the stream error immediately
481
482
483
484
485
486
                            let error_response = Annotated::from_err(
                                DynamoError::builder()
                                    .error_type(ErrorType::Disconnected)
                                    .message("Stream ended before generation completed")
                                    .build(),
                            );
487
488
489
490
                            let _ = tx.send(error_response).await;
                        });

                        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
491
                        let ctx = Arc::new(Controller::new(self.context_id.clone()));
492
493
494
495
496
497
498
499
                        Ok(dynamo_runtime::pipeline::ResponseStream::new(
                            Box::pin(stream),
                            ctx,
                        ))
                    }
                }
                MockBehavior::AlwaysFail => {
                    // Always fail with NoResponders error (same as FailThenSuccess first call)
500
501
502
503
504
505
                    Err(anyhow::anyhow!(
                        DynamoError::builder()
                            .error_type(ErrorType::CannotConnect)
                            .message("no responders")
                            .build()
                    ))
506
507
508
509
510
511
512
513
514
515
                }
            }
        }
    }

    impl MockEngine {
        async fn send_responses(
            &self,
            start: usize,
            end: usize,
516
        ) -> Result<ManyOut<Annotated<BackendOutput>>> {
517
518
519
520
521
522
523
524
525
526
527
528
529
            let (tx, rx) = mpsc::channel(1);
            let token_offset = self.token_offset;

            tokio::spawn(async move {
                for i in start..end {
                    let response = create_mock_output(token_offset + 1 + i as u32);
                    if tx.send(response).await.is_err() {
                        break;
                    }
                }
            });

            let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
530
            let ctx = Arc::new(Controller::new(self.context_id.clone()));
531
532
533
534
535
536
537
538
539
540
541
542
543
            Ok(dynamo_runtime::pipeline::ResponseStream::new(
                Box::pin(stream),
                ctx,
            ))
        }
    }

    /// Test case 1: No migration needed
    /// Tests the normal case where the RetryManager successfully processes all responses
    /// from a single stream without any failures or need for retries/migration.
    /// Expected behavior: All 10 responses should be received successfully.
    #[tokio::test]
    async fn test_retry_manager_no_migration() {
544
        dynamo_runtime::logging::init();
545
        let context_id = uuid::Uuid::new_v4().to_string();
546
        let request = create_mock_request(10);
547
548
549
550
551
552
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::Success,
            10,
            100,
            context_id.clone(),
        ));
553
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
554
555
            mock_engine;

556
        let ctx = Arc::new(Controller::new(context_id.clone()));
557
558
559
560
561
562
563
564
565
566
567
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            0,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
568
569
570
571
572
573
574
575
576
577
578
579
580

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        assert_eq!(responses.len(), 10);
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
581
582
583

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
584
585
586
587
588
589
590
591
592
593
    }

    /// Test case 2: New request migration
    /// Tests the scenario where a worker becomes unreachable for new requests initially,
    /// triggering the RetryManager to retry the request. The MockEngine with FailThenSuccess
    /// fails on the first call with a "No responders available" error, then succeeds
    /// on subsequent calls, simulating a worker becoming available after initial failure.
    /// Expected behavior: All 10 responses should be received successfully after retry.
    #[tokio::test]
    async fn test_retry_manager_new_request_migration() {
594
        dynamo_runtime::logging::init();
595
        let context_id = uuid::Uuid::new_v4().to_string();
596
        let request = create_mock_request(10);
597
598
599
600
601
602
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::FailThenSuccess,
            10,
            100,
            context_id.clone(),
        ));
603
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
604
605
            mock_engine;

606
        let ctx = Arc::new(Controller::new(context_id.clone()));
607
608
609
610
611
612
613
614
615
616
617
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
618
619
620
621
622
623
624
625
626
627
628
629
630

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        assert_eq!(responses.len(), 10);
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
631
632
633

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 1);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
634
635
636
637
638
639
640
641
642
643
    }

    /// Test case 3: Ongoing request migration
    /// Tests the scenario where a worker fails mid-stream during an ongoing request.
    /// This simulates a connection being lost after partial response delivery, requiring
    /// the RetryManager to detect the failure (via "Stream ended before generation completed" error),
    /// create a new stream, and continue from where it left off.
    /// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration() {
644
645
        dynamo_runtime::logging::init();

646
        let context_id = uuid::Uuid::new_v4().to_string();
647
        let request = create_mock_request(10);
648
649
650
651
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFail { fail_after: 5 },
            10,
            100,
652
            context_id.clone(),
653
        ));
654
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
655
656
            mock_engine;

657
        let ctx = Arc::new(Controller::new(context_id.clone()));
658
659
660
661
662
663
664
665
666
667
668
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await
        .expect("Failed to build RetryManager");
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684

        let mut responses = Vec::new();
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received all 10 responses (5 from first stream + 5 from second stream)
        assert_eq!(responses.len(), 10);

        // Check that we received responses from both streams
        for (i, response) in responses.iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
            }
        }
685
686
687

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
688
689
690
691
692
693
694
695
    }

    /// Test case 4: New request migration - indefinite failure
    /// Tests the scenario where a worker becomes unreachable for new requests indefinitely.
    /// The RetryManager should exhaust all retries and return the original error from the first attempt.
    /// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
    #[tokio::test]
    async fn test_retry_manager_new_request_migration_indefinite_failure() {
696
        dynamo_runtime::logging::init();
697
        let context_id = uuid::Uuid::new_v4().to_string();
698
        let request = create_mock_request(0);
699
700
701
702
703
704
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::AlwaysFail,
            0,
            100,
            context_id.clone(),
        ));
705
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
706
707
708
            mock_engine;

        // Should fail to build due to initial stream creation failure after exhausting all 3 retries
709
        let ctx = Arc::new(Controller::new(context_id.clone()));
710
711
712
713
714
715
716
717
718
719
        let metrics = Arc::new(Metrics::new());
        let retry_manager_result = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await;
720
721
722
723
724

        assert!(retry_manager_result.is_err());
        if let Err(error) = retry_manager_result {
            assert!(error.to_string().contains("no responders"));
        }
725
726
727

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 4); // 3 retries + 1 final failure
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
728
729
730
731
732
733
734
735
    }

    /// Test case 5: Ongoing request migration - indefinite failure
    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests.
    /// The RetryManager should exhaust all retries and return the original stream disconnection error.
    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
736
        dynamo_runtime::logging::init();
737
        let context_id = uuid::Uuid::new_v4().to_string();
738
        let request = create_mock_request(10);
739
740
741
742
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFailAlways { fail_after: 3 },
            10,
            100,
743
            context_id.clone(),
744
        ));
745
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
746
747
            mock_engine;

748
        let ctx = Arc::new(Controller::new(context_id.clone()));
749
750
751
752
753
754
755
756
757
758
759
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        ) // 3 retries
        .await
        .expect("Failed to build RetryManager");
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778

        let mut responses = Vec::new();

        // Collect all responses (both successful and error responses)
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received 4 total responses: 3 successful + 1 error
        assert_eq!(responses.len(), 4);

        // First 3 responses should be successful with tokens 101, 102, 103
        for (i, response) in responses[0..3].iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
            }
        }

779
        // 4th response should be a Disconnected error after retries are exhausted
780
        let error_response = &responses[3];
781
782
        let err = error_response.err().expect("expected error response");
        assert_eq!(err.error_type(), ErrorType::Disconnected);
783
784
785

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 3); // 2 retries + 1 final failure
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1); // initial ongoing failure retry
786
787
788
789
790
791
792
793
    }

    /// Test case 6: Ongoing request migration - indefinite failure with stream errors
    /// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests,
    /// and all retry attempts also fail with stream errors instead of NATS errors.
    /// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
    #[tokio::test]
    async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
794
        dynamo_runtime::logging::init();
795
        let context_id = uuid::Uuid::new_v4().to_string();
796
        let request = create_mock_request(10);
797
798
799
800
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
            10,
            100,
801
            context_id.clone(),
802
        ));
803
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
804
805
            mock_engine;

806
        let ctx = Arc::new(Controller::new(context_id.clone()));
807
808
809
810
811
812
813
814
815
816
817
        let metrics = Arc::new(Metrics::new());
        let mut retry_manager = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        ) // 3 retries
        .await
        .expect("Failed to build RetryManager");
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836

        let mut responses = Vec::new();

        // Collect all responses (both successful and error responses)
        while let Some(response) = retry_manager.next().await {
            responses.push(response);
        }

        // Should have received 4 total responses: 3 successful + 1 error
        assert_eq!(responses.len(), 4);

        // First 3 responses should be successful with tokens 101, 102, 103
        for (i, response) in responses[0..3].iter().enumerate() {
            assert!(response.err().is_none());
            if let Some(output) = &response.data {
                assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
            }
        }

837
        // 4th response should be a Disconnected error after retries are exhausted
838
        let error_response = &responses[3];
839
840
        let err = error_response.err().expect("expected error response");
        assert_eq!(err.error_type(), ErrorType::Disconnected);
841
842
843

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 4); // 3 retries + 1 final failure
844
    }
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860

    /// Test case 7: Request cancelled when creating new stream
    /// Tests the scenario where context.stop_generating() is called when creating a new stream.
    /// The RetryManager should detect that the context is stopped and abort creating new streams.
    /// Expected behavior: Should fail to build RetryManager with "Context is stopped or killed" error.
    #[tokio::test]
    async fn test_retry_manager_context_stopped_before_stream() {
        dynamo_runtime::logging::init();
        let context_id = uuid::Uuid::new_v4().to_string();
        let request = create_mock_request(10);
        let mock_engine = Arc::new(MockEngine::new(
            MockBehavior::Success,
            10,
            100,
            context_id.clone(),
        ));
861
        let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
862
863
864
865
866
867
868
869
            mock_engine;

        let ctx = Arc::new(Controller::new(context_id.clone()));

        // Stop the context before building RetryManager
        ctx.stop_generating();

        // Should fail to build due to stopped context
870
871
872
873
874
875
876
877
878
879
        let metrics = Arc::new(Metrics::new());
        let retry_manager_result = RetryManager::build(
            ctx,
            request,
            next_generate,
            3,
            Arc::new(TEST_MODEL.to_string()),
            metrics.clone(),
        )
        .await;
880
881
882
883
884
885
886
887
888

        assert!(retry_manager_result.is_err());
        if let Err(error) = retry_manager_result {
            assert!(
                error
                    .to_string()
                    .contains(&format!("Context id {} is stopped or killed", context_id))
            );
        }
889
890
891

        assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
        assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 0);
892
    }
893
}