batch.rs 10.4 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
7
8
    NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
9
use anyhow::Context as _;
10
use dynamo_async_openai::types::FinishReason;
11
use dynamo_runtime::{Runtime, pipeline::Context, runtime::CancellationToken};
12
13
14
15
16
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::Arc;
17
use std::sync::atomic::{AtomicU64, Ordering};
18
19
20
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

21
use crate::entrypoint::EngineConfig;
22
use crate::entrypoint::input::common;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;

const OUTPUT_FILENAME: &str = "output.jsonl";

#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
    // The input files only have this
    text: String,

    response: Option<String>,

    #[serde(default)]
    tokens_in: usize,

    #[serde(default)]
    tokens_out: usize,

    #[serde(default)]
    elapsed_ms: usize,
45
46
47
48
49
50

    #[serde(default, skip_serializing_if = "Option::is_none")]
    finish_reason: Option<FinishReason>,

    #[serde(skip, default)]
    request_id: usize,
51
52
53
54
55
56
57
}

pub async fn run(
    runtime: Runtime,
    input_jsonl: PathBuf,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
58
    let cancel_token = runtime.primary_token();
59
60
61
62
63
64
65
66
    // Check if the path exists and is a directory
    if !input_jsonl.exists() || !input_jsonl.is_file() {
        anyhow::bail!(
            "Missing or not a file: {}. Should be a JSON Lines file.",
            input_jsonl.display()
        );
    }

67
    let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
68

69
    let pre_processor = if prepared_engine.has_tokenizer() {
70
71
72
        Some(OpenAIPreprocessor::new(
            prepared_engine.card.take().unwrap(),
        )?)
73
74
75
76
77
78
79
80
    } else {
        None
    };
    let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
    let dw_cancel_token = cancel_token.clone();
    let mut output_file = input_jsonl.clone();
    output_file.set_file_name(OUTPUT_FILENAME);
    tokio::spawn(async move {
81
        if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
82
83
84
            tracing::error!(%err, "Failed writing output to {}", output_file.display());
        }
    });
85
    let service_name_ref = Arc::new(prepared_engine.service_name);
86
87
88
89
90
91
92
93
94
95
96
97
98

    let tokens_in = Arc::new(AtomicU64::new(0));
    let tokens_out = Arc::new(AtomicU64::new(0));
    let mut handles = vec![];
    let mut num_entries = 0;
    let input_file = tokio::fs::File::open(&input_jsonl)
        .await
        .with_context(|| input_jsonl.display().to_string())?;
    let buffered_input = tokio::io::BufReader::new(input_file);

    tracing::info!("Timer start.");
    let start = Instant::now();
    let mut lines = buffered_input.lines();
99
    let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    while let Ok(Some(line)) = lines.next_line().await {
        if cancel_token.is_cancelled() {
            break;
        }
        if line.is_empty() {
            continue;
        }
        let request_id = num_entries;
        num_entries += 1;
        let mut entry: Entry = match serde_json::from_str(&line) {
            Ok(entry) => entry,
            Err(err) => {
                anyhow::bail!("Error parsing entry: '{line}'. {err}");
            }
        };
115
        entry.request_id = request_id;
116

117
        let engine = prepared_engine.engine.clone();
118
119
120
121
        let pre_processor = pre_processor.clone();
        let tokens_in = tokens_in.clone();
        let tokens_out = tokens_out.clone();
        let done_entries_tx = done_entries_tx.clone();
122
        let service_name_ref = service_name_ref.clone();
123
        let template_clone = template.clone();
124
125
        let handle = tokio::spawn(async move {
            let local_start = Instant::now();
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            let response = match evaluate(
                request_id,
                service_name_ref.as_str(),
                engine,
                &mut entry,
                template_clone,
            )
            .await
            {
                Ok(r) => r,
                Err(err) => {
                    tracing::error!(%err, entry.text, "Failed evaluating prompt");
                    return;
                }
            };
141
142
143
144
145
146
            let local_elapsed = Instant::now() - local_start;
            entry.elapsed_ms = local_elapsed.as_millis() as usize;

            if let Some(pre) = pre_processor {
                // Note this does not include the prompt template. Probably TODO
                entry.tokens_in = match pre.tokenize(&entry.text) {
147
                    Ok(encoding) => encoding.token_ids().len(),
148
149
150
151
152
153
                    Err(err) => {
                        tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
                        0
                    }
                };
                entry.tokens_out = match pre.tokenize(&response) {
154
                    Ok(encoding) => encoding.token_ids().len(),
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    Err(err) => {
                        tracing::warn!(%err, response, "Failed tokenizing response");
                        0
                    }
                };
                tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
                tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
            }
            entry.response = Some(response);

            let _ = done_entries_tx.send(entry).await;
        });
        handles.push(handle);
    }
    tokio::select! {
        _ = cancel_token.cancelled() => {
            // Don't print stats
            return Ok(());
        }
        _ = futures::future::join_all(handles) => {
        }
    }
    let elapsed = Instant::now() - start;
    let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
    let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
    let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
    tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
    tracing::info!(
        "Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
        num_entries,
        humantime::format_duration(elapsed_clean),
        tokens_in,
        tokens_in / cmp::max(elapsed.as_secs(), 1),
        tokens_out,
        tokens_out / cmp::max(elapsed.as_secs(), 1),
    );
