delta.rs 19.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
5
use crate::{
6
    local_model::runtime_config::ModelRuntimeConfig,
7
    protocols::{
8
9
        common::{self, timing::RequestTimingTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo},
10
    },
Greg Clark's avatar
Greg Clark committed
11
12
    types::TokenIdType,
};
13

14
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
15
impl NvCreateChatCompletionRequest {
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

39
40
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
41
42
43
    /// # Arguments
    /// * `request_id` - The request ID to use for the chat completion response ID.
    ///
44
45
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
46
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
47
48
49
50
51
52
        // Check if client requested timing in extra_fields
        let enable_timing = self
            .nvext()
            .and_then(|nv| nv.extra_fields.as_ref())
            .is_some_and(|fields| fields.iter().any(|f| f == "timing"));

53
        let options = DeltaGeneratorOptions {
54
55
56
57
58
59
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
60
61
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
62
            enable_timing,
63
            runtime_config: ModelRuntimeConfig::default(),
64
65
        };

66
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
67
68
69
    }
}

70
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
71
72
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
73
    /// Determines whether token usage statistics should be included in the response.
74
    pub enable_usage: bool,
75
    /// Determines whether log probabilities should be included in the response.
76
    pub enable_logprobs: bool,
77
78
    /// Determines whether timing information should be included in the response's nvext.
    pub enable_timing: bool,
79

80
    pub runtime_config: ModelRuntimeConfig,
81
82
}

83
/// Generates incremental chat completion responses in a streaming fashion.
84
pub struct DeltaGenerator {
85
    /// Unique identifier for the chat completion session.
86
    id: String,
87
    /// Object type, representing a streamed chat completion response.
88
    object: String,
89
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
90
    created: u32,
91
    model: String,
92
    /// Optional system fingerprint for version tracking.
93
    system_fingerprint: Option<String>,
94
    /// Optional service tier information for the response.
95
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
96
    /// Tracks token usage for the completion request.
97
    usage: dynamo_async_openai::types::CompletionUsage,
98
    /// Counter tracking the number of messages issued.
99
    msg_counter: u64,
100
    /// Configuration options for response generation.
101
    options: DeltaGeneratorOptions,
102
103
    /// Optional timing tracker for per-request timing metrics.
    timing_tracker: Option<RequestTimingTracker>,
104
105
106
}

impl DeltaGenerator {
107
108
109
110
111
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
112
    /// * `request_id` - The request ID to use for the chat completion response.
113
114
115
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
116
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
117
118
119
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
120
121
122
123
124
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
125

126
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
127
128
129
130
131
132
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
133

134
135
        let chatcmpl_id = format!("chatcmpl-{request_id}");

136
137
138
139
140
141
142
        // Create timing tracker if timing is enabled
        let timing_tracker = if options.enable_timing {
            Some(RequestTimingTracker::new())
        } else {
            None
        };

143
        Self {
144
            id: chatcmpl_id,
145
146
147
148
149
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
150
            usage,
151
152
            msg_counter: 0,
            options,
153
            timing_tracker,
154
155
156
        }
    }

157
158
159
160
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
161
    pub fn update_isl(&mut self, isl: u32) {
162
163
164
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
165
166
167
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
168
        token_ids: &[TokenIdType],
Greg Clark's avatar
Greg Clark committed
169
170
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
171
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
172
173
174
175
176
177
178
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
179
            .map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
Greg Clark's avatar
Greg Clark committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
199
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
200
201
202
203
204
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
205
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
206
207
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
208
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
209
210
211
212
213
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
214
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
215
216
217
218
219
220
221
222
223
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

224
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
225
226
227
228
229
            content,
            refusal: None,
        })
    }

230
231
232
233
234
235
236
237
238
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
    ///
    /// # Returns
239
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
240
    #[allow(deprecated)]
241
    pub fn create_choice(
242
        &mut self,
Paul Hendricks's avatar
Paul Hendricks committed
243
        index: u32,
244
        text: Option<String>,
245
246
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
247
    ) -> NvCreateChatCompletionStreamResponse {
248
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
249
            content: text,
250
251
            function_call: None,
            tool_calls: None,
252
            role: if self.msg_counter == 0 {
253
                Some(dynamo_async_openai::types::Role::Assistant)
254
255
256
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
257
            refusal: None,
258
            reasoning_content: None,
259
260
        };

261
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
262
263
264
265
266
267
268
269
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

270
271
272
273
274
275
276
277
278
279
280
281
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices,
            usage: None, // Always None for chunks with content/choices
            service_tier: self.service_tier.clone(),
282
            nvext: None, // Will be populated by router layer if needed
283
        }
284
285
286
287
288
289
290
291
    }

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
292
        let usage = self.get_usage();
293

294
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
295
296
297
298
299
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
300
301
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
302
            service_tier: self.service_tier.clone(),
303
            nvext: None,
304
305
        }
    }
306
307
308
309
310

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
311
312
313
314
315
316

    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
317
318
}

319
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
320
/// it to transform backend responses into OpenAI-style streaming responses.
321
322
323
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
324
325
326
327
328
329
330
331
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
332
333
334
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
335
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
355
356
        }

Greg Clark's avatar
Greg Clark committed
357
358
        let logprobs = self.create_logprobs(
            delta.tokens,
359
            &delta.token_ids,
Greg Clark's avatar
Greg Clark committed
360
361
362
            delta.log_probs,
            delta.top_logprobs,
        );
363

364
        // Map backend finish reasons to OpenAI's finish reasons.
365
        let finish_reason = match delta.finish_reason {
366
367
368
369
370
371
372
373
374
375
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
376
            Some(common::FinishReason::ContentFilter) => {
377
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
378
            }
379
380
381
382
383
384
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

385
        // Create the streaming response.
386
        let index = 0;
387
388
        let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

389
390
391
392
393
394
395
        // Record first token time (only succeeds on first call due to OnceLock)
        if let Some(ref tracker) = self.timing_tracker {
            tracker.record_first_token();
        }

        // Extract worker_id from disaggregated_params
        let worker_id_info = delta
396
397
398
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
399
400
401
402
403
404
405
406
407
408
409
410
411
412
            .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok());

        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
            self.timing_tracker.as_ref().map(|tracker| {
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

        // Inject nvext if we have worker_id or timing
        if worker_id_info.is_some() || timing_info.is_some() {
413
            let nvext_response = NvExtResponse {
414
415
                worker_id: worker_id_info.clone(),
                timing: timing_info,
416
417
418
419
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                stream_response.nvext = Some(nvext_json);
420
421
422
423
424
425
426
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
427
428
            }
        }
Paul Hendricks's avatar
Paul Hendricks committed
429

430
        Ok(stream_response)
431
    }
432
433
434
435

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
436
437
438
439
440
441
442
443

    fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
444
445
446
447

    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
448
}
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_async_openai::types::{
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
    };

    fn create_test_request() -> NvCreateChatCompletionRequest {
        let messages = vec![ChatCompletionRequestMessage::User(
            ChatCompletionRequestUserMessage {
                content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
                name: None,
            },
        )];

        NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages,
                stream: Some(false),
                stream_options: None,
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
            chat_template_args: None,
477
            unsupported_fields: Default::default(),
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        }
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_enables_usage() {
        // Test that non-streaming requests get usage enabled
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(false); // false = non-streaming

        assert!(
            request.inner.stream_options.is_some(),
            "Non-streaming request should have stream_options created"
        );
        assert!(
            request.inner.stream_options.unwrap().include_usage,
            "Non-streaming request should have include_usage=true for OpenAI compliance"
        );
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_ignores_streaming() {
        // Test that streaming requests are not modified
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(true); // true = streaming

        assert!(
            request.inner.stream_options.is_none(),
            "Streaming request should not have stream_options modified"
        );
    }
}