common.rs 10.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::pin::Pin;

6
use crate::{
7
    backend::{Backend, ExecutionContext},
8
    discovery::{ModelManager, ModelWatcher},
9
    engines::StreamingEngineAdapter,
10
    entrypoint::{EngineConfig, RouterConfig},
11
    kv_router::{KvPushRouter, KvRouter, PrefillRouter},
12
    migration::Migration,
13
    model_card::ModelDeploymentCard,
14
    preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
15
    protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
16
    request_template::RequestTemplate,
17
    types::{
18
        Annotated,
19
20
21
22
23
24
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            OpenAIChatCompletionsStreamingEngine,
        },
    },
};
25

26
use dynamo_runtime::{
27
    DistributedRuntime,
28
    component::Client,
29
    engine::{AsyncEngineStream, Data},
30
31
32
33
    pipeline::{
        Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend,
        ServiceEngine, ServiceFrontend, SingleIn, Source,
    },
34
35
36
};
use std::sync::Arc;

37
38
39
40
pub struct PreparedEngine {
    pub service_name: String,
    pub engine: OpenAIChatCompletionsStreamingEngine,
    pub inspect_template: bool,
41
42
43
44
45
46
47
48
49
50
51
52
    pub card: Option<ModelDeploymentCard>,
    pub request_template: Option<RequestTemplate>,
}

impl PreparedEngine {
    pub fn has_tokenizer(&self) -> bool {
        if let Some(card) = self.card.as_ref() {
            card.has_tokenizer()
        } else {
            false
        }
    }
53
54
}

55
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
56
pub async fn prepare_engine(
57
    distributed_runtime: DistributedRuntime,
58
    engine_config: EngineConfig,
59
) -> anyhow::Result<PreparedEngine> {
60
    match engine_config {
61
        EngineConfig::Dynamic(local_model) => {
62
            let model_manager = Arc::new(ModelManager::new());
63
            let watch_obj = Arc::new(ModelWatcher::new(
64
                distributed_runtime.clone(),
65
                model_manager.clone(),
66
                RouterConfig::default(),
67
            ));
68
69
70
71
72
73
74
            let discovery = distributed_runtime.discovery();
            let discovery_stream = discovery
                .list_and_watch(
                    dynamo_runtime::discovery::DiscoveryQuery::AllModels,
                    Some(distributed_runtime.primary_token().clone()),
                )
                .await?;
75
            let inner_watch_obj = watch_obj.clone();
76
            let _watcher_task = tokio::spawn(async move {
77
                inner_watch_obj.watch(discovery_stream, None).await;
78
            });
79
            tracing::info!("Waiting for remote model..");
80

81
82
83
84
85
            // TODO: We use the first model to appear, usually we have only one
            // We should add slash commands to text input `/model <name>` to choose,
            // '/models` to list, and notifications when models are added / removed.

            let model_service_name = watch_obj.wait_for_chat_model().await;
86
            tracing::info!("Connected to {model_service_name}");
87
            let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
88
            Ok(PreparedEngine {
89
                service_name: model_service_name,
90
91
                engine,
                inspect_template: false,
92
93
                card: None,
                request_template: local_model.request_template(),
94
            })
95
        }
96
        EngineConfig::StaticFull { engine, model, .. } => {
97
            let service_name = model.service_name().to_string();
98
            tracing::debug!("Model: {service_name} with engine pre-processing");
99
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
100
101
102
103
            Ok(PreparedEngine {
                service_name,
                engine,
                inspect_template: false,
104
105
                request_template: model.request_template(),
                card: Some(model.into_card()),
106
            })
107
108
109
        }
        EngineConfig::StaticCore {
            engine: inner_engine,
110
            model,
111
            ..
112
        } => {
113
114
115
            let pipeline = build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
116
            >(model.card(), inner_engine, model.card().tokenizer_hf()?)
117
            .await?;
118

119
            let service_name = model.service_name().to_string();
120
121
122
123
124
            tracing::debug!("Model: {service_name} with Dynamo pre-processing");
            Ok(PreparedEngine {
                service_name,
                engine: pipeline,
                inspect_template: true,
125
126
                request_template: model.request_template(),
                card: Some(model.into_card()),
127
            })
128
129
130
        }
    }
}
131
132
133
134

