prefill_router.rs 24.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use std::sync::{Arc, OnceLock};

6
use anyhow::Result;
7
use futures::StreamExt;
8
use rand::Rng;
9
use tokio::sync::{OwnedSemaphorePermit, oneshot};
10
use tokio_util::sync::CancellationToken;
11
use tracing::Instrument;
12
13
14
15

use dynamo_runtime::{
    component::Endpoint,
    pipeline::{
16
17
        AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter,
        RouterMode, ServerStreamingEngine, SingleIn, async_trait,
18
    },
19
    protocols::{EndpointId, annotated::Annotated, maybe_error::MaybeError},
20
21
22
23
};

use crate::{
    discovery::ModelManager,
24
    kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
25
    protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
26
    protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
27
    protocols::common::timing::{RequestPhase, RequestTracker},
28
29
};

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/// Errors that can occur during prefill routing
#[derive(Debug, thiserror::Error)]
pub enum PrefillError {
    /// Prefill router has not been activated yet
    #[error("Prefill router not yet activated")]
    NotActivated,

    /// Error during prefill execution
    /// TODO: Separate prefill worker error from prefill router error
    #[error("Prefill execution failed: {0}")]
    PrefillError(String),

    /// Disaggregated params not found in prefill response
    #[error("No disaggregated params in prefill response: {0}")]
    NoDisaggregatedParams(String),
}

47
/// The inner router used by PrefillRouter
48
#[derive(Clone)]
49
50
51
52
53
54
55
enum InnerPrefillRouter {
    /// KV-aware routing using KvPushRouter
    KvRouter(Arc<KvPushRouter>),
    /// Simple routing (RoundRobin, Random, Direct)
    SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}

56
impl InnerPrefillRouter {
57
58
59
60
    /// Generate with optional direct routing to specific worker.
    /// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
    /// For SimpleRouter, target_worker triggers direct routing via router.direct().
    async fn generate_to_worker(
61
62
        &self,
        request: SingleIn<PreprocessedRequest>,
63
        target_worker: Option<u64>,
64
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
65
66
67
68
69
70
71
72
73
74
75
76
        match (self, target_worker) {
            // KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
            (InnerPrefillRouter::KvRouter(router), _) => router.generate(request).await,
            (InnerPrefillRouter::SimpleRouter(router), Some(worker_id)) => {
                router.direct(request, worker_id).await
            }
            (InnerPrefillRouter::SimpleRouter(router), None) => router.generate(request).await,
        }
    }

    /// Select next worker (for non-KV modes only)
    fn select_next_worker(&self) -> Option<u64> {
77
        match self {
78
79
            InnerPrefillRouter::SimpleRouter(router) => router.select_next_worker(),
            InnerPrefillRouter::KvRouter(_) => None,
80
81
        }
    }
82
83
84
85
86
87
88
89

    /// Peek next worker without incrementing state (for non-KV modes only)
    fn peek_next_worker(&self) -> Option<u64> {
        match self {
            InnerPrefillRouter::SimpleRouter(router) => router.peek_next_worker(),
            InnerPrefillRouter::KvRouter(_) => None,
        }
    }
90
91
}

92
93
94
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
95
///
96
97
98
99
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
100
101
pub struct PrefillRouter {
    prefill_router: OnceLock<InnerPrefillRouter>,
102
103
    model_manager: Arc<ModelManager>,
    endpoint_id: OnceLock<EndpointId>,
104
105
    cancel_token: CancellationToken,
    router_mode: RouterMode,
106
    enforce_disagg: bool,
107
108
109
110
}

impl PrefillRouter {
    /// Create a disabled prefill router that will never activate (passthrough only)
111
112
113
114
115
    pub fn disabled(
        model_manager: Arc<ModelManager>,
        router_mode: RouterMode,
        enforce_disagg: bool,
    ) -> Arc<Self> {
116
117
        Arc::new(Self {
            prefill_router: OnceLock::new(),
118
119
            model_manager,
            endpoint_id: OnceLock::new(),
120
121
            cancel_token: CancellationToken::new(),
            router_mode,
122
            enforce_disagg,
123
124
125
126
127
128
129
130
131
        })
    }

