common.rs 6.43 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::pin::Pin;

6
use dynamo_llm::{
7
    backend::{Backend, ExecutionContext},
8
    discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
9
    engines::StreamingEngineAdapter,
10
    model_card::ModelDeploymentCard,
11
    preprocessor::OpenAIPreprocessor,
12
    protocols::common::llm_backend::{BackendOutput, PreprocessedRequest},
13
14
15
16
17
18
19
20
21
    types::{
        openai::chat_completions::{
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
    },
};
use dynamo_runtime::{
22
    engine::{AsyncEngineStream, Data},
23
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
24
25
26
27
    DistributedRuntime, Runtime,
};
use std::sync::Arc;

28
use crate::EngineConfig;
29

30
31
32
33
34
35
pub struct PreparedEngine {
    pub service_name: String,
    pub engine: OpenAIChatCompletionsStreamingEngine,
    pub inspect_template: bool,
}

36
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
37
38
39
pub async fn prepare_engine(
    runtime: Runtime,
    engine_config: EngineConfig,
40
) -> anyhow::Result<PreparedEngine> {
41
    match engine_config {
42
        EngineConfig::Dynamic => {
43
44
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

45
            let Some(etcd_client) = distributed_runtime.etcd_client() else {
46
                anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
47
            };
48
            let model_manager = Arc::new(ModelManager::new());
49
50
51
52
            let watch_obj = Arc::new(ModelWatcher::new(
                distributed_runtime,
                model_manager.clone(),
                dynamo_runtime::pipeline::RouterMode::RoundRobin,
53
                None,
54
            ));
55
            let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
56
            let (_prefix, _watcher, receiver) = models_watcher.dissolve();
57

58
            let inner_watch_obj = watch_obj.clone();
59
60
61
            let _watcher_task = tokio::spawn(async move {
                inner_watch_obj.watch(receiver).await;
            });
62
            tracing::info!("Waiting for remote model..");
63

64
65
66
67
68
            // TODO: We use the first model to appear, usually we have only one
            // We should add slash commands to text input `/model <name>` to choose,
            // '/models` to list, and notifications when models are added / removed.

            let model_service_name = watch_obj.wait_for_chat_model().await;
69
            let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
70
            Ok(PreparedEngine {
71
                service_name: model_service_name,
72
73
74
                engine,
                inspect_template: false,
            })
75
        }
76
77
        EngineConfig::StaticFull { engine, model } => {
            let service_name = model.service_name().to_string();
78
            tracing::debug!("Model: {service_name} with engine pre-processing");
79
            let engine = Arc::new(StreamingEngineAdapter::new(engine));
80
81
82
83
84
            Ok(PreparedEngine {
                service_name,
                engine,
                inspect_template: false,
            })
85
86
87
        }
        EngineConfig::StaticCore {
            engine: inner_engine,
88
            model,
89
        } => {
90
91
92
            let pipeline = build_pipeline::<
                NvCreateChatCompletionRequest,
                NvCreateChatCompletionStreamResponse,
93
            >(model.card(), inner_engine)
94
            .await?;
95

96
            let service_name = model.service_name().to_string();
97
98
99
100
101
102
            tracing::debug!("Model: {service_name} with Dynamo pre-processing");
            Ok(PreparedEngine {
                service_name,
                engine: pipeline,
                inspect_template: true,
            })
103
104
105
        }
    }
}
106
107
108
109
110
111
112
113
114
115
116

pub async fn build_pipeline<Req, Resp>(
    card: &ModelDeploymentCard,
    engine: ExecutionContext,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
    Req: Data,
    Resp: Data,
    OpenAIPreprocessor: Operator<
        Context<Req>,
        Pin<Box<dyn AsyncEngineStream<Annotated<Resp>>>>,
117
        Context<PreprocessedRequest>,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
    >,
{
    let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
    let preprocessor = OpenAIPreprocessor::new((*card).clone())
        .await?
        .into_operator();
    let backend = Backend::from_mdc((*card).clone()).await?.into_operator();
    let engine = ServiceBackend::from_engine(engine);

    Ok(frontend
        .link(preprocessor.forward_edge())?
        .link(backend.forward_edge())?
        .link(engine)?
        .link(backend.backward_edge())?
        .link(preprocessor.backward_edge())?
        .link(frontend)?)
}

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_llm::types::openai::{
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
142
        completions::{CompletionResponse, NvCreateCompletionRequest},
143
144
145
146
147
148
149
150
151
152
    };

    const HF_PATH: &str = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/../../lib/llm/tests/data/sample-models/mock-llama-3.1-8b-instruct"
    );

    #[tokio::test]
    async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
153
        let card = ModelDeploymentCard::load(HF_PATH).await?;
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        let engine = dynamo_llm::engines::make_engine_core();

        // Build pipeline for chat completions
        let pipeline = build_pipeline::<
            NvCreateChatCompletionRequest,
            NvCreateChatCompletionStreamResponse,
        >(&card, engine)
        .await?;

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }

    #[tokio::test]
    async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
        // Create test model card
172
        let card = ModelDeploymentCard::load(HF_PATH).await?;
173
174
175
176
        let engine = dynamo_llm::engines::make_engine_core();

        // Build pipeline for completions
        let pipeline =
177
            build_pipeline::<NvCreateCompletionRequest, CompletionResponse>(&card, engine).await?;
178
179
180
181
182
183
184

        // Verify pipeline was created
        assert!(Arc::strong_count(&pipeline) >= 1);

        Ok(())
    }
}