batch.rs 10.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
6
use crate::preprocessor::OpenAIPreprocessor;
use crate::request_template::RequestTemplate;
use crate::types::openai::chat_completions::{
7
8
    NvCreateChatCompletionRequest, OpenAIChatCompletionsStreamingEngine,
};
9
use anyhow::Context as _;
10
use dynamo_async_openai::types::FinishReason;
11
use dynamo_runtime::{DistributedRuntime, pipeline::Context, runtime::CancellationToken};
12
13
14
15
16
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::cmp;
use std::path::{Path, PathBuf};
use std::sync::Arc;
17
use std::sync::atomic::{AtomicU64, Ordering};
18
19
20
use std::time::{Duration, Instant};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt};

21
use crate::entrypoint::EngineConfig;
22
use crate::entrypoint::input::common;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

/// Max tokens in each response.
/// TODO: For batch mode this should be the full context size of the model
const MAX_TOKENS: u32 = 8192;

const OUTPUT_FILENAME: &str = "output.jsonl";

#[derive(Serialize, Deserialize, Default, Debug)]
struct Entry {
    // The input files only have this
    text: String,

    response: Option<String>,

    #[serde(default)]
    tokens_in: usize,

    #[serde(default)]
    tokens_out: usize,

    #[serde(default)]
    elapsed_ms: usize,
45
46
47
48
49
50

    #[serde(default, skip_serializing_if = "Option::is_none")]
    finish_reason: Option<FinishReason>,

    #[serde(skip, default)]
    request_id: usize,
51
52
53
}

pub async fn run(
54
    distributed_runtime: DistributedRuntime,
55
56
57
    input_jsonl: PathBuf,
    engine_config: EngineConfig,
) -> anyhow::Result<()> {
58
    let cancel_token = distributed_runtime.primary_token();
59
60
61
62
63
64
65
66
    // Check if the path exists and is a directory
    if !input_jsonl.exists() || !input_jsonl.is_file() {
        anyhow::bail!(
            "Missing or not a file: {}. Should be a JSON Lines file.",
            input_jsonl.display()
        );
    }

67
    let mut prepared_engine = common::prepare_engine(distributed_runtime, engine_config).await?;
68

69
    let pre_processor = if prepared_engine.has_tokenizer() {
70
71
72
        Some(OpenAIPreprocessor::new(
            prepared_engine.card.take().unwrap(),
        )?)
73
74
75
76
77
78
79
80
    } else {
        None
    };
    let (done_entries_tx, done_entries_rx) = tokio::sync::mpsc::channel(64);
    let dw_cancel_token = cancel_token.clone();
    let mut output_file = input_jsonl.clone();
    output_file.set_file_name(OUTPUT_FILENAME);
    tokio::spawn(async move {
81
        if let Err(err) = output_writer(dw_cancel_token, done_entries_rx, &output_file).await {
82
83
84
            tracing::error!(%err, "Failed writing output to {}", output_file.display());
        }
    });
85
    let service_name_ref = Arc::new(prepared_engine.service_name);
86
87
88
89
90
91
92
93
94
95
96
97
98

    let tokens_in = Arc::new(AtomicU64::new(0));
    let tokens_out = Arc::new(AtomicU64::new(0));
    let mut handles = vec![];
    let mut num_entries = 0;
    let input_file = tokio::fs::File::open(&input_jsonl)
        .await
        .with_context(|| input_jsonl.display().to_string())?;
    let buffered_input = tokio::io::BufReader::new(input_file);

    tracing::info!("Timer start.");
    let start = Instant::now();
    let mut lines = buffered_input.lines();
99
    let template: Option<Arc<RequestTemplate>> = prepared_engine.request_template.map(Arc::new);
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    while let Ok(Some(line)) = lines.next_line().await {
        if cancel_token.is_cancelled() {
            break;
        }
        if line.is_empty() {
            continue;
        }
        let request_id = num_entries;
        num_entries += 1;
        let mut entry: Entry = match serde_json::from_str(&line) {
            Ok(entry) => entry,
            Err(err) => {
                anyhow::bail!("Error parsing entry: '{line}'. {err}");
            }
        };
115
        entry.request_id = request_id;
116

117
        let engine = prepared_engine.engine.clone();
118
119
120
121
        let pre_processor = pre_processor.clone();
        let tokens_in = tokens_in.clone();
        let tokens_out = tokens_out.clone();
        let done_entries_tx = done_entries_tx.clone();
122
        let service_name_ref = service_name_ref.clone();
123
        let template_clone = template.clone();
124
125
        let handle = tokio::spawn(async move {
            let local_start = Instant::now();
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            let response = match evaluate(
                request_id,
                service_name_ref.as_str(),
                engine,
                &mut entry,
                template_clone,
            )
            .await
            {
                Ok(r) => r,
                Err(err) => {
                    tracing::error!(%err, entry.text, "Failed evaluating prompt");
                    return;
                }
            };
141
142
143
144
145
146
            let local_elapsed = Instant::now() - local_start;
            entry.elapsed_ms = local_elapsed.as_millis() as usize;

            if let Some(pre) = pre_processor {
                // Note this does not include the prompt template. Probably TODO
                entry.tokens_in = match pre.tokenize(&entry.text) {
147
                    Ok(encoding) => encoding.token_ids().len(),
148
149
150
151
152
153
                    Err(err) => {
                        tracing::warn!(%err, entry.text, "Failed tokenizing prompt");
                        0
                    }
                };
                entry.tokens_out = match pre.tokenize(&response) {
154
                    Ok(encoding) => encoding.token_ids().len(),
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                    Err(err) => {
                        tracing::warn!(%err, response, "Failed tokenizing response");
                        0
                    }
                };
                tokens_in.fetch_add(entry.tokens_in as u64, Ordering::Relaxed);
                tokens_out.fetch_add(entry.tokens_out as u64, Ordering::Relaxed);
            }
            entry.response = Some(response);

            let _ = done_entries_tx.send(entry).await;
        });
        handles.push(handle);
    }
    tokio::select! {
        _ = cancel_token.cancelled() => {
            // Don't print stats
            return Ok(());
        }
        _ = futures::future::join_all(handles) => {
        }
    }
    let elapsed = Instant::now() - start;
    let elapsed_clean = Duration::from_millis(elapsed.as_millis() as u64);
    let tokens_in = Arc::into_inner(tokens_in).unwrap().into_inner();
    let tokens_out = Arc::into_inner(tokens_out).unwrap().into_inner();
    tokio::time::sleep(Duration::from_millis(1)).await; // Let output_writer finish stdout write
    tracing::info!(
        "Ran {} files in {}. Tokens in: {} ({}/s). Tokens out: {} ({}/s)",
        num_entries,
        humantime::format_duration(elapsed_clean),
        tokens_in,
        tokens_in / cmp::max(elapsed.as_secs(), 1),
        tokens_out,
        tokens_out / cmp::max(elapsed.as_secs(), 1),
    );
191
    cancel_token.cancel(); // stop everything else
192
193
194
195
196
197

    Ok(())
}

