completions.rs 13.8 KB
Newer Older
1
2
3
4
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use derive_builder::Builder;
5
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
6
7
8
use serde::{Deserialize, Serialize};
use validator::Validate;

9
10
use crate::engines::ValidateRequest;

11
use super::{
12
13
    ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
    OpenAIStopConditionsProvider,
Greg Clark's avatar
Greg Clark committed
14
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
15
16
17
    common_ext::{
        CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
    },
18
    nvext::{NvExt, NvExtProvider},
19
    validate,
20
21
};

22
23
24
25
26
mod aggregator;
mod delta;

pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
Biswa Panda's avatar
Biswa Panda committed
27

28
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
29
pub struct NvCreateCompletionRequest {
30
    #[serde(flatten)]
31
    pub inner: dynamo_async_openai::types::CreateCompletionRequest,
32

33
34
35
    #[serde(flatten)]
    pub common: CommonExt,

36
37
38
39
    #[serde(skip_serializing_if = "Option::is_none")]
    pub nvext: Option<NvExt>,
}

40
41
42
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
    #[serde(flatten)]
43
    pub inner: dynamo_async_openai::types::CreateCompletionResponse,
44
45
}

46
impl ContentProvider for dynamo_async_openai::types::Choice {
47
48
49
50
51
    fn content(&self) -> String {
        self.text.clone()
    }
}

52
pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
53
    match prompt {
54
55
56
        dynamo_async_openai::types::Prompt::String(s) => s.clone(),
        dynamo_async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
        dynamo_async_openai::types::Prompt::IntegerArray(arr) => arr
57
58
59
60
            .iter()
            .map(|&num| num.to_string())
            .collect::<Vec<_>>()
            .join(" "),
61
        dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
62
63
64
65
66
67
68
69
70
71
72
73
74
            .iter()
            .map(|inner| {
                inner
                    .iter()
                    .map(|&num| num.to_string())
                    .collect::<Vec<_>>()
                    .join(" ")
            })
            .collect::<Vec<_>>()
            .join(" | "), // Separate arrays with a delimiter
    }
}

75
impl NvExtProvider for NvCreateCompletionRequest {
76
77
78
79
80
    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }

    fn raw_prompt(&self) -> Option<String> {
81
82
83
84
85
        if let Some(nvext) = self.nvext.as_ref()
            && let Some(use_raw_prompt) = nvext.use_raw_prompt
            && use_raw_prompt
        {
            return Some(prompt_to_string(&self.inner.prompt));
86
87
88
89
90
        }
        None
    }
}

91
impl AnnotationsProvider for NvCreateCompletionRequest {
Biswa Panda's avatar
Biswa Panda committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    fn annotations(&self) -> Option<Vec<String>> {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.clone())
    }

    fn has_annotation(&self, annotation: &str) -> bool {
        self.nvext
            .as_ref()
            .and_then(|nvext| nvext.annotations.as_ref())
            .map(|annotations| annotations.contains(&annotation.to_string()))
            .unwrap_or(false)
    }
}
106

107
impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
108
    fn get_temperature(&self) -> Option<f32> {
109
        self.inner.temperature
110
111
112
    }

    fn get_top_p(&self) -> Option<f32> {
113
        self.inner.top_p
114
115
116
    }

    fn get_frequency_penalty(&self) -> Option<f32> {
117
        self.inner.frequency_penalty
118
119
120
    }

    fn get_presence_penalty(&self) -> Option<f32> {
121
        self.inner.presence_penalty
122
123
124
125
126
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
127
128
129
130
131
132
133
134
135
136
137
138

    fn get_seed(&self) -> Option<i64> {
        self.inner.seed
    }

    fn get_n(&self) -> Option<u8> {
        self.inner.n
    }

    fn get_best_of(&self) -> Option<u8> {
        self.inner.best_of
    }
139
140
}

141
142
143
144
145
146
147
impl CommonExtProvider for NvCreateCompletionRequest {
    fn common_ext(&self) -> Option<&CommonExt> {
        Some(&self.common)
    }

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value> {
148
149
150
151
152
153
        // Note: This one needs special handling since it returns a reference
        if let Some(nvext) = &self.nvext
            && nvext.guided_json.is_some()
        {
            emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
        }
154
155
156
157
158
159
160
        self.common
            .guided_json
            .as_ref()
            .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
    }

    fn get_guided_regex(&self) -> Option<String> {
161
162
163
164
165
        choose_with_deprecation(
            "guided_regex",
            self.common.guided_regex.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
        )
166
167
168
    }

