completions.rs 11.9 KB
Newer Older
1
2
3
4
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use derive_builder::Builder;
5
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
6
7
8
use serde::{Deserialize, Serialize};
use validator::Validate;

9
10
use crate::engines::ValidateRequest;

11
use super::{
12
13
    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
    OpenAIStopConditionsProvider,
Greg Clark's avatar
Greg Clark committed
14
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
15
    common_ext::{CommonExt, CommonExtProvider},
16
    nvext::{NvExt, NvExtProvider},
17
    validate,
18
19
};

20
21
22
23
24
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
25

26
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
27
pub struct NvCreateCompletionRequest {
28
    #[serde(flatten)]
29
    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
30

31
32
33
    #[serde(flatten)]
    pub common: CommonExt,

34
35
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
36
37
38
39

    // metadata - passthrough parameter without restrictions
    #[serde(skip_serializing_if = "Option::is_none")]
    pub metadata: Option<serde_json::Value>,
40
41
}

42
43
44
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
45
    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
46
47
}

48
impl ContentProvider for dynamo_async_openai::types::Choice {
49
50
51
52
53
    fn content(&self) -> String {
        self.text.clone()
    }
}

54
pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
55
    match prompt {
56
57
58
        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
59
60
61
62
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
63
        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
64
65
66
67
68
69
70
71
72
73
74
75
76
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

77
impl NvExtProvider for NvCreateCompletionRequest {
78
79
80
81
82
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
83
84
85
86
87
        if let Some(nvext) = self.nvext.as_ref()
            && let Some(use_raw_prompt) = nvext.use_raw_prompt
            && use_raw_prompt
        {
            return Some(prompt_to_string(&self.inner.prompt));
88
89
90
91
92
        }
        None
    }
}

93
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
108

109
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
110
    fn get_temperature(&self) -> Option<f32> {
111
        self.inner.temperature
112
113
114
    }

    fn get_top_p(&self) -> Option<f32> {
115
        self.inner.top_p
116
117
118
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
119
        self.inner.frequency_penalty
120
121
122
    }

    fn get_presence_penalty(&self) -> Option<f32> {
123
        self.inner.presence_penalty
124
125
126
127
128
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
129
130
131
132
133
134
135
136
137
138
139
140

    fn get_seed(&self) -> Option<i64> {
        self.inner.seed
    }

    fn get_n(&self) -> Option<u8> {
        self.inner.n
    }

    fn get_best_of(&self) -> Option<u8> {
        self.inner.best_of
    }
141
142
}

143
144
145
146
147
148
149
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
150
        self.common.guided_json.as_ref()
151
152
153
    }

    fn get_guided_regex(&self) -> Option<String> {
154
        self.common.guided_regex.clone()
155
156
157
    }

    fn get_guided_grammar(&self) -> Option<String> {
158
        self.common.guided_grammar.clone()
159
160
161
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
162
        self.common.guided_choice.clone()
163
164
165
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
166
        self.common.guided_decoding_backend.clone()
167
    }
168

169
    fn get_guided_whitespace_pattern(&self) -> Option<String> {
170
        self.common.guided_whitespace_pattern.clone()
171
172
    }

173
    fn get_top_k(&self) -> Option<i32> {
174
        self.common.top_k
175
176
    }

177
    fn get_min_p(&self) -> Option<f32> {
178
        self.common.min_p
179
180
    }

181
    fn get_repetition_penalty(&self) -> Option<f32> {
182
        self.common.repetition_penalty
183
    }
184
185
186
187

    fn get_include_stop_str_in_output(&self) -> Option<bool> {
        self.common.include_stop_str_in_output
    }
188
189
}

190
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
191
    fn get_max_tokens(&self) -> Option<u32> {
192
        self.inner.max_tokens
193
194
    }

Paul Hendricks's avatar
Paul Hendricks committed
195
    fn get_min_tokens(&self) -> Option<u32> {
196
        self.common.min_tokens
197
198
199
    }

    fn get_stop(&self) -> Option<Vec<String>> {
200
        None
201
202
203
204
205
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
206
207
208
209

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
210

211
    /// Get the effective ignore_eos value from CommonExt.
212
    fn get_ignore_eos(&self) -> Option<bool> {
213
        self.common.ignore_eos
214
    }
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

231
232
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
233
234
235
236
237
238
239
240
241
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
242
243
        choice: dynamo_async_openai::types::Choice,
        usage: Option<dynamo_async_openai::types::CompletionUsage>,
244
    ) -> NvCreateCompletionResponse {
245
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
246
247
            id: self.id.clone(),
            object: self.object.clone(),
248
            created: self.created,
249
250
251
252
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
253
254
        };
        NvCreateCompletionResponse { inner }
255
256
257
258
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
259
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
260
261
    type Error = anyhow::Error;

262
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

284
        if request.inner.suffix.is_some() {
285
286
287
288
289
290
291
292
293
294
295
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

Greg Clark's avatar
Greg Clark committed
296
297
298
299
        let output_options = request
            .extract_output_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;

300
        let prompt = common::PromptType::Completion(common::CompletionContext {
301
            prompt: prompt_to_string(&request.inner.prompt),
302
303
304
305
306
307
308
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
Greg Clark's avatar
Greg Clark committed
309
            output_options,
310
311
312
313
314
315
            mdc_sum: None,
            annotations: None,
        })
    }
}

316
impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
317
318
319
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
320
321
322
323
324
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

325
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
326
        // so we're fairly safe knowing we won't generate that many Choices
327
328
329
330
331
332
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
333
334
335
336

        // TODO handle aggregating logprobs
        let logprobs = None;

337
        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
338
339
            response.delta.finish_reason.map(Into::into);

340
        let choice = dynamo_async_openai::types::Choice {
341
342
343
344
            text,
            index,
            logprobs,
            finish_reason,
345
346
347
348
349
        };

        Ok(choice)
    }
}
350

Greg Clark's avatar
Greg Clark committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
    fn get_logprobs(&self) -> Option<u32> {
        self.inner.logprobs.map(|logprobs| logprobs as u32)
    }

    fn get_prompt_logprobs(&self) -> Option<u32> {
        self.inner
            .echo
            .and_then(|echo| if echo { Some(1) } else { None })
    }

    fn get_skip_special_tokens(&self) -> Option<bool> {
        None
    }

    fn get_formatted_prompt(&self) -> Option<bool> {
        None
    }
}

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed
393
        // none for metadata
394

395
396
        // Common Ext
        validate::validate_repetition_penalty(self.get_repetition_penalty())?;
397
398
        validate::validate_min_p(self.get_min_p())?;
        validate::validate_top_k(self.get_top_k())?;
399

400
401
402
        Ok(())
    }
}