common.rs 11.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::pin::Pin;

6
use crate::{
7
    backend::{Backend, ExecutionContext},
8
    discovery::{KvWorkerMonitor, ModelManager, ModelWatcher},
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, RouterConfig},
11
    http::service::metrics::Metrics,
12
    kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
13
    migration::Migration,
14
    model_card::ModelDeploymentCard,
15
    namespace::NamespaceFilter,
16
    preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
17
    protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
18
    request_template::RequestTemplate,
19
    types::{
20
        Annotated,
21
22
23
24
25
26
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            OpenAIChatCompletionsStreamingEngine,
        },
    },
};
27

28
use anyhow::Context as _;
29
use dynamo_runtime::{
30
    DistributedRuntime,
31
    component::Client,
32
    engine::{AsyncEngineStream, Data},
33
34
35
36
    pipeline::{
        Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend,
        ServiceEngine, ServiceFrontend, SingleIn, Source,
    },
37
38
39
};
use std::sync::Arc;

40
41
42
43
pub struct PreparedEngine {
    pub service_name: String,
    pub engine: OpenAIChatCompletionsStreamingEngine,
    pub inspect_template: bool,
44
45
46
47
48
49
50
51
52
53
54
55
    pub card: Option<ModelDeploymentCard>,
    pub request_template: Option<RequestTemplate>,
}

impl PreparedEngine {
    pub fn has_tokenizer(&self) -> bool {
        if let Some(card) = self.card.as_ref() {
            card.has_tokenizer()
        } else {
            false
        }
    }
56
57
}

58
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
59
pub async fn prepare_engine(
60
    distributed_runtime: DistributedRuntime,
61
    engine_config: EngineConfig,
62
) -> anyhow::Result<PreparedEngine> {
63
    match engine_config {
64
65
66
        EngineConfig::Dynamic {
            model: local_model, ..
        } => {
67
            let model_manager = Arc::new(ModelManager::new());
68
69
            // Create metrics for migration tracking (not exposed via /metrics in Dynamic engine mode)
            let metrics = Arc::new(Metrics::new());
70
            let watch_obj = Arc::new(ModelWatcher::new(
71
                distributed_runtime.clone(),
72
                model_manager.clone(),
73
                RouterConfig::default(),
74
                local_model.migration_limit(),
75
                None,
76
                metrics,
77
            ));
78
79
80
81
82
83
84
            let discovery = distributed_runtime.discovery();
            let discovery_stream = discovery
                .list_and_watch(
                    dynamo_runtime::discovery::DiscoveryQuery::AllModels,
                    Some(distributed_runtime.primary_token().clone()),
                )
                .await?;
85
            let inner_watch_obj = watch_obj.clone();
86
87
88
89
            let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
                local_model.namespace(),
                local_model.namespace_prefix(),
            );
90
            let _watcher_task = tokio::spawn(async move {
91
92
93
                inner_watch_obj
                    .watch(discovery_stream, namespace_filter)
                    .await;
94
            });
95
            tracing::info!("Waiting for remote model..");
96

97
98
99
100
101
            // TODO: We use the first model to appear, usually we have only one
            // We should add slash commands to text input `/model <name>` to choose,
            // '/models` to list, and notifications when models are added / removed.

            let model_service_name = watch_obj.wait_for_chat_model().await;
102
            tracing::info!("Connected to {model_service_name}");
103
            let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
104
            Ok(PreparedEngine {
105
                service_name: model_service_name,
106
107
                engine,
                inspect_template: false,
108
109
                card: None,
                request_template: local_model.request_template(),
110
            })
111
        }
112
        EngineConfig::InProcessText { engine, model, .. } => {
113
            let service_name = model.service_name().to_string();
114
            tracing::debug!("Model: {service_name} with engine pre-processing");
115
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
116
117
118
119
            Ok(PreparedEngine {
                service_name,
                engine,
                inspect_template: false,
120
121
                request_template: model.request_template(),
                card: Some(model.into_card()),
122
            })
123
        }
124
        EngineConfig::InProcessTokens {
125
            engine: inner_engine,
126
            model,
127
            ..
128
        } => {
129
130
131
            let pipeline = build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
132
            >(model.card(), inner_engine, model.card().tokenizer_hf()?)
133
            .await?;
134

135
            let service_name = model.service_name().to_string();
136
137
138
139
140
            tracing::debug!("Model: {service_name} with Dynamo pre-processing");
            Ok(PreparedEngine {
                service_name,
                engine: pipeline,
                inspect_template: true,
141
142
                request_template: model.request_template(),
                card: Some(model.into_card()),
143
            })
144
145
146
        }
    }
}
147
148
149
150

pub async fn build_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    engine: ExecutionContext,
151
    hf_tokenizer: tokenizers::Tokenizer,