    pub fn new(
        activation_rx: oneshot::Receiver<Endpoint>,
        model_manager: Arc<ModelManager>,
        router_mode: RouterMode,
        kv_cache_block_size: u32,
        kv_router_config: Option<KvRouterConfig>,
132
        enforce_disagg: bool,
133
134
135
136
137
138
    ) -> Arc<Self> {
        let prefill_router = OnceLock::new();
        let cancel_token = CancellationToken::new();

        let router = Arc::new(Self {
            prefill_router,
139
140
            model_manager: model_manager.clone(),
            endpoint_id: OnceLock::new(),
141
142
            cancel_token: cancel_token.clone(),
            router_mode,
143
            enforce_disagg,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        });

        // Spawn background task to wait for activation
        let router_clone = router.clone();
        tokio::spawn(async move {
            tokio::select! {
                result = activation_rx => {
                    let Ok(endpoint) = result else {
                        tracing::debug!("Prefill router activation channel closed without receiving endpoint");
                        return;
                    };

                    if let Err(e) = router_clone.activate(
                        endpoint,
                        model_manager,
                        kv_cache_block_size,
                        kv_router_config,
                    ).await {
                        tracing::error!(error = %e, "Failed to activate prefill router");
                    }
                }
                _ = cancel_token.cancelled() => {
                    tracing::debug!("Prefill router activation cancelled");
                }
            }
        });

        router
    }

    /// Activate the prefill router with the provided endpoint
    async fn activate(
        &self,
        endpoint: Endpoint,
        model_manager: Arc<ModelManager>,
        kv_cache_block_size: u32,
        kv_router_config: Option<KvRouterConfig>,
    ) -> Result<()> {
        tracing::info!(
            router_mode = ?self.router_mode,
            "Activating prefill router"
        );

187
188
189
190
191
192
193
194
195
        // Store endpoint_id for later use in build_bootstrap_info
        let _ = self.endpoint_id.set(endpoint.id());

        // Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
        // This must be done before creating the router so bootstrap info is available
        model_manager
            .get_or_create_runtime_config_watcher(&endpoint)
            .await?;

196
        let inner_router = if self.router_mode.is_kv_routing() {
197
            // Create KV chooser using the endpoint
198
            let kv_chooser = model_manager
199
                .kv_chooser_for(&endpoint, kv_cache_block_size, kv_router_config)
200
201
                .await?;

202
203
204
205
            // Extract client from kv_chooser to ensure shared state
            let client = kv_chooser.client().clone();

            // Build the PushRouter for prefill with KV mode using the shared client
206
207
208
209
210
211
212
213
214
215
216
            let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
                client,
                RouterMode::KV,
                None, // busy_threshold
                None, // worker_monitor
            )
            .await?;

            // Wrap it in KvPushRouter
            InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
        } else {
217
218
219
            // Create client for simple router
            let client = endpoint.client().await?;

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            // Create simple push router with the frontend's router mode
            let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
                client,
                self.router_mode,
                None, // busy_threshold
                None, // worker_monitor
            )
            .await?;

            InnerPrefillRouter::SimpleRouter(Arc::new(push_router))
        };

        // Set the router (ignore error if already set)
        let _ = self.prefill_router.set(inner_router);

        tracing::info!(
            router_mode = ?self.router_mode,
            "Prefill router activated successfully"
        );

        Ok(())
    }

243
244
    /// Build bootstrap_info for disaggregated serving
    /// If preselected_worker is provided (GAIE Stage 2), use it directly.
245
    /// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
246
    async fn build_bootstrap_info(
247
        &self,
248
        req: &PreprocessedRequest,
249
        preselected_worker: Option<u64>,
250
    ) -> Option<(u64, u32, BootstrapInfo)> {
251
        let endpoint_id = self.endpoint_id.get()?;
252
253
        let prefill_router = self.prefill_router.get()?;

254
        // Worker selection
255
        let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
256
            let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
257
258
259
260
261
262
            tracing::debug!(
                worker_id = id,
                dp_rank = dp_rank,
                "Using pre-selected prefill worker for bootstrap"
            );
            (id, dp_rank)
263
264
265
266
267
268
        } else if self.router_mode.is_kv_routing() {
            // KV mode: use find_best_match
            let kv_router = match prefill_router {
                InnerPrefillRouter::KvRouter(r) => r,
                _ => return None,
            };
269
270
            // Extract LORA name from routing hints
            let lora_name = req.routing.as_ref().and_then(|r| r.lora_name.clone());
271
272
273
            match async {
                kv_router
                    .chooser
274
                    .find_best_match(None, &req.token_ids, None, false, lora_name)
275
276
277
278
                    .await
            }
            .instrument(tracing::info_span!("kv_find_best_match"))
            .await
279
280
281
282
            {
                Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
                Err(_) => return None,
            }
283
284
        } else {
            // Non-KV mode: use PushRouter's stateful selection
285
286
287
            // We use peek_next_worker instead of select_next_worker to avoid double-incrementing the counter
            // if we fall back to the original path.
            let worker_id = prefill_router.peek_next_worker()?;
288
            (worker_id, 0)
289
290
        };

