oai.rs 11.6 KB
Newer Older
Biswa Panda's avatar
Biswa Panda committed
1
2
3
4
5
6
7
8
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::*;

use minijinja::{context, value::Value};

use crate::protocols::openai::{
9
    chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
Biswa Panda's avatar
Biswa Panda committed
10
11
12
};
use tracing;

13
14
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
    // No need to validate or enforce other schema checks as the basic Named function schema is already validated while creating the request.
    // Empty parameters is allowed by OpenAI at request level. Need to enforce it at template level.
    // Whenever parameters is empty, insert "type": "object" and "properties": {}
    let mut updated_tools = Vec::new();
    if let Some(arr) = tools.as_array() {
        for tool in arr {
            let mut tool = tool.clone();
            if let Some(function) = tool.get_mut("function")
                && let Some(parameters) = function.get_mut("parameters")
            {
                // Only operate if parameters is an object
                if parameters.is_object() {
                    let mut needs_type = false;
                    let mut needs_properties = false;
                    let is_empty = parameters
                        .as_object()
                        .map(|o| o.is_empty())
                        .unwrap_or(false);

                    // If empty, we need to insert both
                    if is_empty {
                        needs_type = true;
                        needs_properties = true;
                    } else {
                        // If not empty, check if type/properties are missing
                        if let Some(obj) = parameters.as_object() {
                            if !obj.contains_key("type") {
                                needs_type = true;
                            }
                            if !obj.contains_key("properties") {
                                needs_properties = true;
                            }
                        }
                    }

                    if (needs_type || needs_properties)
                        && let Some(obj) = parameters.as_object_mut()
                    {
                        if needs_type {
                            obj.insert(
                                "type".to_string(),
                                serde_json::Value::String("object".to_string()),
                            );
                        }
                        if needs_properties {
                            obj.insert(
                                "properties".to_string(),
                                serde_json::Value::Object(Default::default()),
                            );
                        }
                    }
                }
            }
            updated_tools.push(tool);
        }
    }
    Some(Value::from_serialize(&updated_tools))
}

75
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
76
77
78
79
    fn model(&self) -> String {
        self.inner.model.clone()
    }

Biswa Panda's avatar
Biswa Panda committed
80
    fn messages(&self) -> Value {
Paul Hendricks's avatar
Paul Hendricks committed
81
        Value::from_serialize(&self.inner.messages)
Biswa Panda's avatar
Biswa Panda committed
82
83
84
    }

    fn tools(&self) -> Option<Value> {
Paul Hendricks's avatar
Paul Hendricks committed
85
        if self.inner.tools.is_none() {
Biswa Panda's avatar
Biswa Panda committed
86
87
            None
        } else {
88
89
90
91
            // Try to fix the tool schema if it is missing type and properties
            Some(may_be_fix_tool_schema(
                serde_json::to_value(&self.inner.tools).unwrap(),
            )?)
Biswa Panda's avatar
Biswa Panda committed
92
93
94
95
        }
    }

    fn tool_choice(&self) -> Option<Value> {
Paul Hendricks's avatar
Paul Hendricks committed
96
        if self.inner.tool_choice.is_none() {
Biswa Panda's avatar
Biswa Panda committed
97
98
            None
        } else {
Paul Hendricks's avatar
Paul Hendricks committed
99
            Some(Value::from_serialize(&self.inner.tool_choice))
Biswa Panda's avatar
Biswa Panda committed
100
101
102
103
        }
    }

    fn should_add_generation_prompt(&self) -> bool {
Paul Hendricks's avatar
Paul Hendricks committed
104
105
106
        if let Some(last) = self.inner.messages.last() {
            matches!(
                last,
107
                dynamo_async_openai::types::ChatCompletionRequestMessage::User(_)
Paul Hendricks's avatar
Paul Hendricks committed
108
            )
Biswa Panda's avatar
Biswa Panda committed
109
110
111
112
        } else {
            true
        }
    }
113
114
115
116

    fn extract_text(&self) -> Option<TextInput> {
        Some(TextInput::Single(String::new()))
    }
117
118
119
120

    fn chat_template_args(&self) -> Option<&std::collections::HashMap<String, serde_json::Value>> {
        self.chat_template_args.as_ref()
    }
Biswa Panda's avatar
Biswa Panda committed
121
122
}

123
impl OAIChatLikeRequest for NvCreateCompletionRequest {
124
125
126
    fn model(&self) -> String {
        self.inner.model.clone()
    }
Paul Hendricks's avatar
Paul Hendricks committed
127
    fn messages(&self) -> minijinja::value::Value {
128
129
130
        let message = dynamo_async_openai::types::ChatCompletionRequestMessage::User(
            dynamo_async_openai::types::ChatCompletionRequestUserMessage {
                content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
131
                    crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
Paul Hendricks's avatar
Paul Hendricks committed
132
133
134
135
136
                ),
                name: None,
            },
        );

137
        minijinja::value::Value::from_serialize(vec![message])
Biswa Panda's avatar
Biswa Panda committed
138
139
140
141
142
    }

    fn should_add_generation_prompt(&self) -> bool {
        true
    }
