text.rs 8.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use dynemo_llm::{
17
18
19
20
    backend::Backend,
    preprocessor::OpenAIPreprocessor,
    types::{
        openai::chat_completions::{
21
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
22
23
24
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
25
26
    },
};
27
use dynemo_runtime::{
28
29
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
    runtime::CancellationToken,
30
    DistributedRuntime, Runtime,
31
};
32
33
34
35
36
use futures::StreamExt;
use std::{
    io::{ErrorKind, Read, Write},
    sync::Arc,
};
37
38
39
40

use crate::EngineConfig;

/// Max response tokens for each single query. Must be less than model context size.
Paul Hendricks's avatar
Paul Hendricks committed
41
const MAX_TOKENS: u32 = 8192;
42
43
44
45
46

/// Output of `isatty` if the fd is indeed a TTY
const IS_A_TTY: i32 = 1;

pub async fn run(
47
    runtime: Runtime,
48
49
50
51
52
53
54
55
    cancel_token: CancellationToken,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
    let (service_name, engine, inspect_template): (
        String,
        OpenAIChatCompletionsStreamingEngine,
        bool,
    ) = match engine_config {
56
57
58
59
60
61
62
63
64
65
66
67
68
        EngineConfig::Dynamic(endpoint_id) => {
            let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;

            let endpoint = distributed_runtime
                .namespace(endpoint_id.namespace)?
                .component(endpoint_id.component)?
                .endpoint(endpoint_id.name);

            let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
            tracing::info!("Waiting for remote model..");
            client.wait_for_endpoints().await?;
            tracing::info!("Model discovered");

69
70
            // The service_name isn't used for text chat outside of logs,
            // so use the path. That avoids having to listen on etcd for model registration.
71
            let service_name = endpoint.subject();
72
73
            (service_name, Arc::new(client), false)
        }
74
75
76
77
78
79
80
        EngineConfig::StaticFull {
            service_name,
            engine,
        } => {
            tracing::info!("Model: {service_name}");
            (service_name, engine, false)
        }
81
82
83
84
85
86
        EngineConfig::StaticCore {
            service_name,
            engine: inner_engine,
            card,
        } => {
            let frontend = ServiceFrontend::<
87
                SingleIn<NvCreateChatCompletionRequest>,
88
                ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            >::new();
            let preprocessor = OpenAIPreprocessor::new(*card.clone())
                .await?
                .into_operator();
            let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
            let engine = ServiceBackend::from_engine(inner_engine);

            let pipeline = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(engine)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            tracing::info!("Model: {service_name} with pre-processing");
            (service_name, pipeline, true)
        }
107
        EngineConfig::None => unreachable!(),
108
109
110
111
    };
    main_loop(cancel_token, &service_name, engine, inspect_template).await
}

Paul Hendricks's avatar
Paul Hendricks committed
112
#[allow(deprecated)]
113
114
115
116
async fn main_loop(
    cancel_token: CancellationToken,
    service_name: &str,
    engine: OpenAIChatCompletionsStreamingEngine,
Paul Hendricks's avatar
Paul Hendricks committed
117
    _inspect_template: bool,
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
) -> anyhow::Result<()> {
    tracing::info!("Ctrl-c to exit");
    let theme = dialoguer::theme::ColorfulTheme::default();

    let mut initial_prompt = if unsafe { libc::isatty(libc::STDIN_FILENO) == IS_A_TTY } {
        None
    } else {
        // Something piped in, use that as initial prompt
        let mut input = String::new();
        std::io::stdin().read_to_string(&mut input).unwrap();
        Some(input)
    };

    let mut history = dialoguer::BasicHistory::default();
    let mut messages = vec![];
    while !cancel_token.is_cancelled() {
        // User input
        let prompt = match initial_prompt.take() {
            Some(p) => p,
            None => {
                let input_ui = dialoguer::Input::<String>::with_theme(&theme)
                    .history_with(&mut history)
                    .with_prompt("User");
                match input_ui.interact_text() {
                    Ok(prompt) => prompt,
                    Err(dialoguer::Error::IO(err)) => {
                        match err.kind() {
                            ErrorKind::Interrupted => {
                                // Ctrl-C
                                // Unfortunately I could not make dialoguer handle Ctrl-d
                            }
                            k => {
                                tracing::info!("IO error: {k}");
                            }
                        }
                        break;
                    }
                }
            }
        };
Paul Hendricks's avatar
Paul Hendricks committed
158
159
160
161
162
163
164
165
166

        // Construct messages
        let user_message = async_openai::types::ChatCompletionRequestMessage::User(
            async_openai::types::ChatCompletionRequestUserMessage {
                content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt),
                name: None,
            },
        );
        messages.push(user_message);
167
168

        // Request
Paul Hendricks's avatar
Paul Hendricks committed
169
170
        let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
            .messages(messages.clone())
171
172
            .model(service_name)
            .stream(true)
Paul Hendricks's avatar
Paul Hendricks committed
173
174
175
176
177
178
179
180
181
            .max_tokens(MAX_TOKENS)
            .build()?;

        // TODO We cannot set min_tokens with async-openai
        // if inspect_template {
        //     // This makes the pre-processor ignore stop tokens
        //     req_builder.min_tokens(8192);
        // }

182
        let req = NvCreateChatCompletionRequest { inner, nvext: None };
183
184
185
186
187
188
189
190
191

        // Call the model
        let mut stream = engine.generate(Context::new(req)).await?;

        // Stream the output to stdout
        let mut stdout = std::io::stdout();
        let mut assistant_message = String::new();
        while let Some(item) = stream.next().await {
            let data = item.data.as_ref().unwrap();
Paul Hendricks's avatar
Paul Hendricks committed
192
            let entry = data.inner.choices.first();
193
194
195
196
197
198
199
200
201
202
203
204
205
            let chat_comp = entry.as_ref().unwrap();
            if let Some(c) = &chat_comp.delta.content {
                let _ = stdout.write(c.as_bytes());
                let _ = stdout.flush();
                assistant_message += c;
            }
            if chat_comp.finish_reason.is_some() {
                tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
                break;
            }
        }
        println!();

Paul Hendricks's avatar
Paul Hendricks committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        let assistant_content =
            async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
                assistant_message,
            );

        // ALLOW: function_call is deprecated
        let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
            async_openai::types::ChatCompletionRequestAssistantMessage {
                content: Some(assistant_content),
                refusal: None,
                name: None,
                audio: None,
                tool_calls: None,
                function_call: None,
            },
        );
        messages.push(assistant_message);
223
224
225
226
    }
    println!();
    Ok(())
}