291
292
293
294
        // Get bootstrap info from ModelManager (works for ANY mode)
        let endpoint = self
            .model_manager
            .get_disaggregated_endpoint(endpoint_id, worker_id)?;
295
296
297
        let host = endpoint.bootstrap_host?;
        let port = endpoint.bootstrap_port?;

298
        let bootstrap_room: u64 = rand::rng().random();
299
300
301
302
303
304
305

        tracing::info!(
            worker_id = worker_id,
            dp_rank = dp_rank,
            bootstrap_host = %host,
            bootstrap_port = port,
            bootstrap_room = bootstrap_room,
306
            router_mode = ?self.router_mode,
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            "Built bootstrap_info upfront before prefill"
        );

        Some((
            worker_id,
            dp_rank,
            BootstrapInfo {
                bootstrap_host: host,
                bootstrap_port: port,
                bootstrap_room,
            },
        ))
    }

321
322
323
324
325
326
327
    /// Execute prefill with the given router and extract structured result.
    ///
    /// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
    ///
    /// If `phase_permit` is provided, it is dropped after the first output is received,
    /// allowing subsequent `set_phase` calls to proceed. This is used in the bootstrap
    /// optimization path to ensure `record_worker` completes before the phase changes.
328
329
330
    async fn execute_prefill(
        router: Option<InnerPrefillRouter>,
        request: SingleIn<PreprocessedRequest>,
331
        target_worker: Option<u64>,
332
        phase_permit: Option<OwnedSemaphorePermit>,
333
334
335
    ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
        let router = router.ok_or(PrefillError::NotActivated)?;
        let mut prefill_response = router
336
            .generate_to_worker(request, target_worker)
337
338
339
            .await
            .map_err(|e| PrefillError::PrefillError(e.to_string()))?;

340
341
342
343
        // Drop phase permit now - routing is complete, record_worker was called in select_worker.
        // This unblocks set_phase(Decode) in the main task without waiting for prefill output.
        drop(phase_permit);

344
        let Some(first_output) = prefill_response.next().await else {
345
346
347
            return Err(PrefillError::PrefillError(
                "Prefill router returned no output (stream ended)".to_string(),
            ));
348
349
        };

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        let mut prompt_tokens_details = first_output
            .data
            .as_ref()
            .and_then(|o| o.completion_usage.as_ref())
            .and_then(|u| u.prompt_tokens_details.clone());

        while let Some(next) = prefill_response.next().await {
            if let Some(o) = next.data.as_ref()
                && prompt_tokens_details.is_none()
            {
                prompt_tokens_details = o
                    .completion_usage
                    .as_ref()
                    .and_then(|u| u.prompt_tokens_details.clone());
            }
        }
366

367
        if let Some(err) = first_output.err() {
368
369
370
            return Err(PrefillError::PrefillError(format!(
                "Prefill router returned error in output: {err:?}"
            )));
371
372
373
        }

        let Some(output) = &first_output.data else {
374
375
376
            return Err(PrefillError::NoDisaggregatedParams(
                "Prefill router output has no data field".to_string(),
            ));
377
378
379
        };

        let Some(disaggregated_params) = output.disaggregated_params.clone() else {
380
381
382
            return Err(PrefillError::NoDisaggregatedParams(
                "Prefill router output missing disaggregated_params".to_string(),
            ));
383
384
        };

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        // Extract prefill worker ID from disaggregated_params
        let prefill_worker_id = disaggregated_params
            .get("worker_id")
            .and_then(|worker_id_json| {
                worker_id_json
                    .get("prefill_worker_id")
                    .and_then(|v| v.as_u64())
            });
        Ok((
            PrefillResult {
                disaggregated_params,
                prompt_tokens_details,
            },
            prefill_worker_id,
        ))
    }

402
403
404
405
406
407
    /// Spawn prefill as a background task.
    ///
    /// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
    ///
    /// The `phase_permit` is passed to the spawned task and dropped after the first output,
    /// allowing the main task's `set_phase(Decode)` to proceed.