143
144
145

    fn prompt_input_type(&self) -> PromptInput {
        match &self.inner.prompt {
146
            dynamo_async_openai::types::Prompt::IntegerArray(_) => {
147
148
                PromptInput::Tokens(TokenInput::Single(vec![]))
            }
149
            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(_) => {
150
151
                PromptInput::Tokens(TokenInput::Batch(vec![]))
            }
152
            dynamo_async_openai::types::Prompt::String(_) => {
153
154
                PromptInput::Text(TextInput::Single(String::new()))
            }
155
            dynamo_async_openai::types::Prompt::StringArray(_) => {
156
157
158
159
160
161
162
                PromptInput::Text(TextInput::Batch(vec![]))
            }
        }
    }

    fn extract_tokens(&self) -> Option<TokenInput> {
        match &self.inner.prompt {
163
            dynamo_async_openai::types::Prompt::IntegerArray(tokens) => {
164
165
                Some(TokenInput::Single(tokens.clone()))
            }
166
            dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => {
167
168
                Some(TokenInput::Batch(arrays.clone()))
            }
169
170
171
172
173
174
            _ => None,
        }
    }

    fn extract_text(&self) -> Option<TextInput> {
        match &self.inner.prompt {
175
176
177
178
            dynamo_async_openai::types::Prompt::String(text) => {
                Some(TextInput::Single(text.to_string()))
            }
            dynamo_async_openai::types::Prompt::StringArray(texts) => {
179
180
181
182
183
                Some(TextInput::Batch(texts.to_vec()))
            }
            _ => None,
        }
    }
Biswa Panda's avatar
Biswa Panda committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
}

impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
    fn supports_add_generation_prompt(&self) -> bool {
        self.supports_add_generation_prompt
    }

    fn render(&self, req: &dyn OAIChatLikeRequest) -> Result<String> {
        let mixins = Value::from_dyn_object(self.mixins.clone());

        let tools = req.tools();
        let has_tools = tools.is_some();
        let add_generation_prompt = req.should_add_generation_prompt();

        tracing::trace!(
            "Rendering prompt with tools: {:?}, add_generation_prompt: {}",
            has_tools,
            add_generation_prompt
        );

        let ctx = context! {
            messages => req.messages(),
            tools => tools,
            bos_token => self.config.bos_tok(),
            eos_token => self.config.eos_tok(),
            unk_token => self.config.unk_tok(),
            add_generation_prompt => add_generation_prompt,
            ..mixins
        };

214
215
216
217
218
219
220
        // Merge any additional args into the context last so they take precedence
        let ctx = if let Some(args) = req.chat_template_args() {
            let extra = Value::from_serialize(args);
            context! { ..ctx, ..extra }
        } else {
            ctx
        };
Biswa Panda's avatar
Biswa Panda committed
221

222
        let tmpl: minijinja::Template<'_, '_> = if has_tools {
Biswa Panda's avatar
Biswa Panda committed
223
224
225
226
227
228
229
            self.env.get_template("tool_use")?
        } else {
            self.env.get_template("default")?
        };
        Ok(tmpl.render(&ctx)?)
    }
}
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_may_be_fix_tool_schema_missing_type_and_properties() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {},
                        "strict": null
                    }
                }
            ]
        }"#;

        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
        let tools = serde_json::to_value(request.tools()).unwrap();

        assert!(tools[0]["function"]["parameters"]["type"] == "object");
        assert!(
            tools[0]["function"]["parameters"]["properties"]
                == serde_json::Value::Object(Default::default())
        );
    }

    #[test]
    fn test_may_be_fix_tool_schema_missing_type() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {
                            "properties": {
                                "location": {
                                    "type": "string",
                                    "description": "City and state, e.g., 'San Francisco, CA'"
                                }
                            }
                        },
                        "strict": null
                    }
                }
            ]
        }"#;
        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();

        let tools = serde_json::to_value(request.tools()).unwrap();

        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");

        let mut expected_properties = serde_json::Map::new();
        let mut location = serde_json::Map::new();
        location.insert(
            "type".to_string(),
            serde_json::Value::String("string".to_string()),
        );
        location.insert(
            "description".to_string(),
            serde_json::Value::String("City and state, e.g., 'San Francisco, CA'".to_string()),
        );
        expected_properties.insert("location".to_string(), serde_json::Value::Object(location));

        assert_eq!(
            tools[0]["function"]["parameters"]["properties"],
            serde_json::Value::Object(expected_properties)
        );
    }

    #[test]
    fn test_may_be_fix_tool_schema_missing_properties() {
        let json_str = r#"{
            "model": "gpt-4o",
            "messages": [],
            "tools": [
                {
                    "type": "function",
                    "function": {
                        "name": "get_weather",
                        "description": "Get the current weather in a given location",
                        "parameters": {"type": "object"},
                        "strict": null
                    }
                }
            ]
        }"#;

        let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
        let tools = serde_json::to_value(request.tools()).unwrap();

        assert_eq!(
            tools[0]["function"]["parameters"]["properties"],
            serde_json::Value::Object(Default::default())
        );
        assert_eq!(tools[0]["function"]["parameters"]["type"], "object");
    }
}