text.rs 7.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use futures::StreamExt;
17
18
19
20
use std::{
    io::{ErrorKind, Read, Write},
    sync::Arc,
};
Neelay Shah's avatar
Neelay Shah committed
21
use triton_distributed_llm::{
22
23
24
25
    backend::Backend,
    preprocessor::OpenAIPreprocessor,
    types::{
        openai::chat_completions::{
26
            NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
27
28
29
            OpenAIChatCompletionsStreamingEngine,
        },
        Annotated,
30
31
    },
};
32
33
34
35
use triton_distributed_runtime::{
    pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
    runtime::CancellationToken,
};
36
37
38
39

use crate::EngineConfig;

/// Max response tokens for each single query. Must be less than model context size.
Paul Hendricks's avatar
Paul Hendricks committed
40
const MAX_TOKENS: u32 = 8192;
41
42
43
44
45
46
47
48
49
50
51
52
53

/// Output of `isatty` if the fd is indeed a TTY
const IS_A_TTY: i32 = 1;

pub async fn run(
    cancel_token: CancellationToken,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
    let (service_name, engine, inspect_template): (
        String,
        OpenAIChatCompletionsStreamingEngine,
        bool,
    ) = match engine_config {
54
55
56
57
58
59
60
        EngineConfig::Dynamic(client) => {
            // The service_name isn't used for text chat outside of logs,
            // so use the path. That avoids having to listen on etcd for model registration.
            let service_name = client.path();
            tracing::info!("Model: {service_name}");
            (service_name, Arc::new(client), false)
        }
61
62
63
64
65
66
67
        EngineConfig::StaticFull {
            service_name,
            engine,
        } => {
            tracing::info!("Model: {service_name}");
            (service_name, engine, false)
        }
68
69
70
71
72
73
        EngineConfig::StaticCore {
            service_name,
            engine: inner_engine,
            card,
        } => {
            let frontend = ServiceFrontend::<
74
                SingleIn<NvCreateChatCompletionRequest>,
75
                ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            >::new();
            let preprocessor = OpenAIPreprocessor::new(*card.clone())
                .await?
                .into_operator();
            let backend = Backend::from_mdc(*card.clone()).await?.into_operator();
            let engine = ServiceBackend::from_engine(inner_engine);

            let pipeline = frontend
                .link(preprocessor.forward_edge())?
                .link(backend.forward_edge())?
                .link(engine)?
                .link(backend.backward_edge())?
                .link(preprocessor.backward_edge())?
                .link(frontend)?;

            tracing::info!("Model: {service_name} with pre-processing");
            (service_name, pipeline, true)
        }
94
95
96
97
    };
    main_loop(cancel_token, &service_name, engine, inspect_template).await
}

Paul Hendricks's avatar
Paul Hendricks committed
98
#[allow(deprecated)]
99
100
101
102
async fn main_loop(
    cancel_token: CancellationToken,
    service_name: &str,
    engine: OpenAIChatCompletionsStreamingEngine,
Paul Hendricks's avatar
Paul Hendricks committed
103
    _inspect_template: bool,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
) -> anyhow::Result<()> {
    tracing::info!("Ctrl-c to exit");
    let theme = dialoguer::theme::ColorfulTheme::default();

    let mut initial_prompt = if unsafe { libc::isatty(libc::STDIN_FILENO) == IS_A_TTY } {
        None
    } else {
        // Something piped in, use that as initial prompt
        let mut input = String::new();
        std::io::stdin().read_to_string(&mut input).unwrap();
        Some(input)
    };

    let mut history = dialoguer::BasicHistory::default();
    let mut messages = vec![];
    while !cancel_token.is_cancelled() {
        // User input
        let prompt = match initial_prompt.take() {
            Some(p) => p,
            None => {
                let input_ui = dialoguer::Input::<String>::with_theme(&theme)
                    .history_with(&mut history)
                    .with_prompt("User");
                match input_ui.interact_text() {
                    Ok(prompt) => prompt,
                    Err(dialoguer::Error::IO(err)) => {
                        match err.kind() {
                            ErrorKind::Interrupted => {
                                // Ctrl-C
                                // Unfortunately I could not make dialoguer handle Ctrl-d
                            }
                            k => {
                                tracing::info!("IO error: {k}");
                            }
                        }
                        break;
                    }
                }
            }
        };
Paul Hendricks's avatar
Paul Hendricks committed
144
145
146
147
148
149
150
151
152

        // Construct messages
        let user_message = async_openai::types::ChatCompletionRequestMessage::User(
            async_openai::types::ChatCompletionRequestUserMessage {
                content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt),
                name: None,
            },
        );
        messages.push(user_message);
153
154

        // Request
Paul Hendricks's avatar
Paul Hendricks committed
155
156
        let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
            .messages(messages.clone())
157
158
            .model(service_name)
            .stream(true)
Paul Hendricks's avatar
Paul Hendricks committed
159
160
161
162
163
164
165
166
167
            .max_tokens(MAX_TOKENS)
            .build()?;

        // TODO We cannot set min_tokens with async-openai
        // if inspect_template {
        //     // This makes the pre-processor ignore stop tokens
        //     req_builder.min_tokens(8192);
        // }

168
        let req = NvCreateChatCompletionRequest { inner, nvext: None };
169
170
171
172
173
174
175
176
177

        // Call the model
        let mut stream = engine.generate(Context::new(req)).await?;

        // Stream the output to stdout
        let mut stdout = std::io::stdout();
        let mut assistant_message = String::new();
        while let Some(item) = stream.next().await {
            let data = item.data.as_ref().unwrap();
Paul Hendricks's avatar
Paul Hendricks committed
178
            let entry = data.inner.choices.first();
179
180
181
182
183
184
185
186
187
188
189
190
191
            let chat_comp = entry.as_ref().unwrap();
            if let Some(c) = &chat_comp.delta.content {
                let _ = stdout.write(c.as_bytes());
                let _ = stdout.flush();
                assistant_message += c;
            }
            if chat_comp.finish_reason.is_some() {
                tracing::trace!("finish reason: {:?}", chat_comp.finish_reason.unwrap());
                break;
            }
        }
        println!();

Paul Hendricks's avatar
Paul Hendricks committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        let assistant_content =
            async_openai::types::ChatCompletionRequestAssistantMessageContent::Text(
                assistant_message,
            );

        // ALLOW: function_call is deprecated
        let assistant_message = async_openai::types::ChatCompletionRequestMessage::Assistant(
            async_openai::types::ChatCompletionRequestAssistantMessage {
                content: Some(assistant_content),
                refusal: None,
                name: None,
                audio: None,
                tool_calls: None,
                function_call: None,
            },
        );
        messages.push(assistant_message);
209
210
211
212
    }
    println!();
    Ok(())
}