408
409
410
411
    fn spawn_prefill_task(
        &self,
        prefill_request: SingleIn<PreprocessedRequest>,
        target_worker: Option<u64>,
412
        phase_permit: OwnedSemaphorePermit,
413
    ) {
414
        let router = self.prefill_router.get().cloned();
415
416
417
418
419
420
421
422
423
424
425
        // Capture current span to propagate trace context to the spawned task
        let span = tracing::Span::current();

        tokio::spawn(
            async move {
                match Self::execute_prefill(
                    router,
                    prefill_request,
                    target_worker,
                    Some(phase_permit),
                )
426
                .await
427
428
429
430
431
432
433
                {
                    Ok(_) => {
                        tracing::debug!("Prefill background task completed");
                    }
                    Err(e) => {
                        tracing::warn!("Prefill background task error: {e:?}");
                    }
434
435
                }
            }
436
437
            .instrument(span),
        );
438
439
    }

440
441
442
443
    /// Call the prefill router and extract structured prefill result and worker ID.
    ///
    /// This is the synchronous prefill path - we wait for prefill to complete before proceeding.
    /// No phase permit is needed since `record_worker` completes before we return.
444
445
446
447
    async fn call_prefill(
        &self,
        request: SingleIn<PreprocessedRequest>,
    ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
448
        // For call_prefill path, routing is handled by the router itself (no direct routing needed)
449
450
        // No phase permit needed - we wait for completion before changing phase
        Self::execute_prefill(self.prefill_router.get().cloned(), request, None, None).await
451
452
453
    }
}

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
impl Drop for PrefillRouter {
    fn drop(&mut self) {
        tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
        self.cancel_token.cancel();
    }
}

