delta.rs 8.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
17
18
use crate::protocols::common;

19
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
20
impl NvCreateChatCompletionRequest {
21
22
23
24
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
25
26
27
    pub fn response_generator(&self) -> DeltaGenerator {
        let options = DeltaGeneratorOptions {
            enable_usage: true,
Paul Hendricks's avatar
Paul Hendricks committed
28
            enable_logprobs: self.inner.logprobs.unwrap_or(false),
29
30
        };

Paul Hendricks's avatar
Paul Hendricks committed
31
        DeltaGenerator::new(self.inner.model.clone(), options)
32
33
34
    }
}

35
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
36
37
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
38
    /// Determines whether token usage statistics should be included in the response.
39
    pub enable_usage: bool,
40
    /// Determines whether log probabilities should be included in the response.
41
42
43
    pub enable_logprobs: bool,
}

44
/// Generates incremental chat completion responses in a streaming fashion.
45
46
#[derive(Debug, Clone)]
pub struct DeltaGenerator {
47
    /// Unique identifier for the chat completion session.
48
    id: String,
49
    /// Object type, representing a streamed chat completion response.
50
    object: String,
51
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
52
    created: u32,
53
    /// Model name used for generating responses.
54
    model: String,
55
    /// Optional system fingerprint for version tracking.
56
    system_fingerprint: Option<String>,
57
    /// Optional service tier information for the response.
Paul Hendricks's avatar
Paul Hendricks committed
58
    service_tier: Option<async_openai::types::ServiceTierResponse>,
59
    /// Tracks token usage for the completion request.
Paul Hendricks's avatar
Paul Hendricks committed
60
    usage: async_openai::types::CompletionUsage,
61
    /// Counter tracking the number of messages issued.
62
    msg_counter: u64,
63
    /// Configuration options for response generation.
64
65
66
67
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
68
69
70
71
72
73
74
75
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
76
77
78
79
    pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
80
81
82
83
84
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
85
86
87
88
89
90
91
92

        let usage = async_openai::types::CompletionUsage {
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
93
94
95
96
97
98
99
100

        Self {
            id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
101
            usage,
102
103
104
105
106
            msg_counter: 0,
            options,
        }
    }

107
108
109
110
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
111
    pub fn update_isl(&mut self, isl: u32) {
112
113
114
        self.usage.prompt_tokens = isl;
    }

115
116
117
118
119
120
121
122
123
124
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
    ///
    /// # Returns
    /// * An [`async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
125
    #[allow(deprecated)]
126
127
    pub fn create_choice(
        &self,
Paul Hendricks's avatar
Paul Hendricks committed
128
        index: u32,
129
        text: Option<String>,
Paul Hendricks's avatar
Paul Hendricks committed
130
131
132
133
134
        finish_reason: Option<async_openai::types::FinishReason>,
        logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
    ) -> async_openai::types::CreateChatCompletionStreamResponse {
        // TODO: Update for tool calling
        let delta = async_openai::types::ChatCompletionStreamResponseDelta {
135
            role: if self.msg_counter == 0 {
Paul Hendricks's avatar
Paul Hendricks committed
136
                Some(async_openai::types::Role::Assistant)
137
138
139
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
140
            content: text,
141
            tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
142
143
            function_call: None,
            refusal: None,
144
145
        };

Paul Hendricks's avatar
Paul Hendricks committed
146
147
148
149
150
151
152
153
154
        let choice = async_openai::types::ChatChoiceStream {
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

155
156
157
158
159
        let mut usage = self.usage.clone();
        if self.options.enable_usage {
            usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
        }

Paul Hendricks's avatar
Paul Hendricks committed
160
        async_openai::types::CreateChatCompletionStreamResponse {
161
162
163
164
165
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
Paul Hendricks's avatar
Paul Hendricks committed
166
            choices,
167
            usage: if self.options.enable_usage {
168
                Some(usage)
169
170
171
172
173
174
175
176
            } else {
                None
            },
            service_tier: self.service_tier.clone(),
        }
    }
}

177
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
178
/// it to transform backend responses into OpenAI-style streaming responses.
179
180
181
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
182
183
184
185
186
187
188
189
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
190
191
192
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
193
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
194
        // Aggregate token usage if enabled.
195
        if self.options.enable_usage {
196
197
198
199
200
201
202
203
204
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
205
206
        }

207
        // TODO: Implement log probabilities aggregation.
208
209
        let logprobs = None;

210
        // Map backend finish reasons to OpenAI's finish reasons.
211
        let finish_reason = match delta.finish_reason {
Paul Hendricks's avatar
Paul Hendricks committed
212
213
214
215
            Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
            Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
216
217
218
            Some(common::FinishReason::ContentFilter) => {
                Some(async_openai::types::FinishReason::ContentFilter)
            }
219
220
221
222
223
224
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

225
        // Create the streaming response.
226
        let index = 0;
Paul Hendricks's avatar
Paul Hendricks committed
227
228
        let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

229
        Ok(NvCreateChatCompletionStreamResponse {
Paul Hendricks's avatar
Paul Hendricks committed
230
231
            inner: stream_response,
        })
232
    }
233
234
235
236

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
237
}