completions.rs 9.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
17
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
18
19
20
use serde::{Deserialize, Serialize};
use validator::Validate;

21
22
use crate::engines::ValidateRequest;

23
24
25
use super::{
    common::{self, SamplingOptionsProvider, StopConditionsProvider},
    nvext::{NvExt, NvExtProvider},
26
    validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
27
28
};

29
30
31
32
33
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
34

35
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
36
pub struct NvCreateCompletionRequest {
37
38
    #[serde(flatten)]
    pub inner: async_openai::types::CreateCompletionRequest,
39
40
41
42
43

    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

44
45
46
47
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
    pub inner: async_openai::types::CreateCompletionResponse,
48
49
}

50
impl ContentProvider for async_openai::types::Choice {
51
52
53
54
55
    fn content(&self) -> String {
        self.text.clone()
    }
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
    match prompt {
        async_openai::types::Prompt::String(s) => s.clone(),
        async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        async_openai::types::Prompt::IntegerArray(arr) => arr
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
        async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

79
impl NvExtProvider for NvCreateCompletionRequest {
80
81
82
83
84
85
86
87
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
        if let Some(nvext) = self.nvext.as_ref() {
            if let Some(use_raw_prompt) = nvext.use_raw_prompt {
                if use_raw_prompt {
88
                    return Some(prompt_to_string(&self.inner.prompt));
89
90
91
92
93
94
95
                }
            }
        }
        None
    }
}

96
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
111

112
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
113
    fn get_temperature(&self) -> Option<f32> {
114
        self.inner.temperature
115
116
117
    }

    fn get_top_p(&self) -> Option<f32> {
118
        self.inner.top_p
119
120
121
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
122
        self.inner.frequency_penalty
123
124
125
    }

    fn get_presence_penalty(&self) -> Option<f32> {
126
        self.inner.presence_penalty
127
128
129
130
131
132
133
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

134
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
135
    fn get_max_tokens(&self) -> Option<u32> {
136
        self.inner.max_tokens
137
138
    }

Paul Hendricks's avatar
Paul Hendricks committed
139
    fn get_min_tokens(&self) -> Option<u32> {
140
        None
141
142
143
    }

    fn get_stop(&self) -> Option<Vec<String>> {
144
        None
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

166
167
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
168
169
170
171
172
173
174
175
176
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
177
        choice: async_openai::types::Choice,
178
        usage: Option<async_openai::types::CompletionUsage>,
179
180
    ) -> NvCreateCompletionResponse {
        let inner = async_openai::types::CreateCompletionResponse {
181
182
            id: self.id.clone(),
            object: self.object.clone(),
183
            created: self.created,
184
185
186
187
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
188
189
        };
        NvCreateCompletionResponse { inner }
190
191
192
193
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
194
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
195
196
    type Error = anyhow::Error;

197
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

219
        if request.inner.suffix.is_some() {
220
221
222
223
224
225
226
227
228
229
230
231
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

        let prompt = common::PromptType::Completion(common::CompletionContext {
232
            prompt: prompt_to_string(&request.inner.prompt),
233
234
235
236
237
238
239
240
241
242
243
244
245
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
            mdc_sum: None,
            annotations: None,
        })
    }
}

246
impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choice {
247
248
249
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
250
251
252
253
254
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

255
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
256
        // so we're fairly safe knowing we won't generate that many Choices
257
258
259
260
261
262
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
263
264
265
266
267
268
269
270
271
272
273
274

        // TODO handle aggregating logprobs
        let logprobs = None;

        let finish_reason: Option<async_openai::types::CompletionFinishReason> =
            response.delta.finish_reason.map(Into::into);

        let choice = async_openai::types::Choice {
            text,
            index,
            logprobs,
            finish_reason,
275
276
277
278
279
        };

        Ok(choice)
    }
}
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

        Ok(())
    }
}