completions.rs 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use derive_builder::Builder;
17
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
18
19
20
use serde::{Deserialize, Serialize};
use validator::Validate;

21
22
use crate::engines::ValidateRequest;

23
use super::{
24
25
    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
    OpenAIStopConditionsProvider,
Greg Clark's avatar
Greg Clark committed
26
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
27
28
29
    common_ext::{
        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
    },
30
    nvext::{NvExt, NvExtProvider},
31
    validate,
32
33
};

34
35
36
37
38
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
39

40
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
41
pub struct NvCreateCompletionRequest {
42
    #[serde(flatten)]
43
    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
44

45
46
47
    #[serde(flatten)]
    pub common: CommonExt,

48
49
50
51
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

52
53
54
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
55
    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
56
57
}

58
impl ContentProvider for dynamo_async_openai::types::Choice {
59
60
61
62
63
    fn content(&self) -> String {
        self.text.clone()
    }
}

64
pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
65
    match prompt {
66
67
68
        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
69
70
71
72
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
73
        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
74
75
76
77
78
79
80
81
82
83
84
85
86
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

87
impl NvExtProvider for NvCreateCompletionRequest {
88
89
90
91
92
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
93
94
95
96
97
        if let Some(nvext) = self.nvext.as_ref()
            && let Some(use_raw_prompt) = nvext.use_raw_prompt
            && use_raw_prompt
        {
            return Some(prompt_to_string(&self.inner.prompt));
98
99
100
101
102
        }
        None
    }
}

103
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
118

119
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
120
    fn get_temperature(&self) -> Option<f32> {
121
        self.inner.temperature
122
123
124
    }

    fn get_top_p(&self) -> Option<f32> {
125
        self.inner.top_p
126
127
128
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
129
        self.inner.frequency_penalty
130
131
132
    }

    fn get_presence_penalty(&self) -> Option<f32> {
133
        self.inner.presence_penalty
134
135
136
137
138
139
140
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
}

141
142
143
144
145
146
147
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
148
149
150
151
152
153
        // Note: This one needs special handling since it returns a reference
        if let Some(nvext) = &self.nvext
            && nvext.guided_json.is_some()
        {
            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
        }
154
155
156
157
158
159
160
        self.common
            .guided_json
            .as_ref()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
    }

    fn get_guided_regex(&self) -> Option<String> {
161
162
163
164
165
        choose_with_deprecation(
            "guided_regex",
            self.common.guided_regex.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
        )
166
167
168
    }

    fn get_guided_grammar(&self) -> Option<String> {
169
170
171
172
173
174
175
        choose_with_deprecation(
            "guided_grammar",
            self.common.guided_grammar.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_grammar.as_ref()),
        )
176
177
178
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
179
180
181
182
183
        choose_with_deprecation(
            "guided_choice",
            self.common.guided_choice.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
        )
184
185
186
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
187
188
189
        choose_with_deprecation(
            "guided_decoding_backend",
            self.common.guided_decoding_backend.as_ref(),
190
191
            self.nvext
                .as_ref()
192
193
                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
        )
194
195
196
    }
}

197
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
198
    fn get_max_tokens(&self) -> Option<u32> {
199
        self.inner.max_tokens
200
201
    }

Paul Hendricks's avatar
Paul Hendricks committed
202
    fn get_min_tokens(&self) -> Option<u32> {
203
        self.common.min_tokens
204
205
206
    }

    fn get_stop(&self) -> Option<Vec<String>> {
207
        None
208
209
210
211
212
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
213
214
215
216

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
217
218
219
220
221
222
223
224
225
226

    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
    /// CommonExt (root-level) takes precedence over NvExt.
    fn get_ignore_eos(&self) -> Option<bool> {
        choose_with_deprecation(
            "ignore_eos",
            self.get_common_ignore_eos().as_ref(),
            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
        )
    }
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

243
244
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
245
246
247
248
249
250
251
252
253
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
254
255
        choice: dynamo_async_openai::types::Choice,
        usage: Option<dynamo_async_openai::types::CompletionUsage>,
256
    ) -> NvCreateCompletionResponse {
257
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
258
259
            id: self.id.clone(),
            object: self.object.clone(),
260
            created: self.created,
261
262
263
264
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
265
266
        };
        NvCreateCompletionResponse { inner }
267
268
269
270
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
271
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
272
273
    type Error = anyhow::Error;

274
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

296
        if request.inner.suffix.is_some() {
297
298
299
300
301
302
303
304
305
306
307
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

Greg Clark's avatar
Greg Clark committed
308
309
310
311
        let output_options = request
            .extract_output_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;

312
        let prompt = common::PromptType::Completion(common::CompletionContext {
313
            prompt: prompt_to_string(&request.inner.prompt),
314
315
316
317
318
319
320
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
Greg Clark's avatar
Greg Clark committed
321
            output_options,
322
323
324
325
326
327
            mdc_sum: None,
            annotations: None,
        })
    }
}

328
impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
329
330
331
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
332
333
334
335
336
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

337
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
338
        // so we're fairly safe knowing we won't generate that many Choices
339
340
341
342
343
344
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
345
346
347
348

        // TODO handle aggregating logprobs
        let logprobs = None;

349
        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
350
351
            response.delta.finish_reason.map(Into::into);

352
        let choice = dynamo_async_openai::types::Choice {
353
354
355
356
            text,
            index,
            logprobs,
            finish_reason,
357
358
359
360
361
        };

        Ok(choice)
    }
}
362

Greg Clark's avatar
Greg Clark committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
    fn get_logprobs(&self) -> Option<u32> {
        self.inner.logprobs.map(|logprobs| logprobs as u32)
    }

    fn get_prompt_logprobs(&self) -> Option<u32> {
        self.inner
            .echo
            .and_then(|echo| if echo { Some(1) } else { None })
    }

    fn get_skip_special_tokens(&self) -> Option<bool> {
        None
    }

    fn get_formatted_prompt(&self) -> Option<bool> {
        None
    }
}

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

        Ok(())
    }
}