common_ext.rs 10.9 KB
Newer Older
1
2
3
4
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
5
use utoipa::ToSchema;
6
7
8
9
use validator::Validate;

/// Common extensions for OpenAI API requests that are not part of the standard OpenAI spec
/// but are commonly needed across different request types.
10
#[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
11
12
13
14
15
16
17
18
19
20
21
22
23
pub struct CommonExt {
    /// If true, the model will ignore the end of string token and generate to max_tokens.
    /// This field can also be specified in nvext, but the root-level value takes precedence.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub ignore_eos: Option<bool>,

    /// The minimum number of tokens to generate.
    /// This is a common parameter needed across different request types.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_tokens: Option<u32>,

24
25
26
27
28
    /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub top_k: Option<i32>,

29
30
31
32
33
    /// Relative probability floor
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_p: Option<f32>,

34
35
36
37
38
39
    /// How much to penalize tokens based on how frequently they occur in the text.
    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub repetition_penalty: Option<f32>,

40
41
42
43
44
    /// include_stop_str_in_output
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub include_stop_str_in_output: Option<bool>,

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    /// Guided Decoding Options
    /// If specified, the output will be a JSON object. Can be a string, an object, or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_json: Option<serde_json::Value>,

    /// If specified, the output will follow the regex pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_regex: Option<String>,

    /// If specified, the output will follow the context-free grammar. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_grammar: Option<String>,

    /// If specified, the output will be exactly one of the choices.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_choice: Option<Vec<String>>,

    /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_decoding_backend: Option<String>,
70
71
72
73
74
75

    /// If specified, the output will follow the whitespace pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    #[allow(unused)] // Not used
    pub guided_whitespace_pattern: Option<String>,
76
77
78
79
80
81
82
83

    /// Whether to skip special tokens in the decoded output.
    /// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
    /// When false, special tokens are included in the output text.
    /// Defaults to false if not specified.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub skip_special_tokens: Option<bool>,
84
85
86
87
88
89
90
91
92
93
94
95
96
97
}

impl CommonExt {
    pub fn builder() -> CommonExtBuilder {
        CommonExtBuilder::default()
    }
}

/// Trait for types that provide CommonExt fields
pub trait CommonExtProvider {
    /// Get a reference to the CommonExt struct if available
    fn common_ext(&self) -> Option<&CommonExt>;

    /// Guided Decoding Options
98
    fn get_guided_json(&self) -> Option<serde_json::Value>;
99
100
101
102
    fn get_guided_regex(&self) -> Option<String>;
    fn get_guided_grammar(&self) -> Option<String>;
    fn get_guided_choice(&self) -> Option<Vec<String>>;
    fn get_guided_decoding_backend(&self) -> Option<String>;
103
104
    #[allow(unused)] // Not used
    fn get_guided_whitespace_pattern(&self) -> Option<String>;
105
106
107

    /// Other sampling Options
    fn get_top_k(&self) -> Option<i32>;
108
    fn get_min_p(&self) -> Option<f32>;
109
    fn get_repetition_penalty(&self) -> Option<f32>;
110
    fn get_include_stop_str_in_output(&self) -> Option<bool>;
111
112
113

    /// Output Options
    fn get_skip_special_tokens(&self) -> Option<bool>;
114
115
116
117
118
119
120
121
122
123
124
125
126
}

#[cfg(test)]
mod tests {
    use super::*;

    use serde_json;

    #[test]
    fn test_common_ext_builder_default() {
        let common_ext = CommonExt::builder().build().unwrap();
        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
127
128
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
129
130
131
132
133
        assert_eq!(common_ext.guided_json, None);
        assert_eq!(common_ext.guided_regex, None);
        assert_eq!(common_ext.guided_grammar, None);
        assert_eq!(common_ext.guided_choice, None);
        assert_eq!(common_ext.guided_decoding_backend, None);
134
        assert_eq!(common_ext.include_stop_str_in_output, None);
135
        assert_eq!(common_ext.skip_special_tokens, None);
136
137
138
139
140
141
142
    }

