oai.rs 11.2 KB
Newer Older
Biswa Panda's avatar
Biswa Panda committed
1
2
3
4
5
6
7
8
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::*;

use minijinja::{context, value::Value};

use crate::protocols::openai::{
9
    chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
Biswa Panda's avatar
Biswa Panda committed
10
11
12
};
use tracing;

13
14
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
    // No need to validate or enforce other schema checks as the basic Named function schema is already validated while creating the request.
    // Empty parameters is allowed by OpenAI at request level. Need to enforce it at template level.
    // Whenever parameters is empty, insert "type": "object" and "properties": {}
    let mut updated_tools = Vec::new();
    if let Some(arr) = tools.as_array() {
        for tool in arr {
            let mut tool = tool.clone();
            if let Some(function) = tool.get_mut("function")
                && let Some(parameters) = function.get_mut("parameters")
            {
                // Only operate if parameters is an object
                if parameters.is_object() {
                    let mut needs_type = false;
                    let mut needs_properties = false;
                    let is_empty = parameters
                        .as_object()
                        .map(|o| o.is_empty())
                        .unwrap_or(false);

                    // If empty, we need to insert both
                    if is_empty {
                        needs_type = true;
                        needs_properties = true;
                    } else {
                        // If not empty, check if type/properties are missing
                        if let Some(obj) = parameters.as_object() {
                            if !obj.contains_key("type") {
                                needs_type = true;
                            }
                            if !obj.contains_key("properties") {
                                needs_properties = true;
                            }
                        }
                    }

                    if (needs_type || needs_properties)
                        && let Some(obj) = parameters.as_object_mut()
                    {
                        if needs_type {
                            obj.insert(
                                "type".to_string(),
                                serde_json::Value::String("object".to_string()),
                            );
                        }
                        if needs_properties {
                            obj.insert(
                                "properties".to_string(),
                                serde_json::Value::Object(Default::default()),
                            );
                        }
                    }
                }
            }
            updated_tools.push(tool);
        }
    }
    Some(Value::from_serialize(&updated_tools))
}

75
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
76
77
78
79
    fn model(&self) -> String {
        self.inner.model.clone()
    }

Biswa Panda's avatar
Biswa Panda committed
80
    fn messages(&self) -> Value {
Paul Hendricks's avatar
Paul Hendricks committed
81
        Value::from_serialize(&self.inner.messages)
Biswa Panda's avatar
Biswa Panda committed
82
83
84
    }

    fn tools(&self) -> Option<Value> {
Paul Hendricks's avatar
Paul Hendricks committed
85
        if self.inner.tools.is_none() {
Biswa Panda's avatar
Biswa Panda committed
86
87
            None
        } else {
88
89
90
91
            // Try to fix the tool schema if it is missing type and properties
            Some(may_be_fix_tool_schema(
                serde_json::to_value(&self.inner.tools).unwrap(),
            )?)
Biswa Panda's avatar
Biswa Panda committed
92
93
94
95
        }
    }

    fn tool_choice(&self) -> Option<Value> {
Paul Hendricks's avatar
Paul Hendricks committed
96
        if self.inner.tool_choice.is_none() {
Biswa Panda's avatar
Biswa Panda committed
97
98
            None
        } else {
Paul Hendricks's avatar
Paul Hendricks committed
99
            Some(Value::from_serialize(&self.inner.tool_choice))
Biswa Panda's avatar
Biswa Panda committed
100
101
102
103
        }
    }

    fn should_add_generation_prompt(&self) -> bool {
Paul Hendricks's avatar
Paul Hendricks committed
104
105
106
        if let Some(last) = self.inner.messages.last() {
            matches!(
                last,
107
                dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
Paul Hendricks's avatar
Paul Hendricks committed
108
            )
Biswa Panda's avatar
Biswa Panda committed
109
110
111
112
        } else {
            true
        }
    }
113
114
115
116

    fn extract_text(&self) -> Option<TextInput> {
        Some(TextInput::Single(String::new()))
    }
Biswa Panda's avatar
Biswa Panda committed
117
118
}

119
impl OAIChatLikeRequest for NvCreateCompletionRequest {
120
121
122
    fn model(&self) -> String {
        self.inner.model.clone()
    }
Paul Hendricks's avatar
Paul Hendricks committed
123
    fn messages(&self) -> minijinja::value::Value {
124
125
126
        let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
            dynamo_async_openai::types::ChatCompletionRequestUserMessage {
                content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
127
                    crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
Paul Hendricks's avatar
Paul Hendricks committed
128
129
130
131
132
                ),
                name: None,
            },
        );

133
        minijinja::value::Value::from_serialize(vec![message])
