common_ext.rs 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::Validate;

/// Common extensions for OpenAI API requests that are not part of the standard OpenAI spec
/// but are commonly needed across different request types.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
pub struct CommonExt {
    /// If true, the model will ignore the end of string token and generate to max_tokens.
    /// This field can also be specified in nvext, but the root-level value takes precedence.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub ignore_eos: Option<bool>,

    /// The minimum number of tokens to generate.
    /// This is a common parameter needed across different request types.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_tokens: Option<u32>,

23
24
25
26
27
    /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub top_k: Option<i32>,

28
29
30
31
32
    /// Relative probability floor
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_p: Option<f32>,

33
34
35
36
37
38
    /// How much to penalize tokens based on how frequently they occur in the text.
    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub repetition_penalty: Option<f32>,

39
40
41
42
43
    /// include_stop_str_in_output
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub include_stop_str_in_output: Option<bool>,

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    /// Guided Decoding Options
    /// If specified, the output will be a JSON object. Can be a string, an object, or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_json: Option<serde_json::Value>,

    /// If specified, the output will follow the regex pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_regex: Option<String>,

    /// If specified, the output will follow the context-free grammar. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_grammar: Option<String>,

    /// If specified, the output will be exactly one of the choices.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_choice: Option<Vec<String>>,

    /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_decoding_backend: Option<String>,
69
70
71
72
73
74

    /// If specified, the output will follow the whitespace pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    #[allow(unused)] // Not used
    pub guided_whitespace_pattern: Option<String>,
75
76
77
78
79
80
81
82

    /// Whether to skip special tokens in the decoded output.
    /// When true, special tokens (like EOS, BOS, PAD) are removed from the output text.
    /// When false, special tokens are included in the output text.
    /// Defaults to false if not specified.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub skip_special_tokens: Option<bool>,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
}

impl CommonExt {
    pub fn builder() -> CommonExtBuilder {
        CommonExtBuilder::default()
    }
}

/// Trait for types that provide CommonExt fields
pub trait CommonExtProvider {
    /// Get a reference to the CommonExt struct if available
    fn common_ext(&self) -> Option<&CommonExt>;

    /// Guided Decoding Options
97
    fn get_guided_json(&self) -> Option<serde_json::Value>;
98
99
100
101
    fn get_guided_regex(&self) -> Option<String>;
    fn get_guided_grammar(&self) -> Option<String>;
    fn get_guided_choice(&self) -> Option<Vec<String>>;
    fn get_guided_decoding_backend(&self) -> Option<String>;
102
103
    #[allow(unused)] // Not used
    fn get_guided_whitespace_pattern(&self) -> Option<String>;
104
105
106

    /// Other sampling Options
    fn get_top_k(&self) -> Option<i32>;
107
    fn get_min_p(&self) -> Option<f32>;
108
    fn get_repetition_penalty(&self) -> Option<f32>;
109
    fn get_include_stop_str_in_output(&self) -> Option<bool>;
110
111
112

    /// Output Options
    fn get_skip_special_tokens(&self) -> Option<bool>;
113
114
115
116
117
118
119
120
121
122
123
124
125
}

#[cfg(test)]
mod tests {
    use super::*;

    use serde_json;

    #[test]
    fn test_common_ext_builder_default() {
        let common_ext = CommonExt::builder().build().unwrap();
        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
126
127
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
128
129
130
131
132
        assert_eq!(common_ext.guided_json, None);
        assert_eq!(common_ext.guided_regex, None);
        assert_eq!(common_ext.guided_grammar, None);
        assert_eq!(common_ext.guided_choice, None);
        assert_eq!(common_ext.guided_decoding_backend, None);
133
        assert_eq!(common_ext.include_stop_str_in_output, None);
134
        assert_eq!(common_ext.skip_special_tokens, None);
135
136
137
138
139
140
141
    }

