delta.rs 12.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
5
6
use crate::{
    protocols::{
7
8
        common::{self, timing::RequestTimingTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
9
10
11
    },
    types::TokenIdType,
};
12

13
impl NvCreateCompletionRequest {
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

39
40
    // put this method on the request
    // inspect the request to extract options
41
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
42
43
44
45
46
47
        // Check if client requested timing in extra_fields
        let enable_timing = self
            .nvext()
            .and_then(|nv| nv.extra_fields.as_ref())
            .is_some_and(|fields| fields.iter().any(|f| f == "timing"));

48
        let options = DeltaGeneratorOptions {
49
50
51
52
53
54
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
55
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
56
            enable_timing,
57
58
        };

59
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
60
61
62
63
64
65
66
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
67
    pub enable_timing: bool,
68
69
70
71
72
}

pub struct DeltaGenerator {
    id: String,
    object: String,
73
    created: u32,
74
75
    model: String,
    system_fingerprint: Option<String>,
76
    usage: dynamo_async_openai::types::CompletionUsage,
77
    options: DeltaGeneratorOptions,
78
    timing_tracker: Option<RequestTimingTracker>,
79
80
81
}

impl DeltaGenerator {
82
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
83
84
85
86
87
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

88
89
90
91
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

92
93
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
94
        let usage = dynamo_async_openai::types::CompletionUsage {
95
96
97
98
99
100
101
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

102
103
        let completion_id = format!("cmpl-{request_id}");

104
105
106
107
108
109
110
        // Create timing tracker if timing is enabled
        let timing_tracker = if options.enable_timing {
            Some(RequestTimingTracker::new())
        } else {
            None
        };

111
        Self {
112
            id: completion_id,
113
114
115
116
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
117
            usage,
118
            options,
119
            timing_tracker,
120
121
122
        }
    }

123
    pub fn update_isl(&mut self, isl: u32) {
124
125
126
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
127
128
129
130
131
132
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
133
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
161
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
162
163
164
165
166
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
167
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
168
169
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
170
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
171
172
173
174
175
176
177
178
179
180
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

181
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
182
183
184
185
186
187
188
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

189
190
    pub fn create_choice(
        &self,
191
        index: u32,
192
        text: Option<String>,
193
194
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
195
    ) -> NvCreateCompletionResponse {
196
197
        // todo - update for tool calling

198
199
200
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
201
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
202
203
            id: self.id.clone(),
            object: self.object.clone(),
204
            created: self.created,
205
206
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
207
            choices: vec![dynamo_async_openai::types::Choice {
208
                text: text.unwrap_or_default(),
209
                index,
210
                finish_reason,
Greg Clark's avatar
Greg Clark committed
211
                logprobs,
212
            }],
213
            usage: None, // Always None for chunks with content/choices
214
            nvext: None, // Will be populated by router layer if needed
215
216
217
        };

        NvCreateCompletionResponse { inner }
218
    }
219
220
221
222
223
224
225

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
226
        let usage = self.get_usage();
227
228
229
230
231
232
233
234
235

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
236
            nvext: None, // Will be populated by router layer if needed
237
238
239
240
241
242
243
244
245
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
246
247
248
249
250
251

    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
252
253
}

254
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
255
256
257
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
258
    ) -> anyhow::Result<NvCreateCompletionResponse> {
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
278
279
        }

Greg Clark's avatar
Greg Clark committed
280
281
282
283
284
285
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
286
287

        let finish_reason = delta.finish_reason.map(Into::into);
288
289

        // create choice
290
        let index = delta.index.unwrap_or(0);
291
292
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

293
294
295
296
297
298
299
        // Record first token time (only succeeds on first call due to OnceLock)
        if let Some(ref tracker) = self.timing_tracker {
            tracker.record_first_token();
        }

        // Extract worker_id from disaggregated_params
        let worker_id_info = delta
300
301
302
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());

        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
            self.timing_tracker.as_ref().map(|tracker| {
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

        // Inject nvext if we have worker_id or timing
        if worker_id_info.is_some() || timing_info.is_some() {
317
            let nvext_response = NvExtResponse {
318
319
                worker_id: worker_id_info.clone(),
                timing: timing_info,
320
321
322
323
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
324
325
326
327
328
329
330
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
331
332
333
            }
        }

334
        Ok(response)
335
    }
336
337

    fn get_isl(&self) -> Option<u32> {
338
        Some(self.usage.prompt_tokens)
339
    }
340
341
342
343
344
345
346
347

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
348
349
350
351

    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
352
}