Biswa Panda's avatar
Biswa Panda committed
134
135
136
137
138
    }

    fn should_add_generation_prompt(&self) -> bool {
        true
    }
139
140
141

    fn prompt_input_type(&self) -> PromptInput {
        match &self.inner.prompt {
142
            dynamo_async_openai::types::Prompt::IntegerArray(_) => {
143
144
                PromptInput::Tokens(TokenInput::Single(vec![]))
            }
145
            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(_) => {
146
147
                PromptInput::Tokens(TokenInput::Batch(vec![]))
            }
148
            dynamo_async_openai::types::Prompt::String(_) => {
149
150
                PromptInput::Text(TextInput::Single(String::new()))
            }
151
            dynamo_async_openai::types::Prompt::StringArray(_) => {
152
153
154
155
156
157
158
                PromptInput::Text(TextInput::Batch(vec![]))
            }
        }
    }

    fn extract_tokens(&self) -> Option<TokenInput> {
        match &self.inner.prompt {
159
            dynamo_async_openai::types::Prompt::IntegerArray(tokens) => {
160
161
                Some(TokenInput::Single(tokens.clone()))
            }
162
            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => {
163
164
                Some(TokenInput::Batch(arrays.clone()))
            }
165
166
167
168
169
170
            _ => None,
        }
    }

    fn extract_text(&self) -> Option<TextInput> {
        match &self.inner.prompt {
171
172
173
174
            dynamo_async_openai::types::Prompt::String(text) => {
                Some(TextInput::Single(text.to_string()))
            }
            dynamo_async_openai::types::Prompt::StringArray(texts) => {
175
176
177
178
179
                Some(TextInput::Batch(texts.to_vec()))
            }
            _ => None,
        }
    }
Biswa Panda's avatar
Biswa Panda committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
}

impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
    fn supports_add_generation_prompt(&self) -> bool {
        self.supports_add_generation_prompt
    }

    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
        let mixins = Value::from_dyn_object(self.mixins.clone());

        let tools = req.tools();
        let has_tools = tools.is_some();
        let add_generation_prompt = req.should_add_generation_prompt();

        tracing::trace!(
            "Rendering prompt with tools: {:?}, add_generation_prompt: {}",
            has_tools,
            add_generation_prompt
        );

        let ctx = context! {
            messages => req.messages(),
            tools => tools,
            bos_token => self.config.bos_tok(),
            eos_token => self.config.eos_tok(),
            unk_token => self.config.unk_tok(),
            add_generation_prompt => add_generation_prompt,
            ..mixins
        };

        let ctx = context! { ..ctx, ..context! {

        }};

214
        let tmpl: minijinja::Template<'_, '_> = if has_tools {
Biswa Panda's avatar
Biswa Panda committed
215
216
217
218
219
220
221
            self.env.get_template("tool_use")?
        } else {
            self.env.get_template("default")?
        };
        Ok(tmpl.render(&ctx)?)
    }
}
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_may_be_fix_tool_schema_missing_type_and_properties() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {},
                        "strict": null
                    }
                }
            ]
        }"#;

        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
        let tools = serde_json::to_value(request.tools()).unwrap();

        assert!(tools[0]["function"]["parameters"]["type"] == "object");
        assert!(
            tools[0]["function"]["parameters"]["properties"]
                == serde_json::Value::Object(Default::default())
        );
    }

    #[test]
    fn test_may_be_fix_tool_schema_missing_type() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {
                            "properties": {
                                "location": {
                                    "type": "string",
                                    "description": "City and state, e.g., 'San Francisco, CA'"
                                }
                            }
                        },
                        "strict": null
                    }
                }
            ]
        }"#;
        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();

        let tools = serde_json::to_value(request.tools()).unwrap();

        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");

        let mut expected_properties = serde_json::Map::new();
        let mut location = serde_json::Map::new();
        location.insert(
            "type".to_string(),
            serde_json::Value::String("string".to_string()),
        );
        location.insert(
            "description".to_string(),
            serde_json::Value::String("City and state, e.g., 'San Francisco, CA'".to_string()),
        );
        expected_properties.insert("location".to_string(), serde_json::Value::Object(location));

        assert_eq!(
            tools[0]["function"]["parameters"]["properties"],
            serde_json::Value::Object(expected_properties)
        );
    }

    #[test]
    fn test_may_be_fix_tool_schema_missing_properties() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {"type": "object"},
                        "strict": null
                    }
                }
            ]
        }"#;

        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
        let tools = serde_json::to_value(request.tools()).unwrap();

        assert_eq!(
            tools[0]["function"]["parameters"]["properties"],
            serde_json::Value::Object(Default::default())
        );
        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
    }
}