    #[test]
    fn test_common_ext_builder_with_values() {
        let common_ext = CommonExt::builder()
            .ignore_eos(true)
            .min_tokens(10)
142
143
            .top_k(50)
            .repetition_penalty(1.2)
144
            .include_stop_str_in_output(true)
145
146
147
148
149
            .guided_json(serde_json::json!({"key": "value"}))
            .guided_regex("regex".to_string())
            .guided_grammar("grammar".to_string())
            .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
            .guided_decoding_backend("backend".to_string())
150
            .skip_special_tokens(false)
151
152
153
154
155
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(true));
        assert_eq!(common_ext.min_tokens, Some(10));
156
157
        assert_eq!(common_ext.top_k, Some(50));
        assert_eq!(common_ext.repetition_penalty, Some(1.2));
158
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        assert_eq!(
            common_ext.guided_json.as_ref(),
            Some(&serde_json::json!({"key": "value"}))
        );
        assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
        assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
        assert_eq!(
            common_ext.guided_choice,
            Some(vec!["choice1".to_string(), "choice2".to_string()])
        );
        assert_eq!(
            common_ext.guided_decoding_backend,
            Some("backend".to_string())
        );
173
        assert_eq!(common_ext.skip_special_tokens, Some(false));
174
175
176
177
178
179
180
181
    }

    #[test]
    fn test_common_ext_fields() {
        // Test that CommonExt fields can be set and retrieved correctly
        let common_ext = CommonExt::builder()
            .ignore_eos(false)
            .min_tokens(5)
182
            .include_stop_str_in_output(true)
183
184
185
186
187
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(false));
        assert_eq!(common_ext.min_tokens, Some(5));
188
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
189
190
191
192
193
194
195
196
    }

    #[test]
    fn test_validation_min_tokens() {
        // Test that min_tokens with 0 is valid
        let common_ext = CommonExt {
            ignore_eos: None,
            min_tokens: Some(0), // Should be valid (min = 0)
197
            top_k: None,
198
            min_p: None,
199
            repetition_penalty: None,
200
            include_stop_str_in_output: None,
201
202
203
204
205
            guided_json: None,
            guided_regex: None,
            guided_grammar: None,
            guided_choice: None,
            guided_decoding_backend: None,
206
            guided_whitespace_pattern: None,
207
            skip_special_tokens: None,
208
209
210
211
212
213
214
215
216
217
218
        };
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_neither_specified() {
        // Test that neither ignore_eos nor min_tokens specified works
        let common_ext = CommonExt::builder().build().unwrap();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
219
220
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
221
        assert_eq!(common_ext.include_stop_str_in_output, None);
222
223
224
225
226
227
228
229
230
231
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_default() {
        // Test that Default trait implementation works correctly
        let common_ext = CommonExt::default();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
232
233
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
234
        assert_eq!(common_ext.include_stop_str_in_output, None);
235
236
        assert!(common_ext.validate().is_ok());
    }
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    #[test]
    fn test_skip_special_tokens_field() {
        // Test that skip_special_tokens can be set and retrieved
        let common_ext = CommonExt::builder()
            .skip_special_tokens(true)
            .build()
            .unwrap();

        assert_eq!(common_ext.skip_special_tokens, Some(true));

        let common_ext = CommonExt::builder()
            .skip_special_tokens(false)
            .build()
            .unwrap();

        assert_eq!(common_ext.skip_special_tokens, Some(false));
    }

    #[test]
    fn test_skip_special_tokens_serialization() {
        // Test that skip_special_tokens can be serialized and deserialized
        let common_ext = CommonExt::builder()
            .skip_special_tokens(true)
            .build()
            .unwrap();

        let json = serde_json::to_string(&common_ext).unwrap();
        let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.skip_special_tokens, Some(true));

        // Test with false value
        let common_ext = CommonExt::builder()
            .skip_special_tokens(false)
            .build()
            .unwrap();

        let json = serde_json::to_string(&common_ext).unwrap();
        let deserialized: CommonExt = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.skip_special_tokens, Some(false));

        // Test that None is not serialized (skip_serializing_if = "Option::is_none")
        let common_ext = CommonExt::builder().build().unwrap();
        let json = serde_json::to_string(&common_ext).unwrap();
        assert!(!json.contains("skip_special_tokens"));
    }
285
}