    #[test]
    fn test_common_ext_builder_with_values() {
        let common_ext = CommonExt::builder()
            .ignore_eos(true)
            .min_tokens(10)
143
144
            .top_k(50)
            .repetition_penalty(1.2)
145
            .include_stop_str_in_output(true)
146
147
148
149
150
            .guided_json(serde_json::json!({"key": "value"}))
            .guided_regex("regex".to_string())
            .guided_grammar("grammar".to_string())
            .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
            .guided_decoding_backend("backend".to_string())
151
            .skip_special_tokens(false)
152
153
154
155
156
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(true));
        assert_eq!(common_ext.min_tokens, Some(10));
157
158
        assert_eq!(common_ext.top_k, Some(50));
        assert_eq!(common_ext.repetition_penalty, Some(1.2));
159
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        assert_eq!(
            common_ext.guided_json.as_ref(),
            Some(&serde_json::json!({"key": "value"}))
        );
        assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
        assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
        assert_eq!(
            common_ext.guided_choice,
            Some(vec!["choice1".to_string(), "choice2".to_string()])
        );
        assert_eq!(
            common_ext.guided_decoding_backend,
            Some("backend".to_string())
        );
174
        assert_eq!(common_ext.skip_special_tokens, Some(false));
175
176
177
178
179
180
181
182
    }

    #[test]
    fn test_common_ext_fields() {
        // Test that CommonExt fields can be set and retrieved correctly
        let common_ext = CommonExt::builder()
            .ignore_eos(false)
            .min_tokens(5)
183
            .include_stop_str_in_output(true)
184
185
186
187
188
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(false));
        assert_eq!(common_ext.min_tokens, Some(5));
189
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
190
191
192
193
194
195
196
197
    }

    #[test]
    fn test_validation_min_tokens() {
        // Test that min_tokens with 0 is valid
        let common_ext = CommonExt {
            ignore_eos: None,
            min_tokens: Some(0), // Should be valid (min = 0)
198
            top_k: None,
199
            min_p: None,
200
            repetition_penalty: None,
201
            include_stop_str_in_output: None,
202
203
204
205
206
            guided_json: None,
            guided_regex: None,
            guided_grammar: None,
            guided_choice: None,
            guided_decoding_backend: None,
207
            guided_whitespace_pattern: None,
208
            skip_special_tokens: None,
209
210
211
212
213
214
215
216
217
218
219
        };
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_neither_specified() {
        // Test that neither ignore_eos nor min_tokens specified works
        let common_ext = CommonExt::builder().build().unwrap();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
220
221
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
222
        assert_eq!(common_ext.include_stop_str_in_output, None);
223
224
225
226
227
228
229
230
231
232
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_default() {
        // Test that Default trait implementation works correctly
        let common_ext = CommonExt::default();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
233
234
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
235
        assert_eq!(common_ext.include_stop_str_in_output, None);
236
237
        assert!(common_ext.validate().is_ok());
    }
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

    #[test]
    fn test_skip_special_tokens_field() {
        // Test that skip_special_tokens can be set and retrieved
        let common_ext = CommonExt::builder()
            .skip_special_tokens(true)
            .build()
            .unwrap();

        assert_eq!(common_ext.skip_special_tokens, Some(true));

        let common_ext = CommonExt::builder()
            .skip_special_tokens(false)
            .build()
            .unwrap();

        assert_eq!(common_ext.skip_special_tokens, Some(false));
    }

    #[test]
    fn test_skip_special_tokens_serialization() {
        // Test that skip_special_tokens can be serialized and deserialized
        let common_ext = CommonExt::builder()
            .skip_special_tokens(true)
            .build()
            .unwrap();

        let json = serde_json::to_string(&common_ext).unwrap();
        let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.skip_special_tokens, Some(true));

        // Test with false value
        let common_ext = CommonExt::builder()
            .skip_special_tokens(false)
            .build()
            .unwrap();

        let json = serde_json::to_string(&common_ext).unwrap();
        let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.skip_special_tokens, Some(false));

        // Test that None is not serialized (skip_serializing_if = "Option::is_none")
        let common_ext = CommonExt::builder().build().unwrap();
        let json = serde_json::to_string(&common_ext).unwrap();
        assert!(!json.contains("skip_special_tokens"));
    }
286
}