delta.rs 17.9 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
5
use crate::{
6
    local_model::runtime_config::ModelRuntimeConfig,
Greg Clark's avatar
Greg Clark committed
7
8
9
    protocols::common::{self},
    types::TokenIdType,
};
10

11
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
12
impl NvCreateChatCompletionRequest {
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

36
37
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
38
39
40
    /// # Arguments
    /// * `request_id` - The request ID to use for the chat completion response ID.
    ///
41
42
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
43
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
44
        let options = DeltaGeneratorOptions {
45
46
47
48
49
50
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
51
52
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
53
            runtime_config: ModelRuntimeConfig::default(),
54
55
        };

56
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
57
58
59
    }
}

60
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
61
62
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
63
    /// Determines whether token usage statistics should be included in the response.
64
    pub enable_usage: bool,
65
    /// Determines whether log probabilities should be included in the response.
66
    pub enable_logprobs: bool,
67

68
    pub runtime_config: ModelRuntimeConfig,
69
70
}

71
/// Generates incremental chat completion responses in a streaming fashion.
72
#[derive(Debug)]
73
pub struct DeltaGenerator {
74
    /// Unique identifier for the chat completion session.
75
    id: String,
76
    /// Object type, representing a streamed chat completion response.
77
    object: String,
78
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
79
    created: u32,
80
    model: String,
81
    /// Optional system fingerprint for version tracking.
82
    system_fingerprint: Option<String>,
83
    /// Optional service tier information for the response.
84
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
85
    /// Tracks token usage for the completion request.
86
    usage: dynamo_async_openai::types::CompletionUsage,
87
    /// Counter tracking the number of messages issued.
88
    msg_counter: u64,
89
    /// Configuration options for response generation.
90
91
92
93
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
94
95
96
97
98
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
99
    /// * `request_id` - The request ID to use for the chat completion response.
100
101
102
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
103
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
104
105
106
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
107
108
109
110
111
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
112

113
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
114
115
116
117
118
119
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
120

121
122
        let chatcmpl_id = format!("chatcmpl-{request_id}");

123
        Self {
124
            id: chatcmpl_id,
125
126
127
128
129
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
130
            usage,
131
132
            msg_counter: 0,
            options,
133
134
135
        }
    }

136
137
138
139
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
140
    pub fn update_isl(&mut self, isl: u32) {
141
142
143
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
144
145
146
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
147
        token_ids: &[TokenIdType],
Greg Clark's avatar
Greg Clark committed
148
149
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
150
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
151
152
153
154
155
156
157
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
158
            .map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
Greg Clark's avatar
Greg Clark committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
178
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
179
180
181
182
183
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
184
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
185
186
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
187
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
188
189
190
191
192
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
193
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
194
195
196
197
198
199
200
201
202
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

203
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
204
205
206
207
208
            content,
            refusal: None,
        })
    }

209
210
211
212
213
214
215
216
217
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
    ///
    /// # Returns
218
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
219
    #[allow(deprecated)]
220
    pub fn create_choice(
221
        &mut self,
Paul Hendricks's avatar
Paul Hendricks committed
222
        index: u32,
223
        text: Option<String>,
224
225
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
226
    ) -> NvCreateChatCompletionStreamResponse {
227
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
228
            content: text,
229
230
            function_call: None,
            tool_calls: None,
231
            role: if self.msg_counter == 0 {
232
                Some(dynamo_async_openai::types::Role::Assistant)
233
234
235
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
236
            refusal: None,
237
            reasoning_content: None,
238
239
        };

240
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
241
242
243
244
245
246
247
248
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

249
250
251
252
253
254
255
256
257
258
259
260
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices,
            usage: None, // Always None for chunks with content/choices
            service_tier: self.service_tier.clone(),
261
            nvext: None, // Will be populated by router layer if needed
262
        }
263
264
265
266
267
268
269
270
271
272
    }

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
273

274
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
275
276
277
278
279
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
280
281
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
282
            service_tier: self.service_tier.clone(),
283
            nvext: None,
284
285
        }
    }
286
287
288
289
290

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
291
292
}

293
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
294
/// it to transform backend responses into OpenAI-style streaming responses.
295
296
297
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
298
299
300
301
302
303
304
305
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
306
307
308
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
309
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
310
        // Aggregate token usage if enabled.
311
        if self.options.enable_usage {
312
313
314
315
316
317
318
319
320
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
321
322
323
324
325
326
327
328
329
330

            // If backend provides completion_usage with prompt token details,
            // propagate the entire details struct to usage tracking
            if let Some(prompt_details) = delta
                .completion_usage
                .as_ref()
                .and_then(|usage| usage.prompt_tokens_details.as_ref())
            {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
331
332
        }

Greg Clark's avatar
Greg Clark committed
333
334
        let logprobs = self.create_logprobs(
            delta.tokens,
335
            &delta.token_ids,
Greg Clark's avatar
Greg Clark committed
336
337
338
            delta.log_probs,
            delta.top_logprobs,
        );
339

340
        // Map backend finish reasons to OpenAI's finish reasons.
341
        let finish_reason = match delta.finish_reason {
342
343
344
345
346
347
348
349
350
351
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
352
            Some(common::FinishReason::ContentFilter) => {
353
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
354
            }
355
356
357
358
359
360
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

361
        // Create the streaming response.
362
        let index = 0;
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

        // Extract worker_id from disaggregated_params and inject into nvext if present
        if let Some(worker_id_json) = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
        {
            use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};

            let prefill_worker_id = worker_id_json
                .get("prefill_worker_id")
                .and_then(|v| v.as_u64());
            let decode_worker_id = worker_id_json
                .get("decode_worker_id")
                .and_then(|v| v.as_u64());

            let worker_id_info = WorkerIdInfo {
                prefill_worker_id,
                decode_worker_id,
            };

            let nvext_response = NvExtResponse {
                worker_id: Some(worker_id_info),
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                stream_response.nvext = Some(nvext_json);
                tracing::debug!(
                    "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
                    prefill_worker_id,
                    decode_worker_id
                );
            }
        }
Paul Hendricks's avatar
Paul Hendricks committed
398

399
        Ok(stream_response)
400
    }
401
402
403
404

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
405
406
407
408
409
410
411
412

    fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
413
}
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_async_openai::types::{
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
    };

    fn create_test_request() -> NvCreateChatCompletionRequest {
        let messages = vec![ChatCompletionRequestMessage::User(
            ChatCompletionRequestUserMessage {
                content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
                name: None,
            },
        )];

        NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages,
                stream: Some(false),
                stream_options: None,
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
            chat_template_args: None,
442
            unsupported_fields: Default::default(),
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        }
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_enables_usage() {
        // Test that non-streaming requests get usage enabled
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(false); // false = non-streaming

        assert!(
            request.inner.stream_options.is_some(),
            "Non-streaming request should have stream_options created"
        );
        assert!(
            request.inner.stream_options.unwrap().include_usage,
            "Non-streaming request should have include_usage=true for OpenAI compliance"
        );
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_ignores_streaming() {
        // Test that streaming requests are not modified
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(true); // true = streaming

        assert!(
            request.inner.stream_options.is_none(),
            "Streaming request should not have stream_options modified"
        );
    }
}