common.rs 7.07 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::pin::Pin;

6
use crate::{
7
    backend::{Backend, ExecutionContext},
8
    discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
9
    engines::StreamingEngineAdapter,
10
    entrypoint::EngineConfig,
11
    model_card::ModelDeploymentCard,
12
    preprocessor::OpenAIPreprocessor,
13
    protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
14
    request_template::RequestTemplate,
15
16
17
18
19
20
21
22
23
    types::{
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
    },
};
use dynamo_runtime::{
24
    engine::{AsyncEngineStream, Data},
25
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
26
27
28
29
    DistributedRuntime, Runtime,
};
use std::sync::Arc;

30
31
32
33
pub struct PreparedEngine {
    pub service_name: String,
    pub engine: OpenAIChatCompletionsStreamingEngine,
    pub inspect_template: bool,
34
35
36
37
38
39
40
41
42
43
44
45
    pub card: Option<ModelDeploymentCard>,
    pub request_template: Option<RequestTemplate>,
}

impl PreparedEngine {
    pub fn has_tokenizer(&self) -> bool {
        if let Some(card) = self.card.as_ref() {
            card.has_tokenizer()
        } else {
            false
        }
    }
46
47
}

48
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
49
50
51
pub async fn prepare_engine(
    runtime: Runtime,
    engine_config: EngineConfig,
52
) -> anyhow::Result<PreparedEngine> {
53
    match engine_config {
54
        EngineConfig::Dynamic(local_model) => {
55
56
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

57
            let Some(etcd_client) = distributed_runtime.etcd_client() else {
58
                anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
59
            };
60
            let model_manager = Arc::new(ModelManager::new());
61
62
63
64
            let watch_obj = Arc::new(ModelWatcher::new(
                distributed_runtime,
                model_manager.clone(),
                dynamo_runtime::pipeline::RouterMode::RoundRobin,
65
                None,
66
            ));
67
            let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
68
            let (_prefix, _watcher, receiver) = models_watcher.dissolve();
69

70
            let inner_watch_obj = watch_obj.clone();
71
72
73
            let _watcher_task = tokio::spawn(async move {
                inner_watch_obj.watch(receiver).await;
            });
74
            tracing::info!("Waiting for remote model..");
75

76
77
78
79
80
            // TODO: We use the first model to appear, usually we have only one
            // We should add slash commands to text input `/model <name>` to choose,
            // '/models` to list, and notifications when models are added / removed.

            let model_service_name = watch_obj.wait_for_chat_model().await;
81
            let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
82
            Ok(PreparedEngine {
83
                service_name: model_service_name,
84
85
                engine,
                inspect_template: false,
86
87
                card: None,
                request_template: local_model.request_template(),
88
            })
89
        }
90
91
        EngineConfig::StaticFull { engine, model } => {
            let service_name = model.service_name().to_string();
92
            tracing::debug!("Model: {service_name} with engine pre-processing");
93
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
94
95
96
97
            Ok(PreparedEngine {
                service_name,
                engine,
                inspect_template: false,
98
99
                request_template: model.request_template(),
                card: Some(model.into_card()),
100
            })
101
102
103
        }
        EngineConfig::StaticCore {
            engine: inner_engine,
104
            model,
105
        } => {
106
107
108
            let pipeline = build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
109
            >(model.card(), inner_engine)
110
            .await?;
111

112
            let service_name = model.service_name().to_string();
113
114
115
116
117
            tracing::debug!("Model: {service_name} with Dynamo pre-processing");
            Ok(PreparedEngine {
                service_name,
                engine: pipeline,
                inspect_template: true,
118
119
                request_template: model.request_template(),
                card: Some(model.into_card()),
120
            })
121
122
123
        }
    }
}
124
125
126
127
128
129
130
131
132
133
134

pub async fn build_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    engine: ExecutionContext,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
        Context<Req>,
        Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
135
        Context<PreprocessedRequest>,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
    >,
{
    let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
    let preprocessor = OpenAIPreprocessor::new((*card).clone())
        .await?
        .into_operator();
    let backend = Backend::from_mdc((*card).clone()).await?.into_operator();
    let engine = ServiceBackend::from_engine(engine);

    Ok(frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(engine)?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?)
}

#[cfg(test)]
mod tests {
    use super::*;
158
    use crate::types::openai::{
159
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
160
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
161
162
163
164
    };

    const HF_PATH: &str = concat!(
        env!("CARGO_MANIFEST_DIR"),
165
        "/tests/data/sample-models/mock-llama-3.1-8b-instruct"
166
167
168
169
170
    );

    #[tokio::test]
    async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
171
        let card = ModelDeploymentCard::load(HF_PATH).await?;
172
        let engine = crate::engines::make_engine_core();
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

        // Build pipeline for chat completions
        let pipeline = build_pipeline::<
            NvCreateChatCompletionRequest,
            NvCreateChatCompletionStreamResponse,
        >(&card, engine)
        .await?;

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }

    #[tokio::test]
    async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
190
        let card = ModelDeploymentCard::load(HF_PATH).await?;
191
        let engine = crate::engines::make_engine_core();
192
193
194

        // Build pipeline for completions
        let pipeline =
195
196
            build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine)
                .await?;
197
198
199
200
201
202
203

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }
}