completions.rs 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
17
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
18
19
20
use serde::{Deserialize, Serialize};
use validator::Validate;

21
22
use crate::engines::ValidateRequest;

23
use super::{
24
25
    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
    OpenAIStopConditionsProvider,
Greg Clark's avatar
Greg Clark committed
26
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
27
28
29
    common_ext::{
        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
    },
30
    nvext::{NvExt, NvExtProvider},
31
    validate,
32
33
};

34
35
36
37
38
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
39

40
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
41
pub struct NvCreateCompletionRequest {
42
    #[serde(flatten)]
43
    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
44

45
46
47
    #[serde(flatten)]
    pub common: CommonExt,

48
49
50
51
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

52
53
54
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
55
    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
56
57
}

58
impl ContentProvider for dynamo_async_openai::types::Choice {
59
60
61
62
63
    fn content(&self) -> String {
        self.text.clone()
    }
}

64
pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
65
    match prompt {
66
67
68
        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
69
70
71
72
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
73
        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
74
75
76
77
78
79
80
81
82
83
84
85
86
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

87
impl NvExtProvider for NvCreateCompletionRequest {
88
89
90
91
92
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
93
94
95
96
97
        if let Some(nvext) = self.nvext.as_ref()
            && let Some(use_raw_prompt) = nvext.use_raw_prompt
            && use_raw_prompt
        {
            return Some(prompt_to_string(&self.inner.prompt));
98
99
100
101
102
        }
        None
    }
}

103
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
118

119
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
120
    fn get_temperature(&self) -> Option<f32> {
121
        self.inner.temperature
122
123
124
    }

    fn get_top_p(&self) -> Option<f32> {
125
        self.inner.top_p
126
127
128
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
129
        self.inner.frequency_penalty
130
131
132
    }

    fn get_presence_penalty(&self) -> Option<f32> {
133
        self.inner.presence_penalty
134
135
136
137
138
139
140
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

141
142
143
144
145
146
147
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
148
149
150
151
152
153
        // Note: This one needs special handling since it returns a reference
        if let Some(nvext) = &self.nvext
            && nvext.guided_json.is_some()
        {
            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
        }
154
155
156
157
158
159
160
        self.common
            .guided_json
            .as_ref()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
    }

    fn get_guided_regex(&self) -> Option<String> {
161
162
163
164
165
        choose_with_deprecation(
            "guided_regex",
            self.common.guided_regex.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
        )
166
167
168
    }

    fn get_guided_grammar(&self) -> Option<String> {
169
170
171
172
173
174
175
        choose_with_deprecation(
            "guided_grammar",
            self.common.guided_grammar.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_grammar.as_ref()),
        )
176
177
178
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
179
180
181
182
183
        choose_with_deprecation(
            "guided_choice",
            self.common.guided_choice.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
        )
184
185
186
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
187
188
189
        choose_with_deprecation(
            "guided_decoding_backend",
            self.common.guided_decoding_backend.as_ref(),
190
191
            self.nvext
                .as_ref()
192
193
                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
        )
194
    }
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    fn get_top_k(&self) -> Option<i32> {
        choose_with_deprecation(
            "top_k",
            self.common.top_k.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
        )
    }

    fn get_repetition_penalty(&self) -> Option<f32> {
        choose_with_deprecation(
            "repetition_penalty",
            self.common.repetition_penalty.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.repetition_penalty.as_ref()),
        )
    }
213
214
215
216

    fn get_include_stop_str_in_output(&self) -> Option<bool> {
        self.common.include_stop_str_in_output
    }
217
218
}

219
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
220
    fn get_max_tokens(&self) -> Option<u32> {
221
        self.inner.max_tokens
222
223
    }

Paul Hendricks's avatar
Paul Hendricks committed
224
    fn get_min_tokens(&self) -> Option<u32> {
225
        self.common.min_tokens
226
227
228
    }

    fn get_stop(&self) -> Option<Vec<String>> {
229
        None
230
231
232
233
234
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
235
236
237
238

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
239
240
241
242
243
244
245
246
247
248

    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
    /// CommonExt (root-level) takes precedence over NvExt.
    fn get_ignore_eos(&self) -> Option<bool> {
        choose_with_deprecation(
            "ignore_eos",
            self.get_common_ignore_eos().as_ref(),
            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
        )
    }
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

265
266
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
267
268
269
270
271
272
273
274
275
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
276
277
        choice: dynamo_async_openai::types::Choice,
        usage: Option<dynamo_async_openai::types::CompletionUsage>,
278
    ) -> NvCreateCompletionResponse {
279
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
280
281
            id: self.id.clone(),
            object: self.object.clone(),
282
            created: self.created,
283
284
285
286
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
287
288
        };
        NvCreateCompletionResponse { inner }
289
290
291
292
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
293
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
294
295
    type Error = anyhow::Error;

296
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

318
        if request.inner.suffix.is_some() {
319
320
321
322
323
324
325
326
327
328
329
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

Greg Clark's avatar
Greg Clark committed
330
331
332
333
        let output_options = request
            .extract_output_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;

334
        let prompt = common::PromptType::Completion(common::CompletionContext {
335
            prompt: prompt_to_string(&request.inner.prompt),
336
337
338
339
340
341
342
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
Greg Clark's avatar
Greg Clark committed
343
            output_options,
344
345
346
347
348
349
            mdc_sum: None,
            annotations: None,
        })
    }
}

350
impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
351
352
353
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
354
355
356
357
358
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

359
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
360
        // so we're fairly safe knowing we won't generate that many Choices
361
362
363
364
365
366
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
367
368
369
370

        // TODO handle aggregating logprobs
        let logprobs = None;

371
        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
372
373
            response.delta.finish_reason.map(Into::into);

374
        let choice = dynamo_async_openai::types::Choice {
375
376
377
378
            text,
            index,
            logprobs,
            finish_reason,
379
380
381
382
383
        };

        Ok(choice)
    }
}
384

Greg Clark's avatar
Greg Clark committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
    fn get_logprobs(&self) -> Option<u32> {
        self.inner.logprobs.map(|logprobs| logprobs as u32)
    }

    fn get_prompt_logprobs(&self) -> Option<u32> {
        self.inner
            .echo
            .and_then(|echo| if echo { Some(1) } else { None })
    }

    fn get_skip_special_tokens(&self) -> Option<bool> {
        None
    }

    fn get_formatted_prompt(&self) -> Option<bool> {
        None
    }
}

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

        Ok(())
    }
}