batch.rs 10.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
7
8
    NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
9
use anyhow::Context as _;
10
use dynamo_async_openai::types::FinishReason;
11
12
13
14
15
16
17
18
19
20
use dynamo_runtime::{pipeline::Context, runtime::CancellationToken, Runtime};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

21
22
use crate::entrypoint::input::common;
use crate::entrypoint::EngineConfig;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;

const OUTPUT_FILENAME: &str = "output.jsonl";

#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
    // The input files only have this
    text: String,

    response: Option<String>,

    #[serde(default)]
    tokens_in: usize,

    #[serde(default)]
    tokens_out: usize,

    #[serde(default)]
    elapsed_ms: usize,
45
46
47
48
49
50

    #[serde(default, skip_serializing_if = "Option::is_none")]
    finish_reason: Option<FinishReason>,

    #[serde(skip, default)]
    request_id: usize,
51
52
53
54
55
56
57
}

pub async fn run(
    runtime: Runtime,
    input_jsonl: PathBuf,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
58
    let cancel_token = runtime.primary_token();
59
60
61
62
63
64
65
66
    // Check if the path exists and is a directory
    if !input_jsonl.exists() || !input_jsonl.is_file() {
        anyhow::bail!(
            "Missing or not a file: {}. Should be a JSON Lines file.",
            input_jsonl.display()
        );
    }

67
    let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
68

69
70
    let pre_processor = if prepared_engine.has_tokenizer() {
        Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?)
71
72
73
74
75
76
77
78
    } else {
        None
    };
    let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
    let dw_cancel_token = cancel_token.clone();
    let mut output_file = input_jsonl.clone();
    output_file.set_file_name(OUTPUT_FILENAME);
    tokio::spawn(async move {
79
        if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
80
81
82
            tracing::error!(%err, "Failed writing output to {}", output_file.display());
        }
    });
83
    let service_name_ref = Arc::new(prepared_engine.service_name);
84
85
86
87
88
89
90
91
92
93
94
95
96

    let tokens_in = Arc::new(AtomicU64::new(0));
    let tokens_out = Arc::new(AtomicU64::new(0));
    let mut handles = vec![];
    let mut num_entries = 0;
    let input_file = tokio::fs::File::open(&input_jsonl)
        .await
        .with_context(|| input_jsonl.display().to_string())?;
    let buffered_input = tokio::io::BufReader::new(input_file);

    tracing::info!("Timer start.");
    let start = Instant::now();
    let mut lines = buffered_input.lines();
97
    let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    while let Ok(Some(line)) = lines.next_line().await {
        if cancel_token.is_cancelled() {
            break;
        }
        if line.is_empty() {
            continue;
        }
        let request_id = num_entries;
        num_entries += 1;
        let mut entry: Entry = match serde_json::from_str(&line) {
            Ok(entry) => entry,
            Err(err) => {
                anyhow::bail!("Error parsing entry: '{line}'. {err}");
            }
        };
113
        entry.request_id = request_id;
114

115
        let engine = prepared_engine.engine.clone();
116
117
118
119
        let pre_processor = pre_processor.clone();
        let tokens_in = tokens_in.clone();
        let tokens_out = tokens_out.clone();
        let done_entries_tx = done_entries_tx.clone();
120
        let service_name_ref = service_name_ref.clone();
121
        let template_clone = template.clone();
122
123
        let handle = tokio::spawn(async move {
            let local_start = Instant::now();
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            let response = match evaluate(
                request_id,
                service_name_ref.as_str(),
                engine,
                &mut entry,
                template_clone,
            )
            .await
            {
                Ok(r) => r,
                Err(err) => {
                    tracing::error!(%err, entry.text, "Failed evaluating prompt");
                    return;
                }
            };
139
140
141
142
143
144
            let local_elapsed = Instant::now() - local_start;
            entry.elapsed_ms = local_elapsed.as_millis() as usize;

            if let Some(pre) = pre_processor {
                // Note this does not include the prompt template. Probably TODO
                entry.tokens_in = match pre.tokenize(&entry.text) {
145
                    Ok(encoding) => encoding.token_ids().len(),
146
147
148
149
150
151
                    Err(err) => {
                        tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
                        0
                    }
                };
                entry.tokens_out = match pre.tokenize(&response) {
152
                    Ok(encoding) => encoding.token_ids().len(),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                    Err(err) => {
                        tracing::warn!(%err, response, "Failed tokenizing response");
                        0
                    }
                };
                tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
                tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
            }
            entry.response = Some(response);

            let _ = done_entries_tx.send(entry).await;
        });
        handles.push(handle);
    }
    tokio::select! {
        _ = cancel_token.cancelled() => {
            // Don't print stats
            return Ok(());
        }
        _ = futures::future::join_all(handles) => {
        }
    }
    let elapsed = Instant::now() - start;
    let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
    let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
    let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
    tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
    tracing::info!(
        "Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
        num_entries,
        humantime::format_duration(elapsed_clean),
        tokens_in,
        tokens_in / cmp::max(elapsed.as_secs(), 1),
        tokens_out,
        tokens_out / cmp::max(elapsed.as_secs(), 1),
    );
