delta.rs 20.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
7
use crate::{
8
    local_model::runtime_config::ModelRuntimeConfig,
9
    protocols::{
10
11
        common::{self, timing::RequestTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
12
    },
Greg Clark's avatar
Greg Clark committed
13
14
    types::TokenIdType,
};
15

16
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
17
impl NvCreateChatCompletionRequest {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

41
42
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
43
44
45
    /// # Arguments
    /// * `request_id` - The request ID to use for the chat completion response ID.
    ///
46
47
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
48
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
49
50
51
52
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
53
            .nvext()
54
55
56
57
58
59
60
61
62
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
63

64
        let options = DeltaGeneratorOptions {
65
66
67
68
69
70
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
71
72
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
73
            enable_tracking,
74
            runtime_config: ModelRuntimeConfig::default(),
75
76
        };

77
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
78
79
80
    }
}

81
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
82
83
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
84
    /// Determines whether token usage statistics should be included in the response.
85
    pub enable_usage: bool,
86
    /// Determines whether log probabilities should be included in the response.
87
    pub enable_logprobs: bool,
88
89
    /// Determines whether request tracking (timing, KV hit rate) should be enabled.
    pub enable_tracking: bool,
90

91
    pub runtime_config: ModelRuntimeConfig,
92
93
}

94
/// Generates incremental chat completion responses in a streaming fashion.
95
pub struct DeltaGenerator {
96
    /// Unique identifier for the chat completion session.
97
    id: String,
98
    /// Object type, representing a streamed chat completion response.
99
    object: String,
100
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
101
    created: u32,
102
    model: String,
103
    /// Optional system fingerprint for version tracking.
104
    system_fingerprint: Option<String>,
105
    /// Optional service tier information for the response.
106
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
107
    /// Tracks token usage for the completion request.
108
    usage: dynamo_async_openai::types::CompletionUsage,
109
    /// Counter tracking the number of messages issued.
110
    msg_counter: u64,
111
    /// Configuration options for response generation.
112
    options: DeltaGeneratorOptions,
113
114
    /// Optional request tracker for per-request metrics (shared with PreprocessedRequest).
    tracker: Option<Arc<RequestTracker>>,
115
116
117
}

impl DeltaGenerator {
118
119
120
121
122
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
123
    /// * `request_id` - The request ID to use for the chat completion response.
124
125
126
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
127
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
128
129
130
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
131
132
133
134
135
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
136

137
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
138
139
140
141
142
143
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
144

145
146
        let chatcmpl_id = format!("chatcmpl-{request_id}");

147
148
149
        // Create request tracker if tracking is enabled
        let tracker = if options.enable_tracking {
            Some(Arc::new(RequestTracker::new()))
150
151
152
153
        } else {
            None
        };

154
        Self {
155
            id: chatcmpl_id,
156
157
158
159
160
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
161
            usage,
162
163
            msg_counter: 0,
            options,
164
            tracker,
165
166
167
        }
    }

168
169
170
171
172
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

173
174
175
176
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
177
    pub fn update_isl(&mut self, isl: u32) {
178
179
180
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
181
182
183
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
184
        token_ids: &[TokenIdType],
Greg Clark's avatar
Greg Clark committed
185
186
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
187
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
188
189
190
191
192
193
194
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
195
            .map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
Greg Clark's avatar
Greg Clark committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
215
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
216
217
218
219
220
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
221
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
222
223
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
224
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
225
226
227
228
229
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
230
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
231
232
233
234
235
236
237
238
239
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

240
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
241
242
243
244
245
            content,
            refusal: None,
        })
    }

246
247
248
249
250
251
252
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
253
    /// * `stop_reason` - Optional stop string or token that triggered the stop.
254
255
    ///
    /// # Returns
256
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
257
    #[allow(deprecated)]
