text.rs 8.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
use dynamo_llm::{
17
18
19
20
    backend::Backend,
    preprocessor::OpenAIPreprocessor,
    types::{
        openai::chat_completions::{
21
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
22
23
24
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
25
26
    },
};
Neelay Shah's avatar
Neelay Shah committed
27
use dynamo_runtime::{
28
29
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
    runtime::CancellationToken,
30
    DistributedRuntime, Runtime,
31
};
32
33
use futures::StreamExt;
use std::{
34
    io::{ErrorKind, Write},
35
36
    sync::Arc,
};
37
38
39
40

use crate::EngineConfig;

/// Max response tokens for each single query. Must be less than model context size.
Paul Hendricks's avatar
Paul Hendricks committed
41
const MAX_TOKENS: u32 = 8192;
42
43

pub async fn run(
44
    runtime: Runtime,
45
    cancel_token: CancellationToken,
46
    single_prompt: Option<String>,
47
48
49
50
51
52
53
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
    let (service_name, engine, inspect_template): (
        String,
        OpenAIChatCompletionsStreamingEngine,
        bool,
    ) = match engine_config {
54
55
56
57
58
59
60
61
62
63
64
65
66
        EngineConfig::Dynamic(endpoint_id) => {
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

            let endpoint = distributed_runtime
                .namespace(endpoint_id.namespace)?
                .component(endpoint_id.component)?
                .endpoint(endpoint_id.name);

            let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
            tracing::info!("Waiting for remote model..");
            client.wait_for_endpoints().await?;
            tracing::info!("Model discovered");

67
68
            // The service_name isn't used for text chat outside of logs,
            // so use the path. That avoids having to listen on etcd for model registration.
69
            let service_name = endpoint.subject();
70
71
            (service_name, Arc::new(client), false)
        }
72
73
74
75
        EngineConfig::StaticFull {
            service_name,
            engine,
        } => {
76
            tracing::debug!("Model: {service_name}");
77
78
            (service_name, engine, false)
        }
79
80
81
82
83
84
        EngineConfig::StaticCore {
            service_name,
            engine: inner_engine,
            card,
        } => {
            let frontend = ServiceFrontend::<
85
                SingleIn<NvCreateChatCompletionRequest>,
86
                ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            >::new();
            let preprocessor = OpenAIPreprocessor::new(*card.clone())
                .await?
                .into_operator();
            let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
            let engine = ServiceBackend::from_engine(inner_engine);

            let pipeline = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(engine)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

102
            tracing::debug!("Model: {service_name} with pre-processing");
103
104
            (service_name, pipeline, true)
        }
105
        EngineConfig::None => unreachable!(),
106
    };
107
108
109
110
111
112
113
114
    main_loop(
        cancel_token,
        &service_name,
        engine,
        single_prompt,
        inspect_template,
    )
    .await
115
116
}

Paul Hendricks's avatar
Paul Hendricks committed
117
#[allow(deprecated)]
118
119
120
121
async fn main_loop(
    cancel_token: CancellationToken,
    service_name: &str,
    engine: OpenAIChatCompletionsStreamingEngine,
122
    mut initial_prompt: Option<String>,
Paul Hendricks's avatar
Paul Hendricks committed
123
    _inspect_template: bool,
124
) -> anyhow::Result<()> {
125
126
127
    if initial_prompt.is_none() {
        tracing::info!("Ctrl-c to exit");
    }
128
129
    let theme = dialoguer::theme::ColorfulTheme::default();

130
131
132
    // Initial prompt is the pipe case: `echo "Hello" | dynamo-run ..`
    // We run that single prompt and exit
    let single = initial_prompt.is_some();
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    let mut history = dialoguer::BasicHistory::default();
    let mut messages = vec![];
    while !cancel_token.is_cancelled() {
        // User input
        let prompt = match initial_prompt.take() {
            Some(p) => p,
            None => {
                let input_ui = dialoguer::Input::<String>::with_theme(&theme)
                    .history_with(&mut history)
                    .with_prompt("User");
                match input_ui.interact_text() {
                    Ok(prompt) => prompt,
                    Err(dialoguer::Error::IO(err)) => {
                        match err.kind() {
                            ErrorKind::Interrupted => {
                                // Ctrl-C
                                // Unfortunately I could not make dialoguer handle Ctrl-d
                            }
                            k => {
                                tracing::info!("IO error: {k}");
                            }
                        }
                        break;
                    }
                }
            }
        };
Paul Hendricks's avatar
Paul Hendricks committed
160
161
162
163
164
165
166
167
168

        // Construct messages
        let user_message = async_openai::types::ChatCompletionRequestMessage::User(
            async_openai::types::ChatCompletionRequestUserMessage {
                content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt),
                name: None,
            },
        );
        messages.push(user_message);
169
170

        // Request
Paul Hendricks's avatar
Paul Hendricks committed
171
172
        let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
            .messages(messages.clone())
173
174
            .model(service_name)
            .stream(true)
Paul Hendricks's avatar
Paul Hendricks committed
175
176
177
178
179
180
181
182
183
            .max_tokens(MAX_TOKENS)
            .build()?;

        // TODO We cannot set min_tokens with async-openai
        // if inspect_template {
        //     // This makes the pre-processor ignore stop tokens
        //     req_builder.min_tokens(8192);
        // }

184
        let req = NvCreateChatCompletionRequest { inner, nvext: None };
185
186
187
188
189
190
191
192

        // Call the model
        let mut stream = engine.generate(Context::new(req)).await?;

        // Stream the output to stdout
        let mut stdout = std::io::stdout();
        let mut assistant_message = String::new();
        while let Some(item) = stream.next().await {
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            match (item.data.as_ref(), item.event.as_deref()) {
                (Some(data), _) => {
                    // Normal case
                    let entry = data.inner.choices.first();
                    let chat_comp = entry.as_ref().unwrap();
                    if let Some(c) = &chat_comp.delta.content {
                        let _ = stdout.write(c.as_bytes());
                        let _ = stdout.flush();
                        assistant_message += c;
                    }
                    if chat_comp.finish_reason.is_some() {
                        tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
                        break;
                    }
                }
                (None, Some("error")) => {
                    // There's only one error but we loop in case that changes
                    for err in item.comment.unwrap_or_default() {
                        tracing::error!("Engine error: {err}");
                    }
                }
                (None, Some(annotation)) => {
                    tracing::debug!("Annotation. {annotation}: {:?}", item.comment);
                }
                _ => {
                    unreachable!("Event from engine with no data, no error, no annotation.");
                }
220
221
222
223
            }
        }
        println!();

Paul Hendricks's avatar
Paul Hendricks committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        let assistant_content =
            async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
                assistant_message,
            );

        // ALLOW: function_call is deprecated
        let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
            async_openai::types::ChatCompletionRequestAssistantMessage {
                content: Some(assistant_content),
                refusal: None,
                name: None,
                audio: None,
                tool_calls: None,
                function_call: None,
            },
        );
        messages.push(assistant_message);
241
242
243
244

        if single {
            break;
        }
245
246
247
248
    }
    println!();
    Ok(())
}