delta.rs 19.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
5
use crate::{
6
    local_model::runtime_config::ModelRuntimeConfig,
7
    protocols::{
8
9
        common::{self, timing::RequestTimingTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
10
    },
Greg Clark's avatar
Greg Clark committed
11
12
    types::TokenIdType,
};
13

14
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
15
impl NvCreateChatCompletionRequest {
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

39
40
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
41
42
43
    /// # Arguments
    /// * `request_id` - The request ID to use for the chat completion response ID.
    ///
44
45
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
46
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
47
48
49
50
51
52
        // Check if client requested timing in extra_fields
        let enable_timing = self
            .nvext()
            .and_then(|nv| nv.extra_fields.as_ref())
            .is_some_and(|fields| fields.iter().any(|f| f == "timing"));

53
        let options = DeltaGeneratorOptions {
54
55
56
57
58
59
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
60
61
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
62
            enable_timing,
63
            runtime_config: ModelRuntimeConfig::default(),
64
65
        };

66
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
67
68
69
    }
}

70
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
71
72
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
73
    /// Determines whether token usage statistics should be included in the response.
74
    pub enable_usage: bool,
75
    /// Determines whether log probabilities should be included in the response.
76
    pub enable_logprobs: bool,
77
78
    /// Determines whether timing information should be included in the response's nvext.
    pub enable_timing: bool,
79

80
    pub runtime_config: ModelRuntimeConfig,
81
82
}

83
/// Generates incremental chat completion responses in a streaming fashion.
84
pub struct DeltaGenerator {
85
    /// Unique identifier for the chat completion session.
86
    id: String,
87
    /// Object type, representing a streamed chat completion response.
88
    object: String,
89
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
90
    created: u32,
91
    model: String,
92
    /// Optional system fingerprint for version tracking.
93
    system_fingerprint: Option<String>,
94
    /// Optional service tier information for the response.
95
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
96
    /// Tracks token usage for the completion request.
97
    usage: dynamo_async_openai::types::CompletionUsage,
98
    /// Counter tracking the number of messages issued.
99
    msg_counter: u64,
100
    /// Configuration options for response generation.
101
    options: DeltaGeneratorOptions,
102
103
    /// Optional timing tracker for per-request timing metrics.
    timing_tracker: Option<RequestTimingTracker>,
104
105
106
}

impl DeltaGenerator {
107
108
109
110
111
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
112
    /// * `request_id` - The request ID to use for the chat completion response.
113
114
115
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
116
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
117
118
119
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
120
121
122
123
124
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
125

126
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
127
128
129
130
131
132
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
133

134
135
        let chatcmpl_id = format!("chatcmpl-{request_id}");

136
137
138
139
140
141
142
        // Create timing tracker if timing is enabled
        let timing_tracker = if options.enable_timing {
            Some(RequestTimingTracker::new())
        } else {
            None
        };

143
        Self {
144
            id: chatcmpl_id,
145
146
147
148
149
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
150
            usage,
151
152
            msg_counter: 0,
            options,
153
            timing_tracker,
154
155
156
        }
    }

157
158
159
160
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
161
    pub fn update_isl(&mut self, isl: u32) {
162
163
164
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
165
166
167
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
168
        token_ids: &[TokenIdType],
Greg Clark's avatar
Greg Clark committed
169
170
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
171
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
172
173
174
175
176
177
178
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
179
            .map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
Greg Clark's avatar
Greg Clark committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
199
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
200
201
202
203
204
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
205
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
206
207
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
208
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
209
210
211
212
213
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
214
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
215
216
217
218
219
220
221
222
223
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

224
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
225
226
227
228
229
            content,
            refusal: None,
        })
    }

230
231
232
233
234
235
236
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
237
    /// * `stop_reason` - Optional stop string or token that triggered the stop.
238
239
    ///
    /// # Returns
240
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
241
    #[allow(deprecated)]
