common.rs 11.4 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::pin::Pin;

6
use crate::{
7
    backend::{Backend, ExecutionContext},
8
    discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher},
9
    engines::StreamingEngineAdapter,
10
11
12
    entrypoint::{self, EngineConfig},
    kv_router::{KvPushRouter, KvRouter},
    migration::Migration,
13
    model_card::ModelDeploymentCard,
14
    preprocessor::OpenAIPreprocessor,
15
    protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
16
    request_template::RequestTemplate,
17
    types::{
18
        Annotated,
19
20
21
22
23
24
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            OpenAIChatCompletionsStreamingEngine,
        },
    },
};
25

26
use dynamo_runtime::{
27
    DistributedRuntime, Runtime,
28
29
    component::Client,
    distributed::DistributedConfig,
30
    engine::{AsyncEngineStream, Data},
31
32
33
34
    pipeline::{
        Context, ManyOut, Operator, PushRouter, RouterMode, SegmentSource, ServiceBackend,
        ServiceEngine, ServiceFrontend, SingleIn, Source,
    },
35
36
37
};
use std::sync::Arc;

38
39
40
41
pub struct PreparedEngine {
    pub service_name: String,
    pub engine: OpenAIChatCompletionsStreamingEngine,
    pub inspect_template: bool,
42
43
44
45
46
47
48
49
50
51
52
53
    pub card: Option<ModelDeploymentCard>,
    pub request_template: Option<RequestTemplate>,
}

impl PreparedEngine {
    pub fn has_tokenizer(&self) -> bool {
        if let Some(card) = self.card.as_ref() {
            card.has_tokenizer()
        } else {
            false
        }
    }
54
55
}

56
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
57
58
59
pub async fn prepare_engine(
    runtime: Runtime,
    engine_config: EngineConfig,
60
) -> anyhow::Result<PreparedEngine> {
61
    match engine_config {
62
        EngineConfig::Dynamic(local_model) => {
63
64
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

65
            let Some(etcd_client) = distributed_runtime.etcd_client() else {
66
                anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
67
            };
68
            let model_manager = Arc::new(ModelManager::new());
69
70
71
72
            let watch_obj = Arc::new(ModelWatcher::new(
                distributed_runtime,
                model_manager.clone(),
                dynamo_runtime::pipeline::RouterMode::RoundRobin,
73
                None,
74
                None,
75
            ));
76
            let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
77
            let (_prefix, _watcher, receiver) = models_watcher.dissolve();
78

79
            let inner_watch_obj = watch_obj.clone();
80
81
82
            let _watcher_task = tokio::spawn(async move {
                inner_watch_obj.watch(receiver).await;
            });
83
            tracing::info!("Waiting for remote model..");
84

85
86
87
88
89
            // TODO: We use the first model to appear, usually we have only one
            // We should add slash commands to text input `/model <name>` to choose,
            // '/models` to list, and notifications when models are added / removed.

            let model_service_name = watch_obj.wait_for_chat_model().await;
90
            tracing::info!("Connected to {model_service_name}");
91
            let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
92
            Ok(PreparedEngine {
93
                service_name: model_service_name,
94
95
                engine,
                inspect_template: false,
96
97
                card: None,
                request_template: local_model.request_template(),
98
            })
99
        }
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        EngineConfig::StaticRemote(local_model) => {
            // For now we only do ModelType.Backend
            // For batch/text we only do Chat Completions

            // The card should have been loaded at 'build' phase earlier
            let card = local_model.card();
            let router_mode = local_model.router_config().router_mode;

            let dst_config = DistributedConfig::from_settings(true);
            let distributed_runtime = DistributedRuntime::new(runtime, dst_config).await?;

            let endpoint_id = local_model.endpoint_id();
            let component = distributed_runtime
                .namespace(&endpoint_id.namespace)?
114
                .component(&endpoint_id.component)?;
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            let client = component.endpoint(&endpoint_id.name).client().await?;

            let kv_chooser = if router_mode == RouterMode::KV {
                let model_manager = Arc::new(ModelManager::new());
                Some(
                    model_manager
                        .kv_chooser_for(
                            local_model.display_name(),
                            &component,
                            card.kv_cache_block_size,
                            Some(local_model.router_config().kv_router_config),
                        )
                        .await?,
                )
            } else {
                None
            };

            let chat_engine = entrypoint::build_routed_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
137
            >(card, &client, router_mode, None, kv_chooser.clone())
138
139
140
141
142
143
144
145
146
147
148
149
150
            .await?;

            let service_name = local_model.service_name().to_string();
            tracing::info!("Static connecting to {service_name}");
            Ok(PreparedEngine {
                service_name,
                engine: chat_engine,
                inspect_template: false,
                request_template: local_model.request_template(),
                card: Some(local_model.into_card()),
            })
        }
        EngineConfig::StaticFull { engine, model, .. } => {
151
            let service_name = model.service_name().to_string();
152
            tracing::debug!("Model: {service_name} with engine pre-processing");
153
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
154
155
156
157
            Ok(PreparedEngine {
                service_name,
                engine,
                inspect_template: false,
158
159
                request_template: model.request_template(),
                card: Some(model.into_card()),
160
            })
161
162
163
        }
        EngineConfig::StaticCore {
            engine: inner_engine,
164
            model,
165
            ..
166
        } => {
167
168
169
            let pipeline = build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
170
            >(model.card(), inner_engine)
171
            .await?;
172

173
            let service_name = model.service_name().to_string();
174
175
176
177
178
            tracing::debug!("Model: {service_name} with Dynamo pre-processing");
            Ok(PreparedEngine {
                service_name,
                engine: pipeline,
                inspect_template: true,
179
180
                request_template: model.request_template(),
                card: Some(model.into_card()),
181
            })
182
183
184
        }
    }
}
185
186
187
188
189
190
191
192
193

