delta.rs 14.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
7
8
use crate::{
    protocols::{
9
10
        common::{self, timing::RequestTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
11
12
13
    },
    types::TokenIdType,
};
14

15
impl NvCreateCompletionRequest {
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
33
                        continuous_usage_stats: false,
34
35
36
37
38
39
40
41
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

42
43
    // put this method on the request
    // inspect the request to extract options
44
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
45
46
47
48
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
49
            .nvext()
50
51
52
53
54
55
56
57
58
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
59

60
        let options = DeltaGeneratorOptions {
61
62
63
64
65
66
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
67
68
69
70
71
72
            continuous_usage_stats: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.continuous_usage_stats)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
73
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
74
            enable_tracking,
75
76
        };

77
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
78
79
80
81
82
83
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
84
    pub continuous_usage_stats: bool,
85
    pub enable_logprobs: bool,
86
    pub enable_tracking: bool,
87
88
89
90
91
}

pub struct DeltaGenerator {
    id: String,
    object: String,
92
    created: u32,
93
94
    model: String,
    system_fingerprint: Option<String>,
95
    usage: dynamo_async_openai::types::CompletionUsage,
96
    options: DeltaGeneratorOptions,
97
    tracker: Option<Arc<RequestTracker>>,
98
99
100
}

impl DeltaGenerator {
101
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
102
103
104
105
106
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

107
108
109
110
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

111
112
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
113
        let usage = dynamo_async_openai::types::CompletionUsage {
114
115
116
117
118
119
120
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

121
122
        let completion_id = format!("cmpl-{request_id}");

123
124
125
        // Create request tracker if tracking is enabled
        let tracker = if options.enable_tracking {
            Some(Arc::new(RequestTracker::new()))
126
127
128
129
        } else {
            None
        };

130
        Self {
131
            id: completion_id,
132
133
134
135
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
136
            usage,
137
            options,
138
            tracker,
139
140
141
        }
    }

142
143
144
145
146
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

147
    pub fn update_isl(&mut self, isl: u32) {
148
149
150
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
151
152
153
154
155
156
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
157
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
185
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
186
187
188
189
190
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
191
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
192
193
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
194
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
195
196
197
198
199
200
201
202
203
204
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

205
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
206
207
208
209
210
211
212
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

213
214
    pub fn create_choice(
        &self,
215
        index: u32,
216
        text: Option<String>,
217
218
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
219
    ) -> NvCreateCompletionResponse {
220
221
        // todo - update for tool calling

222
223
224
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
225
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
226
227
            id: self.id.clone(),
            object: self.object.clone(),
228
            created: self.created,
229
230
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
231
            choices: vec![dynamo_async_openai::types::Choice {
232
                text: text.unwrap_or_default(),
233
                index,
234
                finish_reason,
Greg Clark's avatar
Greg Clark committed
235
                logprobs,
236
            }],
237
238
239
240
241
            usage: if self.options.enable_usage && self.options.continuous_usage_stats {
                Some(self.get_usage())
            } else {
                None
            },
242
            nvext: None, // Will be populated by router layer if needed
243
244
245
        };

        NvCreateCompletionResponse { inner }
246
    }
247
248
249
250
251
252
253

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
254
        let usage = self.get_usage();
255
256
257
258
259
260
261
262
263

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
264
            nvext: None, // Will be populated by router layer if needed
265
266
267
268
269
270
271
272
273
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
274

275
276
277
278
279
    /// Check if continuous usage tracking is enabled
    pub fn is_continuous_usage_enabled(&self) -> bool {
        self.options.continuous_usage_stats
    }

280
281
282
283
284
    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
285
286
}

287
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
288
289
290
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
291
    ) -> anyhow::Result<NvCreateCompletionResponse> {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
311
312
        }

Greg Clark's avatar
Greg Clark committed
313
314
315
316
317
318
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
319
320

        let finish_reason = delta.finish_reason.map(Into::into);
321
322

        // create choice
323
        let index = delta.index.unwrap_or(0);
324
325
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

326
        // Record first token time (only succeeds on first call due to OnceLock)
327
        if let Some(ref tracker) = self.tracker {
328
329
330
            tracker.record_first_token();
        }

331
332
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
333

334
335
336
337
338
339
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());

340
341
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
342
            self.tracker.as_ref().map(|tracker| {
343
344
345
346
347
348
349
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

350
351
        // Inject nvext if we have worker_id, token_ids, or timing
        if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
352
            let nvext_response = NvExtResponse {
353
354
                worker_id: worker_id_info.clone(),
                timing: timing_info,
355
                token_ids: token_ids.clone(),
356
357
358
359
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
360
361
362
363
364
365
366
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
367
368
369
370
371
372
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into completions nvext: {} tokens",
                        tokens.len()
                    );
                }
373
374
375
            }
        }

376
        Ok(response)
377
    }
378
379

    fn get_isl(&self) -> Option<u32> {
380
        Some(self.usage.prompt_tokens)
381
    }
382
383
384
385
386
387
388
389

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
390

391
392
393
394
    fn is_continuous_usage_enabled(&self) -> bool {
        DeltaGenerator::is_continuous_usage_enabled(self)
    }

395
396
397
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
398
}