batch.rs 10.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
7
8
    NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
9
10
use anyhow::Context as _;
use async_openai::types::FinishReason;
11
12
13
14
15
16
17
18
19
20
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

21
22
use crate::entrypoint::input::common;
use crate::entrypoint::EngineConfig;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;

const OUTPUT_FILENAME: &str = "output.jsonl";

#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
    // The input files only have this
    text: String,

    response: Option<String>,

    #[serde(default)]
    tokens_in: usize,

    #[serde(default)]
    tokens_out: usize,

    #[serde(default)]
    elapsed_ms: usize,
45
46
47
48
49
50

    #[serde(default, skip_serializing_if = "Option::is_none")]
    finish_reason: Option<FinishReason>,

    #[serde(skip, default)]
    request_id: usize,
51
52
53
54
55
56
57
}

pub async fn run(
    runtime: Runtime,
    input_jsonl: PathBuf,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
58
    let cancel_token = runtime.primary_token();
59
60
61
62
63
64
65
66
    // Check if the path exists and is a directory
    if !input_jsonl.exists() || !input_jsonl.is_file() {
        anyhow::bail!(
            "Missing or not a file: {}. Should be a JSON Lines file.",
            input_jsonl.display()
        );
    }

67
    let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
68

69
70
    let pre_processor = if prepared_engine.has_tokenizer() {
        Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?)
71
72
73
74
75
76
77
78
    } else {
        None
    };
    let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
    let dw_cancel_token = cancel_token.clone();
    let mut output_file = input_jsonl.clone();
    output_file.set_file_name(OUTPUT_FILENAME);
    tokio::spawn(async move {
79
        if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
80
81
82
            tracing::error!(%err, "Failed writing output to {}", output_file.display());
        }
    });
83
    let service_name_ref = Arc::new(prepared_engine.service_name);
84
85
86
87
88
89
90
91
92
93
94
95
96

    let tokens_in = Arc::new(AtomicU64::new(0));
    let tokens_out = Arc::new(AtomicU64::new(0));
    let mut handles = vec![];
    let mut num_entries = 0;
    let input_file = tokio::fs::File::open(&input_jsonl)
        .await
        .with_context(|| input_jsonl.display().to_string())?;
    let buffered_input = tokio::io::BufReader::new(input_file);

    tracing::info!("Timer start.");
    let start = Instant::now();
    let mut lines = buffered_input.lines();
97
    let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    while let Ok(Some(line)) = lines.next_line().await {
        if cancel_token.is_cancelled() {
            break;
        }
        if line.is_empty() {
            continue;
        }
        let request_id = num_entries;
        num_entries += 1;
        let mut entry: Entry = match serde_json::from_str(&line) {
            Ok(entry) => entry,
            Err(err) => {
                anyhow::bail!("Error parsing entry: '{line}'. {err}");
            }
        };
113
        entry.request_id = request_id;
114

115
        let engine = prepared_engine.engine.clone();
116
117
118
119
        let pre_processor = pre_processor.clone();
        let tokens_in = tokens_in.clone();
        let tokens_out = tokens_out.clone();
        let done_entries_tx = done_entries_tx.clone();
120
        let service_name_ref = service_name_ref.clone();
121
        let template_clone = template.clone();
122
123
        let handle = tokio::spawn(async move {
            let local_start = Instant::now();
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            let response = match evaluate(
                request_id,
                service_name_ref.as_str(),
                engine,
                &mut entry,
                template_clone,
            )
            .await
            {
                Ok(r) => r,
                Err(err) => {
                    tracing::error!(%err, entry.text, "Failed evaluating prompt");
                    return;
                }
            };
139
140
141
142
143
144
            let local_elapsed = Instant::now() - local_start;
            entry.elapsed_ms = local_elapsed.as_millis() as usize;

            if let Some(pre) = pre_processor {
                // Note this does not include the prompt template. Probably TODO
                entry.tokens_in = match pre.tokenize(&entry.text) {
145
                    Ok(encoding) => encoding.token_ids().len(),
146
147
148
149
150
151
                    Err(err) => {
                        tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
                        0
                    }
                };
                entry.tokens_out = match pre.tokenize(&response) {
152
                    Ok(encoding) => encoding.token_ids().len(),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                    Err(err) => {
                        tracing::warn!(%err, response, "Failed tokenizing response");
                        0
                    }
                };
                tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
                tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
            }
            entry.response = Some(response);

            let _ = done_entries_tx.send(entry).await;
        });
        handles.push(handle);
    }
    tokio::select! {
        _ = cancel_token.cancelled() => {
            // Don't print stats
            return Ok(());
        }
        _ = futures::future::join_all(handles) => {
        }
    }
    let elapsed = Instant::now() - start;
    let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
    let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
    let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
    tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
    tracing::info!(
        "Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
        num_entries,
        humantime::format_duration(elapsed_clean),
        tokens_in,
        tokens_in / cmp::max(elapsed.as_secs(), 1),
        tokens_out,
        tokens_out / cmp::max(elapsed.as_secs(), 1),
    );
