text.rs 8.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
use dynamo_llm::{
17
18
19
20
    backend::Backend,
    preprocessor::OpenAIPreprocessor,
    types::{
        openai::chat_completions::{
21
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
22
23
24
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
25
26
    },
};
Neelay Shah's avatar
Neelay Shah committed
27
use dynamo_runtime::{
28
29
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
    runtime::CancellationToken,
30
    DistributedRuntime, Runtime,
31
};
32
33
34
35
36
use futures::StreamExt;
use std::{
    io::{ErrorKind, Read, Write},
    sync::Arc,
};
37
38
39
40

use crate::EngineConfig;

/// Max response tokens for each single query. Must be less than model context size.
Paul Hendricks's avatar
Paul Hendricks committed
41
const MAX_TOKENS: u32 = 8192;
42
43
44
45
46

/// Output of `isatty` if the fd is indeed a TTY
const IS_A_TTY: i32 = 1;

pub async fn run(
47
    runtime: Runtime,
48
49
50
51
52
53
54
55
    cancel_token: CancellationToken,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
    let (service_name, engine, inspect_template): (
        String,
        OpenAIChatCompletionsStreamingEngine,
        bool,
    ) = match engine_config {
56
57
58
59
60
61
62
63
64
65
66
67
68
        EngineConfig::Dynamic(endpoint_id) => {
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

            let endpoint = distributed_runtime
                .namespace(endpoint_id.namespace)?
                .component(endpoint_id.component)?
                .endpoint(endpoint_id.name);

            let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
            tracing::info!("Waiting for remote model..");
            client.wait_for_endpoints().await?;
            tracing::info!("Model discovered");

69
70
            // The service_name isn't used for text chat outside of logs,
            // so use the path. That avoids having to listen on etcd for model registration.
71
            let service_name = endpoint.subject();
72
73
            (service_name, Arc::new(client), false)
        }
74
75
76
77
78
79
80
        EngineConfig::StaticFull {
            service_name,
            engine,
        } => {
            tracing::info!("Model: {service_name}");
            (service_name, engine, false)
        }
81
82
83
84
85
86
        EngineConfig::StaticCore {
            service_name,
            engine: inner_engine,
            card,
        } => {
            let frontend = ServiceFrontend::<
87
                SingleIn<NvCreateChatCompletionRequest>,
88
                ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            >::new();
            let preprocessor = OpenAIPreprocessor::new(*card.clone())
                .await?
                .into_operator();
            let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
            let engine = ServiceBackend::from_engine(inner_engine);

            let pipeline = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(engine)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            tracing::info!("Model: {service_name} with pre-processing");
            (service_name, pipeline, true)
        }
107
        EngineConfig::None => unreachable!(),
108
109
110
111
    };
    main_loop(cancel_token, &service_name, engine, inspect_template).await
}

Paul Hendricks's avatar
Paul Hendricks committed
112
#[allow(deprecated)]
113
114
115
116
async fn main_loop(
    cancel_token: CancellationToken,
    service_name: &str,
    engine: OpenAIChatCompletionsStreamingEngine,
Paul Hendricks's avatar
Paul Hendricks committed
117
    _inspect_template: bool,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
) -> anyhow::Result<()> {
    tracing::info!("Ctrl-c to exit");
    let theme = dialoguer::theme::ColorfulTheme::default();

    let mut initial_prompt = if unsafe { libc::isatty(libc::STDIN_FILENO) == IS_A_TTY } {
        None
    } else {
        // Something piped in, use that as initial prompt
        let mut input = String::new();
        std::io::stdin().read_to_string(&mut input).unwrap();
        Some(input)
    };

    let mut history = dialoguer::BasicHistory::default();
    let mut messages = vec![];
    while !cancel_token.is_cancelled() {
        // User input
        let prompt = match initial_prompt.take() {
            Some(p) => p,
            None => {
                let input_ui = dialoguer::Input::<String>::with_theme(&theme)
                    .history_with(&mut history)
                    .with_prompt("User");
                match input_ui.interact_text() {
                    Ok(prompt) => prompt,
                    Err(dialoguer::Error::IO(err)) => {
                        match err.kind() {
                            ErrorKind::Interrupted => {
                                // Ctrl-C
                                // Unfortunately I could not make dialoguer handle Ctrl-d
                            }
                            k => {
                                tracing::info!("IO error: {k}");
                            }
                        }
                        break;
                    }
                }
            }
        };
Paul Hendricks's avatar
Paul Hendricks committed
158
159
160
161
162
163
164
165
166

        // Construct messages
        let user_message = async_openai::types::ChatCompletionRequestMessage::User(
            async_openai::types::ChatCompletionRequestUserMessage {
                content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt),
                name: None,
            },
        );
        messages.push(user_message);
167
168

        // Request
Paul Hendricks's avatar
Paul Hendricks committed
169
170
        let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
            .messages(messages.clone())
171
172
            .model(service_name)
            .stream(true)
Paul Hendricks's avatar
Paul Hendricks committed
173
174
175
176
177
178
179
180
181
            .max_tokens(MAX_TOKENS)
            .build()?;

        // TODO We cannot set min_tokens with async-openai
        // if inspect_template {
        //     // This makes the pre-processor ignore stop tokens
        //     req_builder.min_tokens(8192);
        // }

182
        let req = NvCreateChatCompletionRequest { inner, nvext: None };
183
184
185
186
187
188
189
190

        // Call the model
        let mut stream = engine.generate(Context::new(req)).await?;

        // Stream the output to stdout
        let mut stdout = std::io::stdout();
        let mut assistant_message = String::new();
        while let Some(item) = stream.next().await {
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            match (item.data.as_ref(), item.event.as_deref()) {
                (Some(data), _) => {
                    // Normal case
                    let entry = data.inner.choices.first();
                    let chat_comp = entry.as_ref().unwrap();
                    if let Some(c) = &chat_comp.delta.content {
                        let _ = stdout.write(c.as_bytes());
                        let _ = stdout.flush();
                        assistant_message += c;
                    }
                    if chat_comp.finish_reason.is_some() {
                        tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
                        break;
                    }
                }
                (None, Some("error")) => {
                    // There's only one error but we loop in case that changes
                    for err in item.comment.unwrap_or_default() {
                        tracing::error!("Engine error: {err}");
                    }
                }
                (None, Some(annotation)) => {
                    tracing::debug!("Annotation. {annotation}: {:?}", item.comment);
                }
                _ => {
                    unreachable!("Event from engine with no data, no error, no annotation.");
                }
218
219
220
221
            }
        }
        println!();

Paul Hendricks's avatar
Paul Hendricks committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        let assistant_content =
            async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
                assistant_message,
            );

        // ALLOW: function_call is deprecated
        let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
            async_openai::types::ChatCompletionRequestAssistantMessage {
                content: Some(assistant_content),
                refusal: None,
                name: None,
                audio: None,
                tool_calls: None,
                function_call: None,
            },
        );
        messages.push(assistant_message);
239
240
241
242
    }
    println!();
    Ok(())
}