242
    pub fn create_choice(
243
        &mut self,
Paul Hendricks's avatar
Paul Hendricks committed
244
        index: u32,
245
        text: Option<String>,
246
247
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
248
        stop_reason: Option<dynamo_async_openai::types::StopReason>,
249
    ) -> NvCreateChatCompletionStreamResponse {
250
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
251
            content: text,
252
253
            function_call: None,
            tool_calls: None,
254
            role: if self.msg_counter == 0 {
255
                Some(dynamo_async_openai::types::Role::Assistant)
256
257
258
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
259
            refusal: None,
260
            reasoning_content: None,
261
262
        };

263
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
264
265
266
            index,
            delta,
            finish_reason,
267
            stop_reason,
Paul Hendricks's avatar
Paul Hendricks committed
268
269
270
271
272
            logprobs,
        };

        let choices = vec![choice];

273
274
275
276
277
278
279
280
281
282
283
284
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices,
            usage: None, // Always None for chunks with content/choices
            service_tier: self.service_tier.clone(),
285
            nvext: None, // Will be populated by router layer if needed
286
        }
287
288
289
290
291
292
293
294
    }

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
295
        let usage = self.get_usage();
296

297
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
298
299
300
301
302
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
303
304
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
305
            service_tier: self.service_tier.clone(),
306
            nvext: None,
307
308
        }
    }
309
310
311
312
313

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
314
315
316
317
318
319

    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
320
321
}

322
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
323
/// it to transform backend responses into OpenAI-style streaming responses.
324
325
326
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
327
328
329
330
331
332
333
334
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
335
336
337
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
338
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
358
359
        }

Greg Clark's avatar
Greg Clark committed
360
361
        let logprobs = self.create_logprobs(
            delta.tokens,
362
            &delta.token_ids,
Greg Clark's avatar
Greg Clark committed
363
364
365
            delta.log_probs,
            delta.top_logprobs,
        );
366

367
        // Map backend finish reasons to OpenAI's finish reasons.
368
        let finish_reason = match delta.finish_reason {
369
370
371
372
373
374
375
376
377
378
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
379
            Some(common::FinishReason::ContentFilter) => {
380
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
381
            }
382
383
384
385
386
387
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

388
        // Create the streaming response.
389
        let index = 0;
390
391
392
393
394
395
396
        let mut stream_response = self.create_choice(
            index,
            delta.text,
            finish_reason,
            logprobs,
            delta.stop_reason,
        );
397

398
399
400
401
402
403
404
        // Record first token time (only succeeds on first call due to OnceLock)
        if let Some(ref tracker) = self.timing_tracker {
            tracker.record_first_token();
        }

        // Extract worker_id from disaggregated_params
        let worker_id_info = delta
405
406
407
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());

        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
            self.timing_tracker.as_ref().map(|tracker| {
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

        // Inject nvext if we have worker_id or timing
        if worker_id_info.is_some() || timing_info.is_some() {
422
            let nvext_response = NvExtResponse {
423
424
                worker_id: worker_id_info.clone(),
                timing: timing_info,
425
426
427
428
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                stream_response.nvext = Some(nvext_json);
429
430
431
432
433
434
435
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
436
437
            }
        }
Paul Hendricks's avatar
Paul Hendricks committed
438

439
        Ok(stream_response)
440
    }
441
442
443
444

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
445
446
447
448
449
450
451
452

    fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
453
454
455
456

    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
457
}
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_async_openai::types::{
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
    };

    fn create_test_request() -> NvCreateChatCompletionRequest {
        let messages = vec![ChatCompletionRequestMessage::User(
            ChatCompletionRequestUserMessage {
                content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
                name: None,
            },
        )];

        NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages,
                stream: Some(false),
                stream_options: None,
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
            chat_template_args: None,
486
            unsupported_fields: Default::default(),
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        }
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_enables_usage() {
        // Test that non-streaming requests get usage enabled
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(false); // false = non-streaming

        assert!(
            request.inner.stream_options.is_some(),
            "Non-streaming request should have stream_options created"
        );
        assert!(
            request.inner.stream_options.unwrap().include_usage,
            "Non-streaming request should have include_usage=true for OpenAI compliance"
        );
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_ignores_streaming() {
        // Test that streaming requests are not modified
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(true); // true = streaming

        assert!(
            request.inner.stream_options.is_none(),
            "Streaming request should not have stream_options modified"
        );
    }
}