batch.rs 11 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
6
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
7
8
    NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
9
use anyhow::Context as _;
10
use dynamo_async_openai::types::{ChatCompletionMessageContent, FinishReason};
11
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
12
13
14
15
16
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::Arc;
17
use std::sync::atomic::{AtomicU64, Ordering};
18
19
20
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

21
use crate::entrypoint::EngineConfig;
22
use crate::entrypoint::input::common;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;

const OUTPUT_FILENAME: &str = "output.jsonl";

#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
    // The input files only have this
    text: String,

    response: Option<String>,

    #[serde(default)]
    tokens_in: usize,

    #[serde(default)]
    tokens_out: usize,

    #[serde(default)]
    elapsed_ms: usize,
45
46
47
48
49
50

    #[serde(default, skip_serializing_if = "Option::is_none")]
    finish_reason: Option<FinishReason>,

    #[serde(skip, default)]
    request_id: usize,
51
52
53
}

pub async fn run(
54
    distributed_runtime: DistributedRuntime,
55
56
57
    input_jsonl: PathBuf,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
58
    let cancel_token = distributed_runtime.primary_token();
59
60
61
62
63
64
65
66
    // Check if the path exists and is a directory
    if !input_jsonl.exists() || !input_jsonl.is_file() {
        anyhow::bail!(
            "Missing or not a file: {}. Should be a JSON Lines file.",
            input_jsonl.display()
        );
    }

67
    let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?;
68

69
    let pre_processor = if prepared_engine.has_tokenizer() {
70
71
72
        Some(OpenAIPreprocessor::new(
            prepared_engine.card.take().unwrap(),
        )?)
73
74
75
76
77
78
79
80
    } else {
        None
    };
    let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
    let dw_cancel_token = cancel_token.clone();
    let mut output_file = input_jsonl.clone();
    output_file.set_file_name(OUTPUT_FILENAME);
    tokio::spawn(async move {
81
        if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
82
83
84
            tracing::error!(%err, "Failed writing output to {}", output_file.display());
        }
    });
85
    let service_name_ref = Arc::new(prepared_engine.service_name);
86
87
88
89
90
91
92
93
94
95
96
97
98

    let tokens_in = Arc::new(AtomicU64::new(0));
    let tokens_out = Arc::new(AtomicU64::new(0));
    let mut handles = vec![];
    let mut num_entries = 0;
    let input_file = tokio::fs::File::open(&input_jsonl)
        .await
        .with_context(|| input_jsonl.display().to_string())?;
    let buffered_input = tokio::io::BufReader::new(input_file);

    tracing::info!("Timer start.");
    let start = Instant::now();
    let mut lines = buffered_input.lines();
99
    let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    while let Ok(Some(line)) = lines.next_line().await {
        if cancel_token.is_cancelled() {
            break;
        }
        if line.is_empty() {
            continue;
        }
        let request_id = num_entries;
        num_entries += 1;
        let mut entry: Entry = match serde_json::from_str(&line) {
            Ok(entry) => entry,
            Err(err) => {
                anyhow::bail!("Error parsing entry: '{line}'. {err}");
            }
        };
115
        entry.request_id = request_id;
116

117
        let engine = prepared_engine.engine.clone();
118
119
120
121
        let pre_processor = pre_processor.clone();
        let tokens_in = tokens_in.clone();
        let tokens_out = tokens_out.clone();
        let done_entries_tx = done_entries_tx.clone();
122
        let service_name_ref = service_name_ref.clone();
123
        let template_clone = template.clone();
124
125
        let handle = tokio::spawn(async move {
            let local_start = Instant::now();
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            let response = match evaluate(
                request_id,
                service_name_ref.as_str(),
                engine,
                &mut entry,
                template_clone,
            )
            .await
            {
                Ok(r) => r,
                Err(err) => {
                    tracing::error!(%err, entry.text, "Failed evaluating prompt");
                    return;
                }
            };
141
142
143
144
145
146
            let local_elapsed = Instant::now() - local_start;
            entry.elapsed_ms = local_elapsed.as_millis() as usize;

            if let Some(pre) = pre_processor {
                // Note this does not include the prompt template. Probably TODO
                entry.tokens_in = match pre.tokenize(&entry.text) {
147
                    Ok(encoding) => encoding.token_ids().len(),
148
149
150
151
152
153
                    Err(err) => {
                        tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
                        0
                    }
                };
                entry.tokens_out = match pre.tokenize(&response) {
154
                    Ok(encoding) => encoding.token_ids().len(),
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    Err(err) => {
                        tracing::warn!(%err, response, "Failed tokenizing response");
                        0
                    }
                };
                tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
                tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
            }
            entry.response = Some(response);

            let _ = done_entries_tx.send(entry).await;
        });
        handles.push(handle);
    }
    tokio::select! {
        _ = cancel_token.cancelled() => {
            // Don't print stats
            return Ok(());
        }
        _ = futures::future::join_all(handles) => {
        }
    }
    let elapsed = Instant::now() - start;
    let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
    let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
    let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
    tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
    tracing::info!(
        "Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
        num_entries,
        humantime::format_duration(elapsed_clean),
        tokens_in,
        tokens_in / cmp::max(elapsed.as_secs(), 1),
        tokens_out,
        tokens_out / cmp::max(elapsed.as_secs(), 1),
    );
