common_ext.rs 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::Validate;

/// Common extensions for OpenAI API requests that are not part of the standard OpenAI spec
/// but are commonly needed across different request types.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
pub struct CommonExt {
    /// If true, the model will ignore the end of string token and generate to max_tokens.
    /// This field can also be specified in nvext, but the root-level value takes precedence.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub ignore_eos: Option<bool>,

    /// The minimum number of tokens to generate.
    /// This is a common parameter needed across different request types.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_tokens: Option<u32>,

23
24
25
26
27
    /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub top_k: Option<i32>,

28
29
30
31
32
    /// Relative probability floor
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub min_p: Option<f32>,

33
34
35
36
37
38
    /// How much to penalize tokens based on how frequently they occur in the text.
    /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub repetition_penalty: Option<f32>,

39
40
41
42
43
    /// include_stop_str_in_output
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub include_stop_str_in_output: Option<bool>,

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    /// Guided Decoding Options
    /// If specified, the output will be a JSON object. Can be a string, an object, or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_json: Option<serde_json::Value>,

    /// If specified, the output will follow the regex pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_regex: Option<String>,

    /// If specified, the output will follow the context-free grammar. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_grammar: Option<String>,

    /// If specified, the output will be exactly one of the choices.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_choice: Option<Vec<String>>,

    /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub guided_decoding_backend: Option<String>,
69
70
71
72
73
74

    /// If specified, the output will follow the whitespace pattern. Can be a string or null.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    #[allow(unused)] // Not used
    pub guided_whitespace_pattern: Option<String>,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
}

impl CommonExt {
    pub fn builder() -> CommonExtBuilder {
        CommonExtBuilder::default()
    }
}

/// Trait for types that provide CommonExt fields
pub trait CommonExtProvider {
    /// Get a reference to the CommonExt struct if available
    fn common_ext(&self) -> Option<&CommonExt>;

    /// Guided Decoding Options
    fn get_guided_json(&self) -> Option<&serde_json::Value>;
    fn get_guided_regex(&self) -> Option<String>;
    fn get_guided_grammar(&self) -> Option<String>;
    fn get_guided_choice(&self) -> Option<Vec<String>>;
    fn get_guided_decoding_backend(&self) -> Option<String>;
94
95
    #[allow(unused)] // Not used
    fn get_guided_whitespace_pattern(&self) -> Option<String>;
96
97
98

    /// Other sampling Options
    fn get_top_k(&self) -> Option<i32>;
99
    fn get_min_p(&self) -> Option<f32>;
100
    fn get_repetition_penalty(&self) -> Option<f32>;
101
    fn get_include_stop_str_in_output(&self) -> Option<bool>;
102
103
}

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/// Helper function to emit deprecation warnings for nvext parameters
pub fn emit_nvext_deprecation_warning(
    field_name: &str,
    nvext_has_value: bool,
    common_has_value: bool,
) {
    if nvext_has_value && !common_has_value {
        tracing::warn!(
            "DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Use '{field_name}' at the top level or in 'extra_body' instead."
        );
    } else if nvext_has_value && common_has_value {
        tracing::warn!(
            "DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Top-level '{field_name}' takes precedence. Use '{field_name}' at the top level or in 'extra_body' instead."
        );
    }
}

/// Helper function to choose between common and nvext values with deprecation warnings
pub fn choose_with_deprecation<T: Clone>(
    field: &'static str,
    common: Option<&T>,
    nv: Option<&T>,
) -> Option<T> {
    if nv.is_some() {
        emit_nvext_deprecation_warning(field, true, common.is_some());
    }
    common.cloned().or_else(|| nv.cloned())
}

133
134
135
136
137
138
139
140
141
142
143
#[cfg(test)]
mod tests {
    use super::*;

    use serde_json;

    #[test]
    fn test_common_ext_builder_default() {
        let common_ext = CommonExt::builder().build().unwrap();
        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
144
145
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
146
147
148
149
150
        assert_eq!(common_ext.guided_json, None);
        assert_eq!(common_ext.guided_regex, None);
        assert_eq!(common_ext.guided_grammar, None);
        assert_eq!(common_ext.guided_choice, None);
        assert_eq!(common_ext.guided_decoding_backend, None);
151
        assert_eq!(common_ext.include_stop_str_in_output, None);
152
153
154
155
156
157
158
    }

