completions.rs 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
17
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
18
19
20
use serde::{Deserialize, Serialize};
use validator::Validate;

21
22
use crate::engines::ValidateRequest;

23
24
use super::{
    common::{self, SamplingOptionsProvider, StopConditionsProvider},
25
    common_ext::{CommonExt, CommonExtProvider},
26
    nvext::{NvExt, NvExtProvider},
27
    validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
28
29
};

30
31
32
33
34
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
35

36
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
37
pub struct NvCreateCompletionRequest {
38
39
    #[serde(flatten)]
    pub inner: async_openai::types::CreateCompletionRequest,
40

41
42
43
    #[serde(flatten)]
    pub common: CommonExt,

44
45
46
47
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

48
49
50
51
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
    pub inner: async_openai::types::CreateCompletionResponse,
52
53
}

54
impl ContentProvider for async_openai::types::Choice {
55
56
57
58
59
    fn content(&self) -> String {
        self.text.clone()
    }
}

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
    match prompt {
        async_openai::types::Prompt::String(s) => s.clone(),
        async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        async_openai::types::Prompt::IntegerArray(arr) => arr
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
        async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

83
impl NvExtProvider for NvCreateCompletionRequest {
84
85
86
87
88
89
90
91
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
        if let Some(nvext) = self.nvext.as_ref() {
            if let Some(use_raw_prompt) = nvext.use_raw_prompt {
                if use_raw_prompt {
92
                    return Some(prompt_to_string(&self.inner.prompt));
93
94
95
96
97
98
99
                }
            }
        }
        None
    }
}

100
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
115

116
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
117
    fn get_temperature(&self) -> Option<f32> {
118
        self.inner.temperature
119
120
121
    }

    fn get_top_p(&self) -> Option<f32> {
122
        self.inner.top_p
123
124
125
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
126
        self.inner.frequency_penalty
127
128
129
    }

    fn get_presence_penalty(&self) -> Option<f32> {
130
        self.inner.presence_penalty
131
132
133
134
135
136
137
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
        self.common
            .guided_json
            .as_ref()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
    }

    fn get_guided_regex(&self) -> Option<String> {
        self.common
            .guided_regex
            .clone()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_regex.clone()))
    }

    fn get_guided_grammar(&self) -> Option<String> {
        self.common
            .guided_grammar
            .clone()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_grammar.clone()))
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
        self.common
            .guided_choice
            .clone()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_choice.clone()))
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
        self.common.guided_decoding_backend.clone().or_else(|| {
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_decoding_backend.clone())
        })
    }
}

181
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
182
    fn get_max_tokens(&self) -> Option<u32> {
183
        self.inner.max_tokens
184
185
    }

Paul Hendricks's avatar
Paul Hendricks committed
186
    fn get_min_tokens(&self) -> Option<u32> {
187
        self.common.min_tokens
188
189
190
    }

    fn get_stop(&self) -> Option<Vec<String>> {
191
        None
192
193
194
195
196
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
197
198
199
200

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

217
218
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
219
220
221
222
223
224
225
226
227
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
228
        choice: async_openai::types::Choice,
229
        usage: Option<async_openai::types::CompletionUsage>,
230
231
    ) -> NvCreateCompletionResponse {
        let inner = async_openai::types::CreateCompletionResponse {
232
233
            id: self.id.clone(),
            object: self.object.clone(),
234
            created: self.created,
235
236
237
238
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
239
240
        };
        NvCreateCompletionResponse { inner }
241
242
243
244
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
245
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
246
247
    type Error = anyhow::Error;

248
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

270
        if request.inner.suffix.is_some() {
271
272
273
274
275
276
277
278
279
280
281
282
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

        let prompt = common::PromptType::Completion(common::CompletionContext {
283
            prompt: prompt_to_string(&request.inner.prompt),
284
285
286
287
288
289
290
291
292
293
294
295
296
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
            mdc_sum: None,
            annotations: None,
        })
    }
}

297
impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choice {
298
299
300
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
301
302
303
304
305
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

306
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
307
        // so we're fairly safe knowing we won't generate that many Choices
308
309
310
311
312
313
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
314
315
316
317
318
319
320
321
322
323
324
325

        // TODO handle aggregating logprobs
        let logprobs = None;

        let finish_reason: Option<async_openai::types::CompletionFinishReason> =
            response.delta.finish_reason.map(Into::into);

        let choice = async_openai::types::Choice {
            text,
            index,
            logprobs,
            finish_reason,
326
327
328
329
330
        };

        Ok(choice)
    }
}
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

        Ok(())
    }
}