191
    cancel_token.cancel(); // stop everything else
192
193
194
195
196
197

    Ok(())
}

// Run a single prompt through the engine
async fn evaluate(
198
199
    request_id: usize,
    service_name: &str,
200
    engine: OpenAIChatCompletionsStreamingEngine,
201
    entry: &mut Entry,
202
    template: Option<Arc<RequestTemplate>>,
203
) -> anyhow::Result<String> {
204
205
206
    let user_message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
        dynamo_async_openai::types::ChatCompletionRequestUserMessage {
            content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
207
                entry.text.clone(),
208
209
210
211
            ),
            name: None,
        },
    );
212
    let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
213
        .messages(vec![user_message])
214
215
216
217
218
        .model(
            template
                .as_ref()
                .map_or_else(|| service_name.to_string(), |t| t.model.clone()),
        )
219
        .stream(true)
220
221
222
223
224
225
        .max_completion_tokens(
            template
                .as_ref()
                .map_or(MAX_TOKENS, |t| t.max_completion_tokens),
        )
        .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
226
        .build()?;
227
228
229
230
    let req = NvCreateChatCompletionRequest {
        inner,
        common: Default::default(),
        nvext: None,
231
        chat_template_args: None,
232
        media_io_kwargs: None,
233
        unsupported_fields: Default::default(),
234
    };
235
236
237
238
239
240
    let mut stream = engine.generate(Context::new(req)).await?;
    let mut output = String::new();
    while let Some(item) = stream.next().await {
        match (item.data.as_ref(), item.event.as_deref()) {
            (Some(data), _) => {
                // Normal case
241
                let choice = data.choices.first();
242
                let chat_comp = choice.as_ref().unwrap();
243
                if let Some(c) = &chat_comp.delta.content {
244
245
246
247
248
249
250
251
252
                    match c {
                        ChatCompletionMessageContent::Text(text) => {
                            output += text;
                        }
                        ChatCompletionMessageContent::Parts(_) => {
                            // Multimodal content - skip for now in batch processing
                            // (ayushag) TODO: Handle multimodal content in batch mode
                        }
                    }
253
                }
254
                entry.finish_reason = chat_comp.finish_reason;
255
                if chat_comp.finish_reason.is_some() {
256
257
258
259
260
                    tracing::trace!(
                        request_id,
                        "finish reason: {:?}",
                        chat_comp.finish_reason.unwrap()
                    );
261
262
263
264
                    break;
                }
            }
            (None, Some("error")) => {
265
                tracing::error!(request_id, "the error case");
266
267
                // There's only one error but we loop in case that changes
                for err in item.comment.unwrap_or_default() {
268
                    tracing::error!(request_id, "Engine error: {err}");
269
270
271
                }
            }
            (None, Some(annotation)) => {
272
                tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
            }
            _ => {
                unreachable!("Event from engine with no data, no error, no annotation.");
            }
        }
    }
    Ok(output)
}

async fn output_writer(
    cancel_token: CancellationToken,
    mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
    output_file: &Path,
) -> anyhow::Result<()> {
    let mut num_completed = 0;
    let mut f = tokio::fs::File::create(output_file).await?;
    loop {
290
        let entry = tokio::select! {
291
292
293
            _ = cancel_token.cancelled() => {
                break;
            }
294
295
296
297
298
            maybe_entry = entries_rx.recv() => {
                match maybe_entry {
                    Some(entry) => entry,
                    None => {break;}
                }
299
300
301
302
303
304
305
306
307
            }
        };
        let mut s = serde_json::to_string(&entry)?;
        s.push('\n');
        f.write_all(s.as_bytes()).await?;

        num_completed += 1;
        // TODO: Progress bar. We'd have to count the lines in the input first,
        // and the input maybe be large
308
        tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
309
310
311
    }
    Ok(())
}