delta.rs 11.2 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
5
6
7
8
use crate::{
    protocols::common::{self},
    types::TokenIdType,
};
9

10
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
11
impl NvCreateChatCompletionRequest {
12
13
14
15
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
16
17
18
    pub fn response_generator(&self) -> DeltaGenerator {
        let options = DeltaGeneratorOptions {
            enable_usage: true,
Greg Clark's avatar
Greg Clark committed
19
20
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
21
22
        };

Paul Hendricks's avatar
Paul Hendricks committed
23
        DeltaGenerator::new(self.inner.model.clone(), options)
24
25
26
    }
}

27
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
28
29
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
30
    /// Determines whether token usage statistics should be included in the response.
31
    pub enable_usage: bool,
32
    /// Determines whether log probabilities should be included in the response.
33
34
35
    pub enable_logprobs: bool,
}

36
/// Generates incremental chat completion responses in a streaming fashion.
37
#[derive(Debug)]
38
pub struct DeltaGenerator {
39
    /// Unique identifier for the chat completion session.
40
    id: String,
41
    /// Object type, representing a streamed chat completion response.
42
    object: String,
43
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
44
    created: u32,
45
    /// Model name used for generating responses.
46
    model: String,
47
    /// Optional system fingerprint for version tracking.
48
    system_fingerprint: Option<String>,
49
    /// Optional service tier information for the response.
50
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
51
    /// Tracks token usage for the completion request.
52
    usage: dynamo_async_openai::types::CompletionUsage,
53
    /// Counter tracking the number of messages issued.
54
    msg_counter: u64,
55
    /// Configuration options for response generation.
56
57
58
59
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
60
61
62
63
64
65
66
67
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
68
69
70
71
    pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
72
73
74
75
76
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
77

78
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
79
80
81
82
83
84
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
85
86
87
88
89
90
91
92

        Self {
            id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
93
            usage,
94
95
96
97
98
            msg_counter: 0,
            options,
        }
    }

99
100
101
102
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
103
    pub fn update_isl(&mut self, isl: u32) {
104
105
106
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
107
108
109
110
111
112
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
113
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
141
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
142
143
144
145
146
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
147
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
148
149
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
150
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
151
152
153
154
155
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
156
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
157
158
159
160
161
162
163
164
165
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

166
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
167
168
169
170
171
            content,
            refusal: None,
        })
    }

172
173
174
175
176
177
178
179
180
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
    ///
    /// # Returns
181
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
182
    #[allow(deprecated)]
183
184
    pub fn create_choice(
        &self,
Paul Hendricks's avatar
Paul Hendricks committed
185
        index: u32,
186
        text: Option<String>,
187
188
189
190
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
    ) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse {
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
191
192
193
            content: text,
            function_call: None,
            tool_calls: None,
194
            role: if self.msg_counter == 0 {
195
                Some(dynamo_async_openai::types::Role::Assistant)
196
197
198
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
199
            refusal: None,
200
            reasoning_content: None,
201
202
        };

203
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
204
205
206
207
208
209
210
211
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

212
213
214
215
216
        let mut usage = self.usage.clone();
        if self.options.enable_usage {
            usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
        }

217
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
218
219
220
221
222
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
Paul Hendricks's avatar
Paul Hendricks committed
223
            choices,
224
            usage: if self.options.enable_usage {
225
                Some(usage)
226
227
228
229
230
231
232
233
            } else {
                None
            },
            service_tier: self.service_tier.clone(),
        }
    }
}

234
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
235
/// it to transform backend responses into OpenAI-style streaming responses.
236
237
238
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
239
240
241
242
243
244
245
246
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
247
248
249
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
250
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
251
        // Aggregate token usage if enabled.
252
        if self.options.enable_usage {
253
254
255
256
257
258
259
260
261
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
262
263
        }

Greg Clark's avatar
Greg Clark committed
264
265
266
267
268
269
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
270

271
        // Map backend finish reasons to OpenAI's finish reasons.
272
        let finish_reason = match delta.finish_reason {
273
274
275
276
277
278
279
280
281
282
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
283
            Some(common::FinishReason::ContentFilter) => {
284
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
285
            }
286
287
288
289
290
291
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

292
        // Create the streaming response.
293
        let index = 0;
Paul Hendricks's avatar
Paul Hendricks committed
294
295
        let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

296
        Ok(stream_response)
297
    }
298
299
300
301

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
302
}