pub async fn build_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    engine: ExecutionContext,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
194
195
196
197
198
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
{
    let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
    let preprocessor = OpenAIPreprocessor::new((*card).clone())
        .await?
        .into_operator();
    let backend = Backend::from_mdc((*card).clone()).await?.into_operator();
    let engine = ServiceBackend::from_engine(engine);

    Ok(frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(engine)?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?)
}

216
217
218
219
pub async fn build_routed_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    client: &Client,
    router_mode: RouterMode,
220
    busy_threshold: Option<f64>,
221
222
223
224
225
226
    chooser: Option<Arc<KvRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
227
228
229
230
231
            Context<Req>,
            Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
            Context<PreprocessedRequest>,
            Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
        >,
232
233
234
235
236
{
    let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
    let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
    let backend = Backend::from_mdc(card.clone()).await?.into_operator();
    let migration = Migration::from_mdc(card.clone()).await?.into_operator();
237
238
239
240
241
242
243
    let router =
        PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
            client.clone(),
            router_mode,
            busy_threshold,
        )
        .await?;
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    let service_backend = match router_mode {
        RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
            ServiceBackend::from_engine(Arc::new(router))
        }
        RouterMode::KV => {
            let Some(chooser) = chooser else {
                anyhow::bail!("RouterMode::KV requires KVRouter to not be null");
            };
            let kv_push_router = KvPushRouter::new(router, chooser);
            ServiceBackend::from_engine(Arc::new(kv_push_router))
        }
    };

    let engine = frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(migration.forward_edge())?
        .link(service_backend)?
        .link(migration.backward_edge())?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?;
    Ok(engine)
}

269
270
271
#[cfg(test)]
mod tests {
    use super::*;
272
    use crate::types::openai::{
273
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
274
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
275
276
277
278
    };

    const HF_PATH: &str = concat!(
        env!("CARGO_MANIFEST_DIR"),
279
        "/tests/data/sample-models/mock-llama-3.1-8b-instruct"
280
281
282
283
284
    );

    #[tokio::test]
    async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
285
        let card = ModelDeploymentCard::load(HF_PATH).await?;
286
        let engine = crate::engines::make_engine_core();
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

        // Build pipeline for chat completions
        let pipeline = build_pipeline::<
            NvCreateChatCompletionRequest,
            NvCreateChatCompletionStreamResponse,
        >(&card, engine)
        .await?;

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }

    #[tokio::test]
    async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
304
        let card = ModelDeploymentCard::load(HF_PATH).await?;
305
        let engine = crate::engines::make_engine_core();
306
307
308

        // Build pipeline for completions
        let pipeline =
309
310
            build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine)
                .await?;
311
312
313
314
315
316
317

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }
}