delta.rs 15 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
7
8
use crate::{
    protocols::{
9
10
        common::{self, timing::RequestTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
11
12
13
    },
    types::TokenIdType,
};
14

15
impl NvCreateCompletionRequest {
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
33
                        continuous_usage_stats: false,
34
35
36
37
38
39
40
41
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

42
43
    // put this method on the request
    // inspect the request to extract options
44
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
45
46
47
48
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
49
            .nvext()
50
51
52
53
54
55
56
57
58
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
59

60
        let options = DeltaGeneratorOptions {
61
62
63
64
65
66
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
67
68
69
70
71
72
            continuous_usage_stats: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.continuous_usage_stats)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
73
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
74
            enable_tracking,
75
76
        };

77
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
78
79
80
81
82
83
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
84
    pub continuous_usage_stats: bool,
85
    pub enable_logprobs: bool,
86
    pub enable_tracking: bool,
87
88
89
90
91
}

pub struct DeltaGenerator {
    id: String,
    object: String,
92
    created: u32,
93
94
    model: String,
    system_fingerprint: Option<String>,
95
    usage: dynamo_async_openai::types::CompletionUsage,
96
    options: DeltaGeneratorOptions,
97
    tracker: Option<Arc<RequestTracker>>,
98
99
100
}

impl DeltaGenerator {
101
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
102
103
104
105
106
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

107
108
109
110
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

111
112
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
113
        let usage = dynamo_async_openai::types::CompletionUsage {
114
115
116
117
118
119
120
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

121
122
        let completion_id = format!("cmpl-{request_id}");

123
124
125
        // Always create request tracker for per-worker metrics (TTFT, ITL per worker_id).
        // The enable_tracking option only controls whether timing info is included in the response.
        let tracker = Some(Arc::new(RequestTracker::new()));
126

127
        Self {
128
            id: completion_id,
129
130
131
132
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
133
            usage,
134
            options,
135
            tracker,
136
137
138
        }
    }

139
140
141
142
143
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

144
    pub fn update_isl(&mut self, isl: u32) {
145
146
147
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
148
149
150
151
152
153
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
154
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
182
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
183
184
185
186
187
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
188
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
189
190
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
191
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
192
193
194
195
196
197
198
199
200
201
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

202
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
203
204
205
206
207
208
209
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

210
211
    pub fn create_choice(
        &self,
212
        index: u32,
213
        text: Option<String>,
214
215
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
216
    ) -> NvCreateCompletionResponse {
217
218
        // todo - update for tool calling

219
220
221
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
222
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
223
224
            id: self.id.clone(),
            object: self.object.clone(),
225
            created: self.created,
226
227
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
228
            choices: vec![dynamo_async_openai::types::Choice {
229
                text: text.unwrap_or_default(),
230
                index,
231
                finish_reason,
Greg Clark's avatar
Greg Clark committed
232
                logprobs,
233
            }],
234
235
236
237
238
            usage: if self.options.enable_usage && self.options.continuous_usage_stats {
                Some(self.get_usage())
            } else {
                None
            },
239
            nvext: None, // Will be populated by router layer if needed
240
241
242
        };

        NvCreateCompletionResponse { inner }
243
    }
244
245
246
247
248
249
250

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
251
        let usage = self.get_usage();
252
253
254
255
256
257
258
259
260

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
261
            nvext: None, // Will be populated by router layer if needed
262
263
264
265
266
267
268
269
270
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
271

272
273
274
275
276
    /// Check if continuous usage tracking is enabled
    pub fn is_continuous_usage_enabled(&self) -> bool {
        self.options.continuous_usage_stats
    }

277
278
279
280
281
    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
282
283
}

284
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
285
286
287
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
288
    ) -> anyhow::Result<NvCreateCompletionResponse> {
289
290
291
292
293
294
295
296
297
298
299
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

300
301
302
303
304
305
306
307
308
309
310
        // If backend provides completion_usage, use it to update usage stats
        // This is critical for prompt embeddings where prompt_tokens comes from
        // the embedding sequence length computed by the worker
        if let Some(completion_usage) = delta.completion_usage.as_ref() {
            // Update prompt_tokens from worker if provided (e.g., for embeddings)
            self.usage.prompt_tokens = completion_usage.prompt_tokens;

            // Propagate prompt token details if provided
            if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
311
312
        }

Greg Clark's avatar
Greg Clark committed
313
314
315
316
317
318
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
319
320

        let finish_reason = delta.finish_reason.map(Into::into);
321
322

        // create choice
323
        let index = delta.index.unwrap_or(0);
324
325
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

326
        // Record first token time (only succeeds on first call due to OnceLock)
327
        if let Some(ref tracker) = self.tracker {
328
329
330
            tracker.record_first_token();
        }

331
332
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
333

334
335
336
337
338
339
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());

340
341
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
342
            self.tracker.as_ref().map(|tracker| {
343
344
345
346
347
348
349
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

350
351
        // Inject nvext if we have worker_id, token_ids, or timing
        if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
352
            let nvext_response = NvExtResponse {
353
354
                worker_id: worker_id_info.clone(),
                timing: timing_info,
355
                token_ids: token_ids.clone(),
356
357
358
359
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
360
361
362
363
364
365
366
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
367
368
369
370
371
372
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into completions nvext: {} tokens",
                        tokens.len()
                    );
                }
373
374
375
            }
        }

376
        Ok(response)
377
    }
378
379

    fn get_isl(&self) -> Option<u32> {
380
        Some(self.usage.prompt_tokens)
381
    }
382
383
384
385
386
387
388
389

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
390

391
392
393
394
    fn is_continuous_usage_enabled(&self) -> bool {
        DeltaGenerator::is_continuous_usage_enabled(self)
    }

395
396
397
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
398
399
400
401

    fn tracker(&self) -> Option<std::sync::Arc<crate::protocols::common::timing::RequestTracker>> {
        self.tracker.clone()
    }
402
}