// Run a single prompt through the engine
async fn evaluate(
198
199
    request_id: usize,
    service_name: &str,
200
    engine: OpenAIChatCompletionsStreamingEngine,
201
    entry: &mut Entry,
202
    template: Option<Arc<RequestTemplate>>,
203
) -> anyhow::Result<String> {
204
205
206
    let user_message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
        dynamo_async_openai::types::ChatCompletionRequestUserMessage {
            content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
207
                entry.text.clone(),
208
209
210
211
            ),
            name: None,
        },
    );
212
    let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
213
        .messages(vec![user_message])
214
215
216
217
218
        .model(
            template
                .as_ref()
                .map_or_else(|| service_name.to_string(), |t| t.model.clone()),
        )
219
        .stream(true)
220
221
222
223
224
225
        .max_completion_tokens(
            template
                .as_ref()
                .map_or(MAX_TOKENS, |t| t.max_completion_tokens),
        )
        .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
226
        .build()?;
227
228
229
230
    let req = NvCreateChatCompletionRequest {
        inner,
        common: Default::default(),
        nvext: None,
231
        chat_template_args: None,
232
        media_io_kwargs: None,
233
        unsupported_fields: Default::default(),
234
    };
235
236
237
238
239
240
    let mut stream = engine.generate(Context::new(req)).await?;
    let mut output = String::new();
    while let Some(item) = stream.next().await {
        match (item.data.as_ref(), item.event.as_deref()) {
            (Some(data), _) => {
                // Normal case
241
                let choice = data.choices.first();
242
                let chat_comp = choice.as_ref().unwrap();
243
244
245
                if let Some(c) = &chat_comp.delta.content {
                    output += c;
                }
246
                entry.finish_reason = chat_comp.finish_reason;
247
                if chat_comp.finish_reason.is_some() {
248
249
250
251
252
                    tracing::trace!(
                        request_id,
                        "finish reason: {:?}",
                        chat_comp.finish_reason.unwrap()
                    );
253
254
255
256
                    break;
                }
            }
            (None, Some("error")) => {
257
                tracing::error!(request_id, "the error case");
258
259
                // There's only one error but we loop in case that changes
                for err in item.comment.unwrap_or_default() {
260
                    tracing::error!(request_id, "Engine error: {err}");
261
262
263
                }
            }
            (None, Some(annotation)) => {
264
                tracing::debug!(request_id, "Annotation. {annotation}: {:?}", item.comment);
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            }
            _ => {
                unreachable!("Event from engine with no data, no error, no annotation.");
            }
        }
    }
    Ok(output)
}

async fn output_writer(
    cancel_token: CancellationToken,
    mut entries_rx: tokio::sync::mpsc::Receiver<Entry>,
    output_file: &Path,
) -> anyhow::Result<()> {
    let mut num_completed = 0;
    let mut f = tokio::fs::File::create(output_file).await?;
    loop {
282
        let entry = tokio::select! {
283
284
285
            _ = cancel_token.cancelled() => {
                break;
            }
286
287
288
289
290
            maybe_entry = entries_rx.recv() => {
                match maybe_entry {
                    Some(entry) => entry,
                    None => {break;}
                }
291
292
293
294
295
296
297
298
299
            }
        };
        let mut s = serde_json::to_string(&entry)?;
        s.push('\n');
        f.write_all(s.as_bytes()).await?;

        num_completed += 1;
        // TODO: Progress bar. We'd have to count the lines in the input first,
        // and the input maybe be large
300
        tracing::info!(entry.request_id, entry.tokens_out, "Saved {num_completed}");
301
302
303
    }
    Ok(())
}