152
153
154
155
156
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
157
158
159
160
161
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
162
163
{
    let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
164
165
166
167
168
    let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
    let preprocessor =
        OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?
            .into_operator();
    let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
169
170
171
172
173
174
175
176
177
178
179
    let engine = ServiceBackend::from_engine(engine);

    Ok(frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(engine)?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?)
}

180
#[allow(clippy::too_many_arguments)]
181
182
183
pub async fn build_routed_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    client: &Client,
184
    model_manager: Arc<crate::discovery::ModelManager>,
185
    router_mode: RouterMode,
186
    worker_monitor: Option<KvWorkerMonitor>,
187
    chooser: Option<Arc<KvRouter>>,
188
    hf_tokenizer: tokenizers::Tokenizer,
189
    prefill_chooser: Option<Arc<PrefillRouter>>,
190
    enforce_disagg: bool,
191
    migration_limit: u32,
192
    metrics: Arc<Metrics>,
193
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
194
195
196
197
198
199
200
201
202
203
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
{
204
205
    let PromptFormatter::OAI(formatter) =
        PromptFormatter::from_mdc(card).context("PromptFormatter.from_mdc")?;
206
    let preprocessor =
207
208
        OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())
            .context("OpenAIPreprocessor.new_with_parts")?;
209
210
211
    build_routed_pipeline_with_preprocessor(
        card,
        client,
212
        model_manager,
213
        router_mode,
214
        worker_monitor,
215
216
        chooser,
        preprocessor,
217
        hf_tokenizer,
218
        prefill_chooser,
219
        enforce_disagg,
220
        migration_limit,
221
        metrics,
222
223
224
225
    )
    .await
}

226
#[allow(clippy::too_many_arguments)]
227
228
229
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
    card: &ModelDeploymentCard,
    client: &Client,
230
    model_manager: Arc<crate::discovery::ModelManager>,
231
    router_mode: RouterMode,
232
    worker_monitor: Option<KvWorkerMonitor>,
233
234
    chooser: Option<Arc<KvRouter>>,
    preprocessor: Arc<OpenAIPreprocessor>,
235
    hf_tokenizer: tokenizers::Tokenizer,
236
    prefill_chooser: Option<Arc<PrefillRouter>>,
237
    enforce_disagg: bool,
238
    migration_limit: u32,
239
    metrics: Arc<Metrics>,
240
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
241
242
243
244
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
245
246
247
248
249
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
250
251
{
    let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
252
    let preprocessor_op = preprocessor.into_operator();
253
    let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
254
    let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
255

256
257
258
259
260
261
262
263
264
265
    // For KV routing, use the client from the chooser to ensure shared state
    let router_client = if router_mode == RouterMode::KV {
        let Some(ref chooser) = chooser else {
            anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
        };
        chooser.client().clone()
    } else {
        client.clone()
    };

266
    // Get threshold value and wrap monitor for PushRouter
267
268
269
270
    // Note: PushRouter uses active_decode_blocks_threshold for its internal logic
    let threshold_value = worker_monitor
        .as_ref()
        .map(|m| m.active_decode_blocks_threshold());
271
272
    let monitor_arc =
        worker_monitor.map(|m| Arc::new(m) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>);
273

274
275
    let router =
        PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
276
            router_client,
277
            router_mode,
278
279
            threshold_value,
            monitor_arc,
280
281
        )
        .await?;
282

283
    let service_backend = match router_mode {
284
285
286
287
        RouterMode::Direct => {
            ServiceBackend::from_engine(Arc::new(DirectRoutingRouter::new(router)))
        }
        RouterMode::Random | RouterMode::RoundRobin => {
288
289
290
291
292
293
294
295
296
297
298
            ServiceBackend::from_engine(Arc::new(router))
        }
        RouterMode::KV => {
            let Some(chooser) = chooser else {
                anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
            };
            let kv_push_router = KvPushRouter::new(router, chooser);
            ServiceBackend::from_engine(Arc::new(kv_push_router))
        }
    };

299
    // Use the provided prefill chooser, or create a disabled one if not provided
300
301
    let prefill_chooser = prefill_chooser
        .unwrap_or_else(|| PrefillRouter::disabled(model_manager, router_mode, enforce_disagg));
302
303
304
    let prefill_op = prefill_chooser.into_operator();

    // Link with prefill chooser including backward edge for response flow
305
    let engine = frontend
306
        .link(preprocessor_op.forward_edge())?
307
        .link(migration.forward_edge())?
308
        .link(backend.forward_edge())?
309
        .link(prefill_op.forward_edge())?
310
        .link(service_backend)?
311
        .link(prefill_op.backward_edge())?
312
        .link(backend.backward_edge())?
313
        .link(migration.backward_edge())?
314
        .link(preprocessor_op.backward_edge())?
315
        .link(frontend)?;
316

317
318
    Ok(engine)
}