delta.rs 18.2 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
7
8
use crate::{
    protocols::{
9
        common::{self, timing::RequestTracker},
10
11
        openai::{
            convert_backend_top_logprobs,
12
            nvext::{NvExtProvider, NvExtResponseFieldSelection},
13
        },
14
15
16
    },
    types::TokenIdType,
};
17

18
impl NvCreateCompletionRequest {
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
34
                    Some(dynamo_protocols::types::ChatCompletionStreamOptions {
35
                        include_usage: true,
36
                        continuous_usage_stats: false,
37
38
39
40
41
42
43
44
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

45
46
    // put this method on the request
    // inspect the request to extract options
47
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
48
        let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext());
49

50
        let options = DeltaGeneratorOptions {
51
52
53
54
55
56
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
57
58
59
60
61
62
            continuous_usage_stats: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.continuous_usage_stats)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
63
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
64
            response_fields,
65
66
        };

67
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
68
69
70
71
72
73
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
74
    pub continuous_usage_stats: bool,
75
    pub enable_logprobs: bool,
76
    pub response_fields: NvExtResponseFieldSelection,
77
78
79
80
81
}

pub struct DeltaGenerator {
    id: String,
    object: String,
82
    created: u32,
83
84
    model: String,
    system_fingerprint: Option<String>,
85
    usage: dynamo_protocols::types::CompletionUsage,
86
    options: DeltaGeneratorOptions,
87
    tracker: Option<Arc<RequestTracker>>,
88
89
90
}

impl DeltaGenerator {
91
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
92
93
94
95
96
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

97
98
99
100
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

101
102
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
103
        let usage = dynamo_protocols::types::CompletionUsage {
104
105
106
107
108
109
110
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

111
112
        let completion_id = format!("cmpl-{request_id}");

113
        // Always create request tracker for per-worker metrics (TTFT, ITL per worker_id).
114
115
        // `response_fields` only controls which nvext fields are returned to the client;
        // the tracker still records timing/ITL internally for metrics.
116
        let tracker = Some(Arc::new(RequestTracker::new()));
117

118
        Self {
119
            id: completion_id,
120
121
122
123
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
124
            usage,
125
            options,
126
            tracker,
127
128
129
        }
    }

130
131
132
133
134
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

135
    pub fn update_isl(&mut self, isl: u32) {
136
137
138
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
139
140
141
142
143
144
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
145
    ) -> Option<dynamo_protocols::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
166
167
                    let converted = convert_backend_top_logprobs(top_lps, t, *tid, *lp);
                    serde_json::to_value(converted).unwrap()
Greg Clark's avatar
Greg Clark committed
168
169
170
171
                })
                .collect()
        });