    fn get_guided_grammar(&self) -> Option<String> {
169
170
171
172
173
174
175
        choose_with_deprecation(
            "guided_grammar",
            self.common.guided_grammar.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_grammar.as_ref()),
        )
176
177
178
    }

    fn get_guided_choice(&self) -> Option<Vec<String>> {
179
180
181
182
183
        choose_with_deprecation(
            "guided_choice",
            self.common.guided_choice.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
        )
184
185
186
    }

    fn get_guided_decoding_backend(&self) -> Option<String> {
187
188
189
        choose_with_deprecation(
            "guided_decoding_backend",
            self.common.guided_decoding_backend.as_ref(),
190
191
            self.nvext
                .as_ref()
192
193
                .and_then(|nv| nv.guided_decoding_backend.as_ref()),
        )
194
    }
195

196
197
198
199
200
201
202
203
204
205
    fn get_guided_whitespace_pattern(&self) -> Option<String> {
        choose_with_deprecation(
            "guided_whitespace_pattern",
            self.common.guided_whitespace_pattern.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
        )
    }

206
207
208
209
210
211
212
213
    fn get_top_k(&self) -> Option<i32> {
        choose_with_deprecation(
            "top_k",
            self.common.top_k.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
        )
    }

214
215
216
217
218
219
220
221
    fn get_min_p(&self) -> Option<f32> {
        choose_with_deprecation(
            "min_p",
            self.common.min_p.as_ref(),
            self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
        )
    }

222
223
224
225
226
227
228
229
230
    fn get_repetition_penalty(&self) -> Option<f32> {
        choose_with_deprecation(
            "repetition_penalty",
            self.common.repetition_penalty.as_ref(),
            self.nvext
                .as_ref()
                .and_then(|nv| nv.repetition_penalty.as_ref()),
        )
    }
231
232
233
234

    fn get_include_stop_str_in_output(&self) -> Option<bool> {
        self.common.include_stop_str_in_output
    }
235
236
}

237
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
Paul Hendricks's avatar
Paul Hendricks committed
238
    fn get_max_tokens(&self) -> Option<u32> {
239
        self.inner.max_tokens
240
241
    }

Paul Hendricks's avatar
Paul Hendricks committed
242
    fn get_min_tokens(&self) -> Option<u32> {
243
        self.common.min_tokens
244
245
246
    }

    fn get_stop(&self) -> Option<Vec<String>> {
247
        None
248
249
250
251
252
    }

    fn nvext(&self) -> Option<&NvExt> {
        self.nvext.as_ref()
    }
253
254
255
256

    fn get_common_ignore_eos(&self) -> Option<bool> {
        self.common.ignore_eos
    }
257
258
259
260
261
262
263
264
265
266

    /// Get the effective ignore_eos value, considering both CommonExt and NvExt.
    /// CommonExt (root-level) takes precedence over NvExt.
    fn get_ignore_eos(&self) -> Option<bool> {
        choose_with_deprecation(
            "ignore_eos",
            self.get_common_ignore_eos().as_ref(),
            NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
        )
    }
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
}

#[derive(Builder)]
pub struct ResponseFactory {
    #[builder(setter(into))]
    pub model: String,

    #[builder(default)]
    pub system_fingerprint: Option<String>,

    #[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
    pub id: String,

    #[builder(default = "\"text_completion\".to_string()")]
    pub object: String,

283
284
    #[builder(default = "chrono::Utc::now().timestamp() as u32")]
    pub created: u32,
285
286
287
288
289
290
291
292
293
}

impl ResponseFactory {
    pub fn builder() -> ResponseFactoryBuilder {
        ResponseFactoryBuilder::default()
    }

    pub fn make_response(
        &self,
294
295
        choice: dynamo_async_openai::types::Choice,
        usage: Option<dynamo_async_openai::types::CompletionUsage>,
296
    ) -> NvCreateCompletionResponse {
297
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
298
299
            id: self.id.clone(),
            object: self.object.clone(),
300
            created: self.created,
301
302
303
304
            model: self.model.clone(),
            choices: vec![choice],
            system_fingerprint: self.system_fingerprint.clone(),
            usage,
305
306
        };
        NvCreateCompletionResponse { inner }
307
308
309
310
    }
}

/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
311
impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
312
313
    type Error = anyhow::Error;