191
    cancel_token.cancel(); // stop everything else
192
193
194
195
196
197

    Ok(())
}

// Run a single prompt through the engine
async fn evaluate(
198
199
    request_id: usize,
    service_name: &str,
200
    engine: OpenAIChatCompletionsStreamingEngine,
201
    entry: &mut Entry,
202
    template: Option<Arc<RequestTemplate>>,
203
) -> anyhow::Result<String> {
204
205
206
    let user_message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
        dynamo_async_openai::types::ChatCompletionRequestUserMessage {
            content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
207
                entry.text.clone(),
208
209
210
211
            ),
            name: None,
        },
    );
212
    let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
213
        .messages(vec![user_message])
214
215
216
217
218
        .model(
            template
                .as_ref()
                .map_or_else(|| service_name.to_string(), |t| t.model.clone()),
        )
219
        .stream(true)
220
221
222
223
224
225
        .max_completion_tokens(
            template
                .as_ref()
                .map_or(MAX_TOKENS, |t| t.max_completion_tokens),
        )
        .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
226
        .build()?;
227
228
229
230
    let req = NvCreateChatCompletionRequest {
        inner,
        common: Default::default(),
        nvext: None,
231
        chat_template_args: None,
232
        unsupported_fields: Default::default(),
233
    };
234
235
236
237
238
239
    let mut stream = engine.generate(Context::new(req)).await?;
    let mut output = String::new();
    while let Some(item) = stream.next().await {
        match (item.data.as_ref(), item.event.as_deref()) {
            (Some(data), _) => {
                // Normal case
240
                let choice = data.choices.first();
241
                let chat_comp = choice.as_ref().unwrap();
242
243
244
                if let Some(c) = &chat_comp.delta.content {
                    output += c;
                }
245
                entry.finish_reason = chat_comp.finish_reason;
246
                if chat_comp.finish_reason.is_some() {
247
248
249
250
251
                    tracing::trace!(
                        request_id,
                        "finish reason: {:?}",
                        chat_comp.finish_reason.unwrap()
                    );
252
253
254
255
                    break;
                }
            }
            (None, Some("error")) => {
256
                tracing::error!(request_id, "the error case");
257
258
                // There's only one error but we loop in case that changes
                for err in item.comment.unwrap_or_default() {
259
                    tracing::error!(request_id, "Engine error: {err}");
260
261
262
                }
            }
            (None, Some(annotation)) => {
263
                tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
            }
            _ => {
                unreachable!("Event from engine with no data, no error, no annotation.");
            }
        }
    }
    Ok(output)
}

async fn output_writer(
    cancel_token: CancellationToken,
    mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
    output_file: &Path,
) -> anyhow::Result<()> {
    let mut num_completed = 0;
    let mut f = tokio::fs::File::create(output_file).await?;
    loop {
281
        let entry = tokio::select! {
282
283
284
            _ = cancel_token.cancelled() => {
                break;
            }
285
286
287
288
289
            maybe_entry = entries_rx.recv() => {
                match maybe_entry {
                    Some(entry) => entry,
                    None => {break;}
                }
290
291
292
293
294
295
296
297
298
            }
        };
        let mut s = serde_json::to_string(&entry)?;
        s.push('\n');
        f.write_all(s.as_bytes()).await?;

        num_completed += 1;
        // TODO: Progress bar. We'd have to count the lines in the input first,
        // and the input maybe be large
299
        tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
300
301
302
    }
    Ok(())
}