    #[test]
    fn test_common_ext_builder_with_values() {
        let common_ext = CommonExt::builder()
            .ignore_eos(true)
            .min_tokens(10)
159
160
            .top_k(50)
            .repetition_penalty(1.2)
161
            .include_stop_str_in_output(true)
162
163
164
165
166
167
168
169
170
171
            .guided_json(serde_json::json!({"key": "value"}))
            .guided_regex("regex".to_string())
            .guided_grammar("grammar".to_string())
            .guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
            .guided_decoding_backend("backend".to_string())
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(true));
        assert_eq!(common_ext.min_tokens, Some(10));
172
173
        assert_eq!(common_ext.top_k, Some(50));
        assert_eq!(common_ext.repetition_penalty, Some(1.2));
174
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        assert_eq!(
            common_ext.guided_json.as_ref(),
            Some(&serde_json::json!({"key": "value"}))
        );
        assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
        assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
        assert_eq!(
            common_ext.guided_choice,
            Some(vec!["choice1".to_string(), "choice2".to_string()])
        );
        assert_eq!(
            common_ext.guided_decoding_backend,
            Some("backend".to_string())
        );
    }

    #[test]
    fn test_common_ext_fields() {
        // Test that CommonExt fields can be set and retrieved correctly
        let common_ext = CommonExt::builder()
            .ignore_eos(false)
            .min_tokens(5)
197
            .include_stop_str_in_output(true)
198
199
200
201
202
            .build()
            .unwrap();

        assert_eq!(common_ext.ignore_eos, Some(false));
        assert_eq!(common_ext.min_tokens, Some(5));
203
        assert_eq!(common_ext.include_stop_str_in_output, Some(true));
204
205
206
207
208
209
210
211
    }

    #[test]
    fn test_validation_min_tokens() {
        // Test that min_tokens with 0 is valid
        let common_ext = CommonExt {
            ignore_eos: None,
            min_tokens: Some(0), // Should be valid (min = 0)
212
            top_k: None,
213
            min_p: None,
214
            repetition_penalty: None,
215
            include_stop_str_in_output: None,
216
217
218
219
220
            guided_json: None,
            guided_regex: None,
            guided_grammar: None,
            guided_choice: None,
            guided_decoding_backend: None,
221
            guided_whitespace_pattern: None,
222
223
224
225
226
227
228
229
230
231
232
        };
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_neither_specified() {
        // Test that neither ignore_eos nor min_tokens specified works
        let common_ext = CommonExt::builder().build().unwrap();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
233
234
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
235
        assert_eq!(common_ext.include_stop_str_in_output, None);
236
237
238
239
240
241
242
243
244
245
        assert!(common_ext.validate().is_ok());
    }

    #[test]
    fn test_common_ext_default() {
        // Test that Default trait implementation works correctly
        let common_ext = CommonExt::default();

        assert_eq!(common_ext.ignore_eos, None);
        assert_eq!(common_ext.min_tokens, None);
246
247
        assert_eq!(common_ext.top_k, None);
        assert_eq!(common_ext.repetition_penalty, None);
248
        assert_eq!(common_ext.include_stop_str_in_output, None);
249
250
        assert!(common_ext.validate().is_ok());
    }
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

    #[test]
    fn test_choose_with_deprecation() {
        // Common takes precedence
        let result = choose_with_deprecation(
            "test_field",
            Some(&"common_value".to_string()),
            Some(&"nvext_value".to_string()),
        );
        assert_eq!(result, Some("common_value".to_string()));

        // Fallback to nvext
        let result = choose_with_deprecation("test_field", None, Some(&"nvext_value".to_string()));
        assert_eq!(result, Some("nvext_value".to_string()));

        // Both None
        let result: Option<String> = choose_with_deprecation("test_field", None, None);
        assert_eq!(result, None);
    }
270
}