172
        Some(dynamo_protocols::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
173
174
175
176
177
178
179
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

180
181
    pub fn create_choice(
        &self,
182
        index: u32,
183
        text: Option<String>,
184
185
        finish_reason: Option<dynamo_protocols::types::CompletionFinishReason>,
        logprobs: Option<dynamo_protocols::types::Logprobs>,
186
    ) -> NvCreateCompletionResponse {
187
188
        // todo - update for tool calling

189
190
191
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
192
        let inner = dynamo_protocols::types::CreateCompletionResponse {
193
194
            id: self.id.clone(),
            object: self.object.clone(),
195
            created: self.created,
196
197
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
198
            choices: vec![dynamo_protocols::types::Choice {
199
                text: text.unwrap_or_default(),
200
                index,
201
                finish_reason,
Greg Clark's avatar
Greg Clark committed
202
                logprobs,
203
            }],
204
205
206
207
208
            usage: if self.options.enable_usage && self.options.continuous_usage_stats {
                Some(self.get_usage())
            } else {
                None
            },
209
210
        };

211
        NvCreateCompletionResponse { inner, nvext: None }
212
    }
213
214
215
216
217
218
219

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
220
        let usage = self.get_usage();
221

222
        let inner = dynamo_protocols::types::CreateCompletionResponse {
223
224
225
226
227
228
229
230
231
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
        };

232
        NvCreateCompletionResponse { inner, nvext: None }
233
234
235
236
237
238
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
239

240
241
242
243
244
    /// Check if continuous usage tracking is enabled
    pub fn is_continuous_usage_enabled(&self) -> bool {
        self.options.continuous_usage_stats
    }

245
    pub fn get_usage(&self) -> dynamo_protocols::types::CompletionUsage {
246
247
248
249
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
250
251
}

252
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
253
254
255
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
256
    ) -> anyhow::Result<NvCreateCompletionResponse> {
257
258
259
260
261
262
263
264
265
266
267
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

268
269
270
271
272
273
274
        // If backend provides completion_usage, use it to update usage stats
        // This is critical for prompt embeddings where prompt_tokens comes from
        // the embedding sequence length computed by the worker
        if let Some(completion_usage) = delta.completion_usage.as_ref() {
            // Update prompt_tokens from worker if provided (e.g., for embeddings)
            self.usage.prompt_tokens = completion_usage.prompt_tokens;

275
276
277
278
279
            // Propagate completion token details if provided
            if let Some(completion_details) = completion_usage.completion_tokens_details.as_ref() {
                self.usage.completion_tokens_details = Some(completion_details.clone());
            }

280
281
282
283
            // Propagate prompt token details if provided
            if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
284
285
        }

Greg Clark's avatar
Greg Clark committed
286
287
288
289
290
291
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
292
293

        let finish_reason = delta.finish_reason.map(Into::into);
294
295

        // create choice
296
        let index = delta.index.unwrap_or(0);
297
298
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

299
300
301
302
303
304
305
        // Record finish for timing/ITL accounting even when timing is not returned to the client.
        // Kept at call site because it's a side effect on the tracker — not a gating decision.
        if finish_reason.is_some()
            && let Some(ref tracker) = self.tracker
        {
            tracker.record_finish();
        }
306

307
308
309
310
311
312
313
314
315
        // Build the nvext response payload via the shared gating helper on
        // `NvExtResponseFieldSelection` (see `nvext.rs`). Both chat and
        // completions delta generators go through the same helper so the gating
        // rules stay in one place.
        if let Some(nvext_response) = self.options.response_fields.build_response_nvext(
            self.tracker.as_ref(),
            delta.disaggregated_params.as_ref(),
            finish_reason.is_some(),
        ) && let Ok(nvext_json) = serde_json::to_value(&nvext_response)
316
        {
317
318
319
320
321
322
323
324
325
326
327
328
329
            response.nvext = Some(nvext_json);
            if let Some(ref info) = nvext_response.worker_id {
                tracing::debug!(
                    "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                    info.prefill_worker_id,
                    info.decode_worker_id
                );
            }
            if let Some(ref tokens) = nvext_response.token_ids {
                tracing::debug!(
                    "Injected token_ids into completions nvext: {} tokens",
                    tokens.len()
                );
330
331
332
            }
        }

333
        Ok(response)
334
    }
335
336

    fn get_isl(&self) -> Option<u32> {
337
        Some(self.usage.prompt_tokens)
338
    }
339
340
341
342
343
344
345
346

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
347

348
349
350
351
    fn is_continuous_usage_enabled(&self) -> bool {
        DeltaGenerator::is_continuous_usage_enabled(self)
    }

352
    fn get_usage(&self) -> dynamo_protocols::types::CompletionUsage {
353
354
        DeltaGenerator::get_usage(self)
    }
355
356
357
358

    fn tracker(&self) -> Option<std::sync::Arc<crate::protocols::common::timing::RequestTracker>> {
        self.tracker.clone()
    }
359
}
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503

#[cfg(test)]
mod tests {
    use super::*;
    use crate::protocols::common::{self, llm_backend::BackendOutput, timing::WORKER_TYPE_PREFILL};
    use crate::protocols::openai::DeltaGeneratorExt;
    use dynamo_protocols::types::{CreateCompletionRequestArgs, Prompt};

    fn create_test_request() -> NvCreateCompletionRequest {
        let inner = CreateCompletionRequestArgs::default()
            .model("test-model")
            .prompt(Prompt::String("test".to_string()))
            .build()
            .expect("completion request");

        NvCreateCompletionRequest {
            inner,
            common: Default::default(),
            nvext: None,
            metadata: None,
            unsupported_fields: Default::default(),
        }
    }

