delta.rs 14 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
7
8
use crate::{
    protocols::{
9
        common::{self, timing::RequestTracker},
10
11
12
13
        openai::{
            convert_backend_top_logprobs,
            nvext::{NvExtProvider, NvExtResponse, TimingInfo},
        },
14
15
16
    },
    types::TokenIdType,
};
17

18
impl NvCreateCompletionRequest {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
36
                        continuous_usage_stats: false,
37
38
39
40
41
42
43
44
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

45
46
    // put this method on the request
    // inspect the request to extract options
47
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
48
49
50
51
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
52
            .nvext()
53
54
55
56
57
58
59
60
61
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
62

63
        let options = DeltaGeneratorOptions {
64
65
66
67
68
69
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
70
71
72
73
74
75
            continuous_usage_stats: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.continuous_usage_stats)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
76
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
77
            enable_tracking,
78
79
        };

80
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
81
82
83
84
85
86
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
87
    pub continuous_usage_stats: bool,
88
    pub enable_logprobs: bool,
89
    pub enable_tracking: bool,
90
91
92
93
94
}

pub struct DeltaGenerator {
    id: String,
    object: String,
95
    created: u32,
96
97
    model: String,
    system_fingerprint: Option<String>,
98
    usage: dynamo_async_openai::types::CompletionUsage,
99
    options: DeltaGeneratorOptions,
100
    tracker: Option<Arc<RequestTracker>>,
101
102
103
}

impl DeltaGenerator {
104
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
105
106
107
108
109
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

110
111
112
113
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

114
115
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
116
        let usage = dynamo_async_openai::types::CompletionUsage {
117
118
119
120
121
122
123
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

124
125
        let completion_id = format!("cmpl-{request_id}");

126
127
128
        // Always create request tracker for per-worker metrics (TTFT, ITL per worker_id).
        // The enable_tracking option only controls whether timing info is included in the response.
        let tracker = Some(Arc::new(RequestTracker::new()));
129

130
        Self {
131
            id: completion_id,
132
133
134
135
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
136
            usage,
137
            options,
138
            tracker,
139
140
141
        }
    }

142
143
144
145
146
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

147
    pub fn update_isl(&mut self, isl: u32) {
148
149
150
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
151
152
153
154
155
156
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
157
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
178
179
                    let converted = convert_backend_top_logprobs(top_lps, t, *tid, *lp);
                    serde_json::to_value(converted).unwrap()
Greg Clark's avatar
Greg Clark committed
180
181
182
183
                })
                .collect()
        });

184
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
185
186
187
188
189
190
191
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

192
193
    pub fn create_choice(
        &self,
194
        index: u32,
195
        text: Option<String>,
196
197
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
198
    ) -> NvCreateCompletionResponse {
199
200
        // todo - update for tool calling

201
202
203
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
204
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
205
206
            id: self.id.clone(),
            object: self.object.clone(),
207
            created: self.created,
208
209
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
210
            choices: vec![dynamo_async_openai::types::Choice {
211
                text: text.unwrap_or_default(),
212
                index,
213
                finish_reason,
Greg Clark's avatar
Greg Clark committed
214
                logprobs,
215
            }],
216
217
218
219
220
            usage: if self.options.enable_usage && self.options.continuous_usage_stats {
                Some(self.get_usage())
            } else {
                None
            },
221
222
        };

223
        NvCreateCompletionResponse { inner, nvext: None }
224
    }
225
226
227
228
229
230
231

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
232
        let usage = self.get_usage();
233
234
235
236
237
238
239
240
241
242
243

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
        };

244
        NvCreateCompletionResponse { inner, nvext: None }
245
246
247
248
249
250
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
251

252
253
254
255
256
    /// Check if continuous usage tracking is enabled
    pub fn is_continuous_usage_enabled(&self) -> bool {
        self.options.continuous_usage_stats
    }

257
258
259
260
261
    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
262
263
}

264
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
265
266
267
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
268
    ) -> anyhow::Result<NvCreateCompletionResponse> {
269
270
271
272
273
274
275
276
277
278
279
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

280
281
282
283
284
285
286
287
288
289
290
        // If backend provides completion_usage, use it to update usage stats
        // This is critical for prompt embeddings where prompt_tokens comes from
        // the embedding sequence length computed by the worker
        if let Some(completion_usage) = delta.completion_usage.as_ref() {
            // Update prompt_tokens from worker if provided (e.g., for embeddings)
            self.usage.prompt_tokens = completion_usage.prompt_tokens;

            // Propagate prompt token details if provided
            if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
291
292
        }

Greg Clark's avatar
Greg Clark committed
293
294
295
296
297
298
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
299
300

        let finish_reason = delta.finish_reason.map(Into::into);
301
302

        // create choice
303
        let index = delta.index.unwrap_or(0);
304
305
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

306
307
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
308

309
310
311
312
313
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
314
315
316
317
318
        let routed_experts = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("routed_experts"))
            .cloned();
319

320
321
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
322
            self.tracker.as_ref().map(|tracker| {
323
324
325
326
327
328
329
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

330
331
332
333
334
335
        // Inject nvext if we have worker_id, token_ids, timing, or routed experts.
        if worker_id_info.is_some()
            || token_ids.is_some()
            || timing_info.is_some()
            || routed_experts.is_some()
        {
336
            let nvext_response = NvExtResponse {
337
338
                worker_id: worker_id_info.clone(),
                timing: timing_info,
339
                token_ids: token_ids.clone(),
340
                routed_experts,
341
342
343
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
344
                response.nvext = Some(nvext_json);
345
346
347
348
349
350
351
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
352
353
354
355
356
357
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into completions nvext: {} tokens",
                        tokens.len()
                    );
                }
358
359
360
            }
        }

361
        Ok(response)
362
    }
363
364

    fn get_isl(&self) -> Option<u32> {
365
        Some(self.usage.prompt_tokens)
366
    }
367
368
369
370
371
372
373
374

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
375

376
377
378
379
    fn is_continuous_usage_enabled(&self) -> bool {
        DeltaGenerator::is_continuous_usage_enabled(self)
    }

380
381
382
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
383
384
385
386

    fn tracker(&self) -> Option<std::sync::Arc<crate::protocols::common::timing::RequestTracker>> {
        self.tracker.clone()
    }
387
}