delta.rs 12.7 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
5
6
use crate::{
    protocols::{
7
8
        common::{self, timing::RequestTimingTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
9
10
11
    },
    types::TokenIdType,
};
12

13
impl NvCreateCompletionRequest {
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

39
40
    // put this method on the request
    // inspect the request to extract options
41
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
42
43
44
45
46
47
        // Check if client requested timing in extra_fields
        let enable_timing = self
            .nvext()
            .and_then(|nv| nv.extra_fields.as_ref())
            .is_some_and(|fields| fields.iter().any(|f| f == "timing"));

48
        let options = DeltaGeneratorOptions {
49
50
51
52
53
54
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
55
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
56
            enable_timing,
57
58
        };

59
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
60
61
62
63
64
65
66
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
67
    pub enable_timing: bool,
68
69
70
71
72
}

pub struct DeltaGenerator {
    id: String,
    object: String,
73
    created: u32,
74
75
    model: String,
    system_fingerprint: Option<String>,
76
    usage: dynamo_async_openai::types::CompletionUsage,
77
    options: DeltaGeneratorOptions,
78
    timing_tracker: Option<RequestTimingTracker>,
79
80
81
}

impl DeltaGenerator {
82
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
83
84
85
86
87
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

88
89
90
91
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

92
93
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
94
        let usage = dynamo_async_openai::types::CompletionUsage {
95
96
97
98
99
100
101
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

102
103
        let completion_id = format!("cmpl-{request_id}");

104
105
106
107
108
109
110
        // Create timing tracker if timing is enabled
        let timing_tracker = if options.enable_timing {
            Some(RequestTimingTracker::new())
        } else {
            None
        };

111
        Self {
112
            id: completion_id,
113
114
115
116
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
117
            usage,
118
            options,
119
            timing_tracker,
120
121
122
        }
    }

123
    pub fn update_isl(&mut self, isl: u32) {
124
125
126
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
127
128
129
130
131
132
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
133
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
161
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
162
163
164
165
166
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
167
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
168
169
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
170
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
171
172
173
174
175
176
177
178
179
180
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

181
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
182
183
184
185
186
187
188
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

189
190
    pub fn create_choice(
        &self,
191
        index: u32,
192
        text: Option<String>,
193
194
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
195
    ) -> NvCreateCompletionResponse {
196
197
        // todo - update for tool calling

198
199
200
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
201
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
202
203
            id: self.id.clone(),
            object: self.object.clone(),
204
            created: self.created,
205
206
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
207
            choices: vec![dynamo_async_openai::types::Choice {
208
                text: text.unwrap_or_default(),
209
                index,
210
                finish_reason,
Greg Clark's avatar
Greg Clark committed
211
                logprobs,
212
            }],
213
            usage: None, // Always None for chunks with content/choices
214
            nvext: None, // Will be populated by router layer if needed
215
216
217
        };

        NvCreateCompletionResponse { inner }
218
    }
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
237
            nvext: None, // Will be populated by router layer if needed
238
239
240
241
242
243
244
245
246
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
247
248
}

249
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
250
251
252
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
253
    ) -> anyhow::Result<NvCreateCompletionResponse> {
254
255
        // aggregate usage
        if self.options.enable_usage {
256
257
258
259
260
261
262
263
264
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
265
266
267
268
269
270
271
272
273
274

            // If backend provides completion_usage with prompt token details,
            // propagate the entire details struct to usage tracking
            if let Some(prompt_details) = delta
                .completion_usage
                .as_ref()
                .and_then(|usage| usage.prompt_tokens_details.as_ref())
            {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
275
276
        }

Greg Clark's avatar
Greg Clark committed
277
278
279
280
281
282
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
283
284

        let finish_reason = delta.finish_reason.map(Into::into);
285
286

        // create choice
287
        let index = delta.index.unwrap_or(0);
288
289
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

290
291
292
293
294
295
296
        // Record first token time (only succeeds on first call due to OnceLock)
        if let Some(ref tracker) = self.timing_tracker {
            tracker.record_first_token();
        }

        // Extract worker_id from disaggregated_params
        let worker_id_info = delta
297
298
299
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());

        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
            self.timing_tracker.as_ref().map(|tracker| {
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

        // Inject nvext if we have worker_id or timing
        if worker_id_info.is_some() || timing_info.is_some() {
314
            let nvext_response = NvExtResponse {
315
316
                worker_id: worker_id_info.clone(),
                timing: timing_info,
317
318
319
320
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
321
322
323
324
325
326
327
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
328
329
330
            }
        }

331
        Ok(response)
332
    }
333
334

    fn get_isl(&self) -> Option<u32> {
335
        Some(self.usage.prompt_tokens)
336
    }
337
338
339
340
341
342
343
344

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
345
}