    fn make_request_with_nvext(
        nvext: crate::protocols::openai::nvext::NvExt,
    ) -> NvCreateCompletionRequest {
        let mut request = create_test_request();
        request.nvext = Some(nvext);
        request
    }

    fn final_backend_output() -> BackendOutput {
        BackendOutput {
            token_ids: vec![1],
            tokens: vec![Some("hello".to_string())],
            text: Some("hello".to_string()),
            cum_log_probs: None,
            log_probs: None,
            top_logprobs: None,
            finish_reason: Some(common::FinishReason::Stop),
            stop_reason: None,
            index: Some(0),
            completion_usage: None,
            disaggregated_params: Some(serde_json::json!({
                "token_ids": [11, 22, 33],
                "routed_experts": {"layer_0": [1, 3]}
            })),
        }
    }

    #[test]
    fn test_plain_request_without_extra_fields_omits_nvext() {
        let request = create_test_request();
        let mut generator = request.response_generator("req-no-nvext".to_string());
        let tracker = generator.tracker().expect("tracker");
        tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);

        let response = generator
            .choice_from_postprocessor(final_backend_output())
            .expect("choice generation");

        assert!(response.nvext.is_none());
    }

    #[test]
    fn test_timing_extra_field_emits_timing_on_final_chunk() {
        use crate::protocols::openai::nvext::NvExt;
        let nvext = NvExt::builder()
            .extra_fields(vec!["timing".to_string()])
            .build()
            .unwrap();
        let mut generator =
            make_request_with_nvext(nvext).response_generator("req-timing".to_string());

        let response = generator
            .choice_from_postprocessor(final_backend_output())
            .expect("choice generation");

        let nvext_json = response.nvext.expect("nvext present for timing request");
        assert!(
            nvext_json.get("timing").is_some(),
            "timing should be emitted when extra_fields=[\"timing\"]"
        );
        assert!(nvext_json.get("worker_id").is_none());
        assert!(nvext_json.get("token_ids").is_none());
        assert!(nvext_json.get("routed_experts").is_none());
    }

    #[test]
    fn test_query_instance_id_emits_worker_id_and_token_ids() {
        use crate::protocols::openai::nvext::NvExt;
        let nvext = NvExt::builder()
            .annotations(vec!["query_instance_id:abc".to_string()])
            .build()
            .unwrap();
        let mut generator =
            make_request_with_nvext(nvext).response_generator("req-qid".to_string());
        let tracker = generator.tracker().expect("tracker");
        tracker.record_worker(42, Some(0), WORKER_TYPE_PREFILL);

        let response = generator
            .choice_from_postprocessor(final_backend_output())
            .expect("choice generation");

        let nvext_json = response
            .nvext
            .expect("nvext present for query_instance_id flow");
        assert!(nvext_json.get("worker_id").is_some());
        assert_eq!(
            nvext_json.get("token_ids"),
            Some(&serde_json::json!([11, 22, 33]))
        );
        // timing is NOT auto-enabled for query_instance_id — it is gated by `extra_fields: ["timing"]`.
        assert!(nvext_json.get("timing").is_none());
        assert!(nvext_json.get("routed_experts").is_none());
    }

    #[test]
    fn test_routed_experts_extra_field_emits_routed_experts() {
        use crate::protocols::openai::nvext::NvExt;
        let nvext = NvExt::builder()
            .extra_fields(vec!["routed_experts".to_string()])
            .build()
            .unwrap();
        let mut generator =
            make_request_with_nvext(nvext).response_generator("req-experts".to_string());

        let response = generator
            .choice_from_postprocessor(final_backend_output())
            .expect("choice generation");

        let nvext_json = response
            .nvext
            .expect("nvext present for routed_experts request");
        assert_eq!(
            nvext_json.get("routed_experts"),
            Some(&serde_json::json!({"layer_0": [1, 3]}))
        );
        assert!(nvext_json.get("worker_id").is_none());
        assert!(nvext_json.get("timing").is_none());
        assert!(nvext_json.get("token_ids").is_none());
    }
}