258
    pub fn create_choice(
259
        &mut self,
Paul Hendricks's avatar
Paul Hendricks committed
260
        index: u32,
261
        text: Option<String>,
262
263
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
264
        stop_reason: Option<dynamo_async_openai::types::StopReason>,
265
    ) -> NvCreateChatCompletionStreamResponse {
266
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
267
            content: text,
268
269
            function_call: None,
            tool_calls: None,
270
            role: if self.msg_counter == 0 {
271
                Some(dynamo_async_openai::types::Role::Assistant)
272
273
274
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
275
            refusal: None,
276
            reasoning_content: None,
277
278
        };

279
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
280
281
282
            index,
            delta,
            finish_reason,
283
            stop_reason,
Paul Hendricks's avatar
Paul Hendricks committed
284
285
286
287
288
            logprobs,
        };

        let choices = vec![choice];

289
290
291
292
293
294
295
296
297
298
299
300
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices,
            usage: None, // Always None for chunks with content/choices
            service_tier: self.service_tier.clone(),
301
            nvext: None, // Will be populated by router layer if needed
302
        }
303
304
305
306
307
308
309
310
    }

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
311
        let usage = self.get_usage();
312

313
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
314
315
316
317
318
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
319
320
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
321
            service_tier: self.service_tier.clone(),
322
            nvext: None,
323
324
        }
    }
325
326
327
328
329

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
330
331
332
333
334
335

    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
336
337
}

338
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
339
/// it to transform backend responses into OpenAI-style streaming responses.
340
341
342
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
343
344
345
346
347
348
349
350
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
351
352
353
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
354
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
374
375
        }

Greg Clark's avatar
Greg Clark committed
376
377
        let logprobs = self.create_logprobs(
            delta.tokens,
378
            &delta.token_ids,
Greg Clark's avatar
Greg Clark committed
379
380
381
            delta.log_probs,
            delta.top_logprobs,
        );
382

383
        // Map backend finish reasons to OpenAI's finish reasons.
384
        let finish_reason = match delta.finish_reason {
385
386
387
388
389
390
391
392
393
394
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
395
            Some(common::FinishReason::ContentFilter) => {
396
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
397
            }
398
399
400
401
402
403
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

404
        // Create the streaming response.
405
        let index = 0;
406
407
408
409
410
411
412
        let mut stream_response = self.create_choice(
            index,
            delta.text,
            finish_reason,
            logprobs,
            delta.stop_reason,
        );
413

414
        // Record first token time (only succeeds on first call due to OnceLock)
415
        if let Some(ref tracker) = self.tracker {
416
417
418
            tracker.record_first_token();
        }

419
420
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
421

422
423
424
425
426
427
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());

428
429
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
430
            self.tracker.as_ref().map(|tracker| {
431
432
433
434
435
436
437
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

438
439
        // Inject nvext if we have worker_id, token_ids, or timing
        if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
440
            let nvext_response = NvExtResponse {
441
442
                worker_id: worker_id_info.clone(),
                timing: timing_info,
443
                token_ids: token_ids.clone(),
444
445
446
447
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                stream_response.nvext = Some(nvext_json);
448
449
450
451
452
453
454
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
455
456
457
458
459
460
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into chat completion nvext: {} tokens",
                        tokens.len()
                    );
                }
461
462
            }
        }
Paul Hendricks's avatar
Paul Hendricks committed
463

464
        Ok(stream_response)
465
    }
466
467
468
469

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
470
471
472
473
474
475
476
477

    fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
478
479
480
481

    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
482
}
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_async_openai::types::{
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
    };

    fn create_test_request() -> NvCreateChatCompletionRequest {
        let messages = vec![ChatCompletionRequestMessage::User(
            ChatCompletionRequestUserMessage {
                content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
                name: None,
            },
        )];

        NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages,
                stream: Some(false),
                stream_options: None,
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
            chat_template_args: None,
511
            media_io_kwargs: None,
512
            unsupported_fields: Default::default(),
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        }
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_enables_usage() {
        // Test that non-streaming requests get usage enabled
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(false); // false = non-streaming

        assert!(
            request.inner.stream_options.is_some(),
            "Non-streaming request should have stream_options created"
        );
        assert!(
            request.inner.stream_options.unwrap().include_usage,
            "Non-streaming request should have include_usage=true for OpenAI compliance"
        );
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_ignores_streaming() {
        // Test that streaming requests are not modified
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(true); // true = streaming

        assert!(
            request.inner.stream_options.is_none(),
            "Streaming request should not have stream_options modified"
        );
    }
}