delta.rs 11.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
17
18
19
20
use crate::{
    protocols::common::{self},
    types::TokenIdType,
};
21

22
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
23
impl NvCreateChatCompletionRequest {
24
25
26
27
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
28
29
30
    pub fn response_generator(&self) -> DeltaGenerator {
        let options = DeltaGeneratorOptions {
            enable_usage: true,
Greg Clark's avatar
Greg Clark committed
31
32
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
33
34
        };

Paul Hendricks's avatar
Paul Hendricks committed
35
        DeltaGenerator::new(self.inner.model.clone(), options)
36
37
38
    }
}

39
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
40
41
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
42
    /// Determines whether token usage statistics should be included in the response.
43
    pub enable_usage: bool,
44
    /// Determines whether log probabilities should be included in the response.
45
46
47
    pub enable_logprobs: bool,
}

48
/// Generates incremental chat completion responses in a streaming fashion.
49
50
#[derive(Debug, Clone)]
pub struct DeltaGenerator {
51
    /// Unique identifier for the chat completion session.
52
    id: String,
53
    /// Object type, representing a streamed chat completion response.
54
    object: String,
55
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
56
    created: u32,
57
    /// Model name used for generating responses.
58
    model: String,
59
    /// Optional system fingerprint for version tracking.
60
    system_fingerprint: Option<String>,
61
    /// Optional service tier information for the response.
62
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
63
    /// Tracks token usage for the completion request.
64
    usage: dynamo_async_openai::types::CompletionUsage,
65
    /// Counter tracking the number of messages issued.
66
    msg_counter: u64,
67
    /// Configuration options for response generation.
68
69
70
71
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
72
73
74
75
76
77
78
79
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
80
81
82
83
    pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
84
85
86
87
88
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
89

90
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
91
92
93
94
95
96
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
97
98
99
100
101
102
103
104

        Self {
            id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
105
            usage,
106
107
108
109
110
            msg_counter: 0,
            options,
        }
    }

111
112
113
114
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
115
    pub fn update_isl(&mut self, isl: u32) {
116
117
118
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
119
120
121
122
123
124
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
125
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
153
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
154
155
156
157
158
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
159
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
160
161
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
162
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
163
164
165
166
167
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
168
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
169
170
171
172
173
174
175
176
177
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

178
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
179
180
181
182
183
            content,
            refusal: None,
        })
    }

184
185
186
187
188
189
190
191
192
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
    ///
    /// # Returns
193
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
194
    #[allow(deprecated)]
195
196
    pub fn create_choice(
        &self,
Paul Hendricks's avatar
Paul Hendricks committed
197
        index: u32,
198
        text: Option<String>,
199
200
201
202
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
    ) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse {
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
203
204
205
            content: text,
            function_call: None,
            tool_calls: None,
206
            role: if self.msg_counter == 0 {
207
                Some(dynamo_async_openai::types::Role::Assistant)
208
209
210
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
211
            refusal: None,
212
213
        };

214
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
215
216
217
218
219
220
221
222
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

223
224
225
226
227
        let mut usage = self.usage.clone();
        if self.options.enable_usage {
            usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
        }

228
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
229
230
231
232
233
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
Paul Hendricks's avatar
Paul Hendricks committed
234
            choices,
235
            usage: if self.options.enable_usage {
236
                Some(usage)
237
238
239
240
241
242
243
244
            } else {
                None
            },
            service_tier: self.service_tier.clone(),
        }
    }
}

245
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
246
/// it to transform backend responses into OpenAI-style streaming responses.
247
248
249
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
250
251
252
253
254
255
256
257
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
258
259
260
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
261
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
262
        // Aggregate token usage if enabled.
263
        if self.options.enable_usage {
264
265
266
267
268
269
270
271
272
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
273
274
        }

Greg Clark's avatar
Greg Clark committed
275
276
277
278
279
280
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
281

282
        // Map backend finish reasons to OpenAI's finish reasons.
283
        let finish_reason = match delta.finish_reason {
284
285
286
287
288
289
290
291
292
293
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
294
            Some(common::FinishReason::ContentFilter) => {
295
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
296
            }
297
298
299
300
301
302
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

303
        // Create the streaming response.
304
        let index = 0;
Paul Hendricks's avatar
Paul Hendricks committed
305
306
        let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

307
        Ok(NvCreateChatCompletionStreamResponse {
Paul Hendricks's avatar
Paul Hendricks committed
308
309
            inner: stream_response,
        })
310
    }
311
312
313
314

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
315
}