189
    cancel_token.cancel(); // stop everything else
190
191
192
193
194
195

    Ok(())
}

// Run a single prompt through the engine
async fn evaluate(
196
197
    request_id: usize,
    service_name: &str,
198
    engine: OpenAIChatCompletionsStreamingEngine,
199
    entry: &mut Entry,
200
    template: Option<Arc<RequestTemplate>>,
201
202
203
204
) -> anyhow::Result<String> {
    let user_message = async_openai::types::ChatCompletionRequestMessage::User(
        async_openai::types::ChatCompletionRequestUserMessage {
            content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
205
                entry.text.clone(),
206
207
208
209
210
211
            ),
            name: None,
        },
    );
    let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
        .messages(vec![user_message])
212
213
214
215
216
        .model(
            template
                .as_ref()
                .map_or_else(|| service_name.to_string(), |t| t.model.clone()),
        )
217
        .stream(true)
218
219
220
221
222
223
        .max_completion_tokens(
            template
                .as_ref()
                .map_or(MAX_TOKENS, |t| t.max_completion_tokens),
        )
        .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
224
225
226
227
228
229
230
231
        .build()?;
    let req = NvCreateChatCompletionRequest { inner, nvext: None };
    let mut stream = engine.generate(Context::new(req)).await?;
    let mut output = String::new();
    while let Some(item) = stream.next().await {
        match (item.data.as_ref(), item.event.as_deref()) {
            (Some(data), _) => {
                // Normal case
232
233
                let choice = data.inner.choices.first();
                let chat_comp = choice.as_ref().unwrap();
234
235
236
                if let Some(c) = &chat_comp.delta.content {
                    output += c;
                }
237
                entry.finish_reason = chat_comp.finish_reason;
238
                if chat_comp.finish_reason.is_some() {
239
240
241
242
243
                    tracing::trace!(
                        request_id,
                        "finish reason: {:?}",
                        chat_comp.finish_reason.unwrap()
                    );
244
245
246
247
                    break;
                }
            }
            (None, Some("error")) => {
248
                tracing::error!(request_id, "the error case");
249
250
                // There's only one error but we loop in case that changes
                for err in item.comment.unwrap_or_default() {
251
                    tracing::error!(request_id, "Engine error: {err}");
252
253
254
                }
            }
            (None, Some(annotation)) => {
255
                tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            }
            _ => {
                unreachable!("Event from engine with no data, no error, no annotation.");
            }
        }
    }
    Ok(output)
}

async fn output_writer(
    cancel_token: CancellationToken,
    mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
    output_file: &Path,
) -> anyhow::Result<()> {
    let mut num_completed = 0;
    let mut f = tokio::fs::File::create(output_file).await?;
    loop {
273
        let entry = tokio::select! {
274
275
276
            _ = cancel_token.cancelled() => {
                break;
            }
277
278
279
280
281
            maybe_entry = entries_rx.recv() => {
                match maybe_entry {
                    Some(entry) => entry,
                    None => {break;}
                }
282
283
284
285
286
287
288
289
290
            }
        };
        let mut s = serde_json::to_string(&entry)?;
        s.push('\n');
        f.write_all(s.as_bytes()).await?;

        num_completed += 1;
        // TODO: Progress bar. We'd have to count the lines in the input first,
        // and the input maybe be large
291
        tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
292
293
294
    }
    Ok(())
}