#[async_trait]
impl
    Operator<
        SingleIn<PreprocessedRequest>,
        ManyOut<Annotated<LLMEngineOutput>>,
        SingleIn<PreprocessedRequest>,
        ManyOut<Annotated<LLMEngineOutput>>,
    > for PrefillRouter
{
    async fn generate(
        &self,
        request: SingleIn<PreprocessedRequest>,
        next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
        // Extract request data while preserving context
476
        let (mut req, context) = request.into_parts();
477
        let request_id = context.id().to_string();
478
        let engine_ctx = context.context();
479

Yan Ru Pei's avatar
Yan Ru Pei committed
480
481
482
        // Save original max_tokens for decode
        let original_max_tokens = req.stop_conditions.max_tokens;

483
484
        // If prefill router is not activated, skip directly to decode
        if self.prefill_router.get().is_none() {
485
486
487
488
489
490
491
492
493
494
495
496
            if self.enforce_disagg {
                return Err(anyhow::anyhow!(PrefillError::NotActivated));
            }
            return next.generate(context.map(|_| req)).await;
        }

        // Ensure tracker exists for routing decisions in disaggregated mode.
        // Create one if not provided by the upstream DeltaGenerator.
        if req.tracker.is_none() {
            req.tracker = Some(Arc::new(RequestTracker::new()));
        }
        let tracker = req.tracker.as_ref().unwrap();
497
        let prefill_phase_permit = tracker.set_phase(RequestPhase::Prefill).await;
498
499
500
        tracker.record_prefill_start();

        // Prepare prefill request with max_tokens = 1 (clone after tracker is set)
Yan Ru Pei's avatar
Yan Ru Pei committed
501
502
        let mut prefill_req = req.clone();
        prefill_req.stop_conditions.max_tokens = Some(1);
503

504
505
        // Try build_bootstrap_info optimization: if we can get bootstrap info upfront,
        // spawn prefill in background and proceed to decode immediately.
506
507
508
509
        let preselected_worker = prefill_req
            .routing
            .as_ref()
            .and_then(|r| r.prefill_worker_id);
510

511
512
513
514
        let prefill_result = async {
            if let Some((worker_id, dp_rank, bootstrap_info)) = self
                .build_bootstrap_info(&prefill_req, preselected_worker)
                .await
515
            {
516
517
518
519
520
521
522
523
                // Bootstrap optimization path: spawn prefill in background
                // We successfully used the peeked worker, so we must now advance the router state
                // to ensure the next request gets a different worker.
                if !self.router_mode.is_kv_routing()
                    && let Some(router) = self.prefill_router.get()
                {
                    router.select_next_worker();
                }
524

525
526
527
528
                let routing = prefill_req.routing_mut();
                routing.prefill_worker_id = Some(worker_id);
                routing.dp_rank = Some(dp_rank);
                prefill_req.bootstrap_info = Some(bootstrap_info.clone());
529

530
531
                let prefill_context = Context::with_id(prefill_req, request_id.clone());
                engine_ctx.link_child(prefill_context.context());
532

533
534
535
                // Pass phase permit to spawned task - it drops after first output (record_worker complete)
                // This allows set_phase(Decode) below to proceed only after prefill routing is done
                self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
536

537
538
539
540
                Ok((None, Some(worker_id), Some(bootstrap_info)))
            } else {
                // Original prefill path: wait for prefill to complete
                tracing::debug!("Using original prefill path");
541

542
543
544
                // Drop the phase permit before calling call_prefill - we wait for completion
                // so there's no race with set_phase(Decode) below
                drop(prefill_phase_permit);
545

546
547
                let prefill_context = Context::with_id(prefill_req, request_id.clone());
                engine_ctx.link_child(prefill_context.context());
548

549
550
551
552
553
554
555
                self.call_prefill(prefill_context)
                    .await
                    .map(|(result, worker_id)| (Some(result), worker_id, None))
            }
        }
        .instrument(tracing::info_span!("prefill_routing"))
        .await;
556
557
558
559
560
561
562
563
564
565
566
567

        // Abort if cancelled during prefill
        if engine_ctx.is_stopped() || engine_ctx.is_killed() {
            tracing::debug!("Abort entering decode after context is stopped or killed");
            return Err(anyhow::anyhow!(
                "Context id {} is stopped or killed",
                engine_ctx.id()
            ));
        }

        // Handle prefill result
        match prefill_result {
568
569
            Ok((maybe_prefill_result, _prefill_worker_id, bootstrap_info)) => {
                tracing::debug!("Prefill completed, proceeding to decode");
570

571
572
573
                // Set phase to Decode for the decode request.
                // In bootstrap path, this blocks until the spawned prefill task drops its permit
                // (after first output / record_worker completes), ensuring correct phase for routing.
574
                if let Some(ref tracker) = req.tracker {
575
576
                    let _decode_permit = tracker.set_phase(RequestPhase::Decode).await;
                    // Permit is dropped immediately - decode proceeds, no need to hold it
577
578
                }

579
                let mut decode_req = req;
580

581
                // Update request with prefill result
582
                if let Some(prefill_result) = maybe_prefill_result {
583
584
585
                    decode_req.prefill_result = Some(prefill_result);
                }

Yan Ru Pei's avatar
Yan Ru Pei committed
586
587
                // Restore original max_tokens for decode
                decode_req.stop_conditions.max_tokens = original_max_tokens;
588

589
590
591
592
593
                // Inject bootstrap_info for decode worker
                if let Some(info) = bootstrap_info {
                    decode_req.bootstrap_info = Some(info);
                }

594
595
596
597
598
599
600
601
                // Set router_config_override for decode: overlap_score_weight = 0
                let existing_override = decode_req.router_config_override.take();
                decode_req.router_config_override = Some(RouterConfigOverride {
                    overlap_score_weight: Some(0.0),
                    ..existing_override.unwrap_or_default()
                });

                // Map the modified request through with preserved context
602
                let decode_request = context.map(|_| decode_req);
603
604
                next.generate(decode_request).await
            }
605
606
607
608
609
610
611
612
613
614
            Err(PrefillError::NotActivated) => {
                if self.enforce_disagg {
                    tracing::error!(
                        "Prefill router not activated, but disaggregated mode is enforced. Failing request."
                    );
                    return Err(anyhow::anyhow!(PrefillError::NotActivated));
                }
                tracing::debug!("Prefill router not activated, falling back to decode-only");
                next.generate(context.map(|_| req)).await
            }
615
            Err(e) => {
616
617
618
619
620
621
622
                if self.enforce_disagg {
                    tracing::error!(
                        error = %e,
                        "Remote prefill failed, but disaggregated mode is enforced. Failing request."
                    );
                    return Err(anyhow::anyhow!(e));
                }
623
624
625
626
                tracing::warn!(
                    error = %e,
                    "Remote prefill failed, falling back to decode-only. This may impact performance in disaggregated deployments. Verify prefill workers are healthy and accessible."
                );
627
628
629
630
631
                next.generate(context.map(|_| req)).await
            }
        }
    }
}