pub async fn build_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    engine: ExecutionContext,
135
    hf_tokenizer: tokenizers::Tokenizer,
136
137
138
139
140
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
141
142
143
144
145
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
146
147
{
    let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
148
149
150
151
152
    let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
    let preprocessor =
        OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?
            .into_operator();
    let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
153
154
155
156
157
158
159
160
161
162
163
    let engine = ServiceBackend::from_engine(engine);

    Ok(frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(engine)?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?)
}

164
#[allow(clippy::too_many_arguments)]
165
166
167
168
pub async fn build_routed_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    client: &Client,
    router_mode: RouterMode,
169
    busy_threshold: Option<f64>,
170
    chooser: Option<Arc<KvRouter>>,
171
    hf_tokenizer: tokenizers::Tokenizer,
172
    prefill_chooser: Option<Arc<PrefillRouter>>,
173
    enforce_disagg: bool,
174
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
175
176
177
178
179
180
181
182
183
184
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
{
185
186
187
    let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
    let preprocessor =
        OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?;
188
189
190
191
192
193
194
    build_routed_pipeline_with_preprocessor(
        card,
        client,
        router_mode,
        busy_threshold,
        chooser,
        preprocessor,
195
        hf_tokenizer,
196
        prefill_chooser,
197
        enforce_disagg,
198
199
200
201
    )
    .await
}

202
#[allow(clippy::too_many_arguments)]
203
204
205
206
207
208
209
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
    card: &ModelDeploymentCard,
    client: &Client,
    router_mode: RouterMode,
    busy_threshold: Option<f64>,
    chooser: Option<Arc<KvRouter>>,
    preprocessor: Arc<OpenAIPreprocessor>,
210
    hf_tokenizer: tokenizers::Tokenizer,
211
    prefill_chooser: Option<Arc<PrefillRouter>>,
212
    enforce_disagg: bool,
213
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
214
215
216
217
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
218
219
220
221
222
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
223
224
{
    let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
225
    let preprocessor_op = preprocessor.into_operator();
226
227
    let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
    let migration = Migration::from_mdc(card).into_operator();
228

229
230
231
232
233
234
235
236
237
238
    // For KV routing, use the client from the chooser to ensure shared state
    let router_client = if router_mode == RouterMode::KV {
        let Some(ref chooser) = chooser else {
            anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
        };
        chooser.client().clone()
    } else {
        client.clone()
    };

239
240
241
    // Create worker monitor only if busy_threshold is set
    let worker_monitor = busy_threshold.map(|threshold| {
        Arc::new(crate::discovery::KvWorkerMonitor::new(
242
            Arc::new(router_client.clone()),
243
244
245
246
            threshold,
        )) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
    });

247
248
    let router =
        PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
249
            router_client,
250
251
            router_mode,
            busy_threshold,
252
            worker_monitor,
253
254
        )
        .await?;
255

256
257
258
259
260
261
262
263
264
265
266
267
268
    let service_backend = match router_mode {
        RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
            ServiceBackend::from_engine(Arc::new(router))
        }
        RouterMode::KV => {
            let Some(chooser) = chooser else {
                anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
            };
            let kv_push_router = KvPushRouter::new(router, chooser);
            ServiceBackend::from_engine(Arc::new(kv_push_router))
        }
    };

269
    // Use the provided prefill chooser, or create a disabled one if not provided
270
271
    let prefill_chooser =
        prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg));
272
273
274
    let prefill_op = prefill_chooser.into_operator();

    // Link with prefill chooser including backward edge for response flow
275
    let engine = frontend
276
        .link(preprocessor_op.forward_edge())?
277
278
        .link(backend.forward_edge())?
        .link(migration.forward_edge())?
279
        .link(prefill_op.forward_edge())?
280
        .link(service_backend)?
281
        .link(prefill_op.backward_edge())?
282
283
        .link(migration.backward_edge())?
        .link(backend.backward_edge())?
284
        .link(preprocessor_op.backward_edge())?
285
        .link(frontend)?;
286

287
288
    Ok(engine)
}