delta.rs 7.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
Greg Clark's avatar
Greg Clark committed
17
use crate::{protocols::common, types::TokenIdType};
18

19
impl NvCreateCompletionRequest {
20
21
22
23
24
    // put this method on the request
    // inspect the request to extract options
    pub fn response_generator(&self) -> DeltaGenerator {
        let options = DeltaGeneratorOptions {
            enable_usage: true,
Greg Clark's avatar
Greg Clark committed
25
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
26
27
        };

28
        DeltaGenerator::new(self.inner.model.clone(), options)
29
30
31
32
33
34
35
36
37
38
39
40
41
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
}

#[derive(Debug, Clone)]
pub struct DeltaGenerator {
    id: String,
    object: String,
42
    created: u32,
43
44
    model: String,
    system_fingerprint: Option<String>,
45
    usage: dynamo_async_openai::types::CompletionUsage,
46
47
48
49
50
51
52
53
54
55
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
    pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

56
57
58
59
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

60
61
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
62
        let usage = dynamo_async_openai::types::CompletionUsage {
63
64
65
66
67
68
69
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

70
71
72
73
74
75
        Self {
            id: format!("cmpl-{}", uuid::Uuid::new_v4()),
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
76
            usage,
77
78
79
80
            options,
        }
    }

81
    pub fn update_isl(&mut self, isl: u32) {
82
83
84
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
85
86
87
88
89
90
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
91
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
119
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
120
121
122
123
124
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
125
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
126
127
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
128
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
129
130
131
132
133
134
135
136
137
138
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

139
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
140
141
142
143
144
145
146
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

147
148
    pub fn create_choice(
        &self,
149
        index: u32,
150
        text: Option<String>,
151
152
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
153
    ) -> NvCreateCompletionResponse {
154
155
        // todo - update for tool calling

156
157
158
159
160
        let mut usage = self.usage.clone();
        if self.options.enable_usage {
            usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
        }

161
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
162
163
            id: self.id.clone(),
            object: self.object.clone(),
164
            created: self.created,
165
166
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
167
            choices: vec![dynamo_async_openai::types::Choice {
168
                text: text.unwrap_or_default(),
169
                index,
170
                finish_reason,
Greg Clark's avatar
Greg Clark committed
171
                logprobs,
172
173
            }],
            usage: if self.options.enable_usage {
174
                Some(usage)
175
176
177
            } else {
                None
            },
178
179
180
        };

        NvCreateCompletionResponse { inner }
181
182
183
    }
}

184
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
185
186
187
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
188
    ) -> anyhow::Result<NvCreateCompletionResponse> {
189
190
        // aggregate usage
        if self.options.enable_usage {
191
192
193
194
195
196
197
198
199
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
200
201
        }

Greg Clark's avatar
Greg Clark committed
202
203
204
205
206
207
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
208
209

        let finish_reason = delta.finish_reason.map(Into::into);
210
211

        // create choice
212
        let index = delta.index.unwrap_or(0);
Greg Clark's avatar
Greg Clark committed
213
        let response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
214
        Ok(response)
215
    }
216
217

    fn get_isl(&self) -> Option<u32> {
218
        Some(self.usage.prompt_tokens)
219
    }
220
}