completions.rs 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
17
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
18
19
20
use serde::{Deserialize, Serialize};
use validator::Validate;

21
22
use crate::engines::ValidateRequest;

23
use super::{
24
25
    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
    OpenAIStopConditionsProvider,
Greg Clark's avatar
Greg Clark committed
26
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
27
28
29
    common_ext::{
        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
    },
30
    nvext::{NvExt, NvExtProvider},
31
    validate,
32
33
};

34
35
36
37
38
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
39

40
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
41
pub struct NvCreateCompletionRequest {
42
    #[serde(flatten)]
43
    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
44

45
46
47
    #[serde(flatten)]
    pub common: CommonExt,

48
49
50
51
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

52
53
54
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
55
    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
56
57
}

58
impl ContentProvider for dynamo_async_openai::types::Choice {
59
60
61
62
63
    fn content(&self) -> String {
        self.text.clone()
    }
}

64
pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
65
    match prompt {
66
67
68
        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
69
70
71
72
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
73
        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
74
75
76
77
78
79
80
81
82
83
84
85
86
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

87
impl NvExtProvider for NvCreateCompletionRequest {
88
89
90
91
92
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
93
94
95
96
97
        if let Some(nvext) = self.nvext.as_ref()
            && let Some(use_raw_prompt) = nvext.use_raw_prompt
            && use_raw_prompt
        {
            return Some(prompt_to_string(&self.inner.prompt));
98
99
100
101
102
        }
        None
    }
}

103
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
118

119
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
120
    fn get_temperature(&self) -> Option<f32> {
121
        self.inner.temperature
122
123
124
    }

    fn get_top_p(&self) -> Option<f32> {
125
        self.inner.top_p
126
127
128
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
129
        self.inner.frequency_penalty
130
131
132
    }

    fn get_presence_penalty(&self) -> Option<f32> {
133
        self.inner.presence_penalty
134
135
136
137
138
139
140
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

141
142
143
144
145
146
147
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
148
149
150
151
152
153
        // Note: This one needs special handling since it returns a reference
        if let Some(nvext) = &self.nvext
            && nvext.guided_json.is_some()
        {
            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
        }
154
155
156
157
158
159
160
        self.common
            .guided_json
            .as_ref()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
    }

    fn get_guided_regex(&self) -> Option<String> {
161
162
163
164
165
        choose_with_deprecation(
            "guided_regex",
            self.common.guided_regex.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
        )
166
167
168
    }

    fn get_guided_grammar(&self) -> Option<String> {
169
170
171
172
173
174
175
        choose_with_deprecation(
            "guided_grammar",
            self.common.guided_grammar.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_grammar.as_ref()),
        )
176
177
178
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
179
180
181
182
183
        choose_with_deprecation(
            "guided_choice",
            self.common.guided_choice.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
        )
184
185
186
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
187
188
189
        choose_with_deprecation(
            "guided_decoding_backend",
            self.common.guided_decoding_backend.as_ref(),
190
191
            self.nvext
                .as_ref()
192
193
                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
        )
194
    }
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    fn get_top_k(&self) -> Option<i32> {
        choose_with_deprecation(
            "top_k",
            self.common.top_k.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
        )
    }

    fn get_repetition_penalty(&self) -> Option<f32> {
        choose_with_deprecation(
            "repetition_penalty",
            self.common.repetition_penalty.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.repetition_penalty.as_ref()),
        )
    }
213
214
}

215
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
216
    fn get_max_tokens(&self) -> Option<u32> {
217
        self.inner.max_tokens
218
219
    }

Paul Hendricks's avatar
Paul Hendricks committed
220
    fn get_min_tokens(&self) -> Option<u32> {
221
        self.common.min_tokens
222
223
224
    }

    fn get_stop(&self) -> Option<Vec<String>> {
225
        None
226
227
228
229
230
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
231
232
233
234

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
235
236
237
238
239
240
241
242
243
244

    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
    /// CommonExt (root-level) takes precedence over NvExt.
    fn get_ignore_eos(&self) -> Option<bool> {
        choose_with_deprecation(
            "ignore_eos",
            self.get_common_ignore_eos().as_ref(),
            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
        )
    }
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

261
262
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
263
264
265
266
267
268
269
270
271
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
272
273
        choice: dynamo_async_openai::types::Choice,
        usage: Option<dynamo_async_openai::types::CompletionUsage>,
274
    ) -> NvCreateCompletionResponse {
275
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
276
277
            id: self.id.clone(),
            object: self.object.clone(),
278
            created: self.created,
279
280
281
282
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
283
284
        };
        NvCreateCompletionResponse { inner }
285
286
287
288
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
289
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
290
291
    type Error = anyhow::Error;

292
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

314
        if request.inner.suffix.is_some() {
315
316
317
318
319
320
321
322
323
324
325
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

Greg Clark's avatar
Greg Clark committed
326
327
328
329
        let output_options = request
            .extract_output_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;

330
        let prompt = common::PromptType::Completion(common::CompletionContext {
331
            prompt: prompt_to_string(&request.inner.prompt),
332
333
334
335
336
337
338
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
Greg Clark's avatar
Greg Clark committed
339
            output_options,
340
341
342
343
344
345
            mdc_sum: None,
            annotations: None,
        })
    }
}

346
impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
347
348
349
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
350
351
352
353
354
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

355
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
356
        // so we're fairly safe knowing we won't generate that many Choices
357
358
359
360
361
362
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
363
364
365
366

        // TODO handle aggregating logprobs
        let logprobs = None;

367
        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
368
369
            response.delta.finish_reason.map(Into::into);

370
        let choice = dynamo_async_openai::types::Choice {
371
372
373
374
            text,
            index,
            logprobs,
            finish_reason,
375
376
377
378
379
        };

        Ok(choice)
    }
}
380

Greg Clark's avatar
Greg Clark committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
    fn get_logprobs(&self) -> Option<u32> {
        self.inner.logprobs.map(|logprobs| logprobs as u32)
    }

    fn get_prompt_logprobs(&self) -> Option<u32> {
        self.inner
            .echo
            .and_then(|echo| if echo { Some(1) } else { None })
    }

    fn get_skip_special_tokens(&self) -> Option<bool> {
        None
    }

    fn get_formatted_prompt(&self) -> Option<bool> {
        None
    }
}

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

        Ok(())
    }
}