189
    cancel_token.cancel(); // stop everything else
190
191
192
193
194
195

    Ok(())
}

// Run a single prompt through the engine
async fn evaluate(
196
197
    request_id: usize,
    service_name: &str,
198
    engine: OpenAIChatCompletionsStreamingEngine,
199
    entry: &mut Entry,
200
    template: Option<Arc<RequestTemplate>>,
201
) -> anyhow::Result<String> {
202
203
204
    let user_message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
        dynamo_async_openai::types::ChatCompletionRequestUserMessage {
            content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
205
                entry.text.clone(),
206
207
208
209
            ),
            name: None,
        },
    );
210
    let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
211
        .messages(vec![user_message])
212
213
214
215
216
        .model(
            template
                .as_ref()
                .map_or_else(|| service_name.to_string(), |t| t.model.clone()),
        )
217
        .stream(true)
218
219
220
221
222
223
        .max_completion_tokens(
            template
                .as_ref()
                .map_or(MAX_TOKENS, |t| t.max_completion_tokens),
        )
        .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
224
        .build()?;
225
226
227
228
229
    let req = NvCreateChatCompletionRequest {
        inner,
        common: Default::default(),
        nvext: None,
    };
230
231
232
233
234
235
    let mut stream = engine.generate(Context::new(req)).await?;
    let mut output = String::new();
    while let Some(item) = stream.next().await {
        match (item.data.as_ref(), item.event.as_deref()) {
            (Some(data), _) => {
                // Normal case
236
                let choice = data.choices.first();
237
                let chat_comp = choice.as_ref().unwrap();
238
239
240
                if let Some(c) = &chat_comp.delta.content {
                    output += c;
                }
241
                entry.finish_reason = chat_comp.finish_reason;
242
                if chat_comp.finish_reason.is_some() {
243
244
245
246
247
                    tracing::trace!(
                        request_id,
                        "finish reason: {:?}",
                        chat_comp.finish_reason.unwrap()
                    );
248
249
250
251
                    break;
                }
            }
            (None, Some("error")) => {
252
                tracing::error!(request_id, "the error case");
253
254
                // There's only one error but we loop in case that changes
                for err in item.comment.unwrap_or_default() {
255
                    tracing::error!(request_id, "Engine error: {err}");
256
257
258
                }
            }
            (None, Some(annotation)) => {
259
                tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            }
            _ => {
                unreachable!("Event from engine with no data, no error, no annotation.");
            }
        }
    }
    Ok(output)
}

async fn output_writer(
    cancel_token: CancellationToken,
    mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
    output_file: &Path,
) -> anyhow::Result<()> {
    let mut num_completed = 0;
    let mut f = tokio::fs::File::create(output_file).await?;
    loop {
277
        let entry = tokio::select! {
278
279
280
            _ = cancel_token.cancelled() => {
                break;
            }
281
282
283
284
285
            maybe_entry = entries_rx.recv() => {
                match maybe_entry {
                    Some(entry) => entry,
                    None => {break;}
                }
286
287
288
289
290
291
292
293
294
            }
        };
        let mut s = serde_json::to_string(&entry)?;
        s.push('\n');
        f.write_all(s.as_bytes()).await?;

        num_completed += 1;
        // TODO: Progress bar. We'd have to count the lines in the input first,
        // and the input maybe be large
295
        tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
296
297
298
    }
    Ok(())
}