314
    fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        // openai_api_rs::v1::completion::CompletionRequest {
        // NA  pub model: String,
        //     pub prompt: String,
        // **  pub suffix: Option<String>,
        //     pub max_tokens: Option<i32>,
        //     pub temperature: Option<f32>,
        //     pub top_p: Option<f32>,
        //     pub n: Option<i32>,
        //     pub stream: Option<bool>,
        //     pub logprobs: Option<i32>,
        //     pub echo: Option<bool>,
        //     pub stop: Option<Vec<String, Global>>,
        //     pub presence_penalty: Option<f32>,
        //     pub frequency_penalty: Option<f32>,
        //     pub best_of: Option<i32>,
        //     pub logit_bias: Option<HashMap<String, i32, RandomState>>,
        //     pub user: Option<String>,
        // }
        //
        // ** no supported

336
        if request.inner.suffix.is_some() {
337
338
339
340
341
342
343
344
345
346
347
            return Err(anyhow::anyhow!("suffix is not supported"));
        }

        let stop_conditions = request
            .extract_stop_conditions()
            .map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;

        let sampling_options = request
            .extract_sampling_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;

Greg Clark's avatar
Greg Clark committed
348
349
350
351
        let output_options = request
            .extract_output_options()
            .map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;

352
        let prompt = common::PromptType::Completion(common::CompletionContext {
353
            prompt: prompt_to_string(&request.inner.prompt),
354
355
356
357
358
359
360
            system_prompt: None,
        });

        Ok(common::CompletionRequest {
            prompt,
            stop_conditions,
            sampling_options,
Greg Clark's avatar
Greg Clark committed
361
            output_options,
362
363
364
365
366
367
            mdc_sum: None,
            annotations: None,
        })
    }
}

368
impl TryFrom<common::StreamingCompletionResponse> for dynamo_async_openai::types::Choice {
369
370
371
    type Error = anyhow::Error;

    fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
372
373
374
375
376
        let text = response
            .delta
            .text
            .ok_or(anyhow::anyhow!("No text in response"))?;

377
        // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
378
        // so we're fairly safe knowing we won't generate that many Choices
379
380
381
382
383
384
        let index: u32 = response
            .delta
            .index
            .unwrap_or(0)
            .try_into()
            .expect("index exceeds u32::MAX");
385
386
387
388

        // TODO handle aggregating logprobs
        let logprobs = None;

389
        let finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason> =
390
391
            response.delta.finish_reason.map(Into::into);

392
        let choice = dynamo_async_openai::types::Choice {
393
394
395
396
            text,
            index,
            logprobs,
            finish_reason,
397
398
399
400
401
        };

        Ok(choice)
    }
}
402

Greg Clark's avatar
Greg Clark committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
    fn get_logprobs(&self) -> Option<u32> {
        self.inner.logprobs.map(|logprobs| logprobs as u32)
    }

    fn get_prompt_logprobs(&self) -> Option<u32> {
        self.inner
            .echo
            .and_then(|echo| if echo { Some(1) } else { None })
    }

    fn get_skip_special_tokens(&self) -> Option<bool> {
        None
    }

    fn get_formatted_prompt(&self) -> Option<bool> {
        None
    }
}

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
    fn validate(&self) -> Result<(), anyhow::Error> {
        validate::validate_model(&self.inner.model)?;
        validate::validate_prompt(&self.inner.prompt)?;
        validate::validate_suffix(self.inner.suffix.as_deref())?;
        validate::validate_max_tokens(self.inner.max_tokens)?;
        validate::validate_temperature(self.inner.temperature)?;
        validate::validate_top_p(self.inner.top_p)?;
        validate::validate_n(self.inner.n)?;
        // none for stream
        // none for stream_options
        validate::validate_logprobs(self.inner.logprobs)?;
        // none for echo
        validate::validate_stop(&self.inner.stop)?;
        validate::validate_presence_penalty(self.inner.presence_penalty)?;
        validate::validate_frequency_penalty(self.inner.frequency_penalty)?;
        validate::validate_best_of(self.inner.best_of, self.inner.n)?;
        validate::validate_logit_bias(&self.inner.logit_bias)?;
        validate::validate_user(self.inner.user.as_deref())?;
        // none for seed

446
447
        // Common Ext
        validate::validate_repetition_penalty(self.get_repetition_penalty())?;
448
449
        validate::validate_min_p(self.get_min_p())?;
        validate::validate_top_k(self.get_top_k())?;
450

451
452
453
        Ok(())
    }
}