delta.rs 13.8 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
7
8
use crate::{
    protocols::{
9
10
        common::{self, timing::RequestTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
11
12
13
    },
    types::TokenIdType,
};
14

15
impl NvCreateCompletionRequest {
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

41
42
    // put this method on the request
    // inspect the request to extract options
43
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
44
45
46
47
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
48
            .nvext()
49
50
51
52
53
54
55
56
57
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
58

59
        let options = DeltaGeneratorOptions {
60
61
62
63
64
65
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
66
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
67
            enable_tracking,
68
69
        };

70
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
71
72
73
74
75
76
77
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
78
    pub enable_tracking: bool,
79
80
81
82
83
}

pub struct DeltaGenerator {
    id: String,
    object: String,
84
    created: u32,
85
86
    model: String,
    system_fingerprint: Option<String>,
87
    usage: dynamo_async_openai::types::CompletionUsage,
88
    options: DeltaGeneratorOptions,
89
    tracker: Option<Arc<RequestTracker>>,
90
91
92
}

impl DeltaGenerator {
93
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
94
95
96
97
98
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

99
100
101
102
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

103
104
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
105
        let usage = dynamo_async_openai::types::CompletionUsage {
106
107
108
109
110
111
112
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

113
114
        let completion_id = format!("cmpl-{request_id}");

115
116
117
        // Create request tracker if tracking is enabled
        let tracker = if options.enable_tracking {
            Some(Arc::new(RequestTracker::new()))
118
119
120
121
        } else {
            None
        };

122
        Self {
123
            id: completion_id,
124
125
126
127
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
128
            usage,
129
            options,
130
            tracker,
131
132
133
        }
    }

134
135
136
137
138
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

139
    pub fn update_isl(&mut self, isl: u32) {
140
141
142
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
143
144
145
146
147
148
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
149
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
177
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
178
179
180
181
182
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
183
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
184
185
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
186
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
187
188
189
190
191
192
193
194
195
196
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

197
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
198
199
200
201
202
203
204
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

205
206
    pub fn create_choice(
        &self,
207
        index: u32,
208
        text: Option<String>,
209
210
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
211
    ) -> NvCreateCompletionResponse {
212
213
        // todo - update for tool calling

214
215
216
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
217
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
218
219
            id: self.id.clone(),
            object: self.object.clone(),
220
            created: self.created,
221
222
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
223
            choices: vec![dynamo_async_openai::types::Choice {
224
                text: text.unwrap_or_default(),
225
                index,
226
                finish_reason,
Greg Clark's avatar
Greg Clark committed
227
                logprobs,
228
            }],
229
            usage: None, // Always None for chunks with content/choices
230
            nvext: None, // Will be populated by router layer if needed
231
232
233
        };

        NvCreateCompletionResponse { inner }
234
    }
235
236
237
238
239
240
241

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
242
        let usage = self.get_usage();
243
244
245
246
247
248
249
250
251

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
252
            nvext: None, // Will be populated by router layer if needed
253
254
255
256
257
258
259
260
261
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
262
263
264
265
266
267

    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
268
269
}

270
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
271
272
273
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
274
    ) -> anyhow::Result<NvCreateCompletionResponse> {
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
294
295
        }

Greg Clark's avatar
Greg Clark committed
296
297
298
299
300
301
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
302
303

        let finish_reason = delta.finish_reason.map(Into::into);
304
305

        // create choice
306
        let index = delta.index.unwrap_or(0);
307
308
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

309
        // Record first token time (only succeeds on first call due to OnceLock)
310
        if let Some(ref tracker) = self.tracker {
311
312
313
            tracker.record_first_token();
        }

314
315
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
316

317
318
319
320
321
322
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());

323
324
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
325
            self.tracker.as_ref().map(|tracker| {
326
327
328
329
330
331
332
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

333
334
        // Inject nvext if we have worker_id, token_ids, or timing
        if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
335
            let nvext_response = NvExtResponse {
336
337
                worker_id: worker_id_info.clone(),
                timing: timing_info,
338
                token_ids: token_ids.clone(),
339
340
341
342
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
343
344
345
346
347
348
349
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
350
351
352
353
354
355
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into completions nvext: {} tokens",
                        tokens.len()
                    );
                }
356
357
358
            }
        }

359
        Ok(response)
360
    }
361
362

    fn get_isl(&self) -> Option<u32> {
363
        Some(self.usage.prompt_tokens)
364
    }
365
366
367
368
369
370
371
372

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
373
374
375
376

    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
377
}