"examples/dreambooth/test_dreambooth_sd3.py" did not exist on "30e5e81d58eb9c3979c07e6626bae89c1df8c0e1"
chat_template.rs 51.8 KB
Newer Older
1
2
use std::collections::HashSet;

Nicolas Patry's avatar
Nicolas Patry committed
3
4
5
use crate::infer::InferError;
use crate::{
    ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
6
};
Nicolas Patry's avatar
Nicolas Patry committed
7
8
9
10
11
12
13
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;

/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
    Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
14

Nicolas Patry's avatar
Nicolas Patry committed
15
16
17
18
19
20
#[derive(Clone)]
pub(crate) struct ChatTemplate {
    template: Template<'static, 'static>,
    bos_token: Option<String>,
    eos_token: Option<String>,
    use_default_tool_template: bool,
21
    variables: HashSet<String>,
22
23
}

Nicolas Patry's avatar
Nicolas Patry committed
24
impl ChatTemplate {
25
    pub(crate) fn new(
Nicolas Patry's avatar
Nicolas Patry committed
26
27
28
        template: String,
        bos_token: Option<TokenizerConfigToken>,
        eos_token: Option<TokenizerConfigToken>,
29
    ) -> Self {
Nicolas Patry's avatar
Nicolas Patry committed
30
31
32
33
34
35
36
37
38
39
        let mut env = Box::new(Environment::new());
        // enable things like .strip() or .capitalize()
        env.set_unknown_method_callback(pycompat::unknown_method_callback);
        let template_str = template.into_boxed_str();
        env.add_function("raise_exception", raise_exception);

        // leaking env and template_str as read-only, static resources for performance.
        let template = Box::leak(env)
            .template_from_str(Box::leak(template_str))
            .unwrap();
40

41
42
43
44
45
        // get the list of variables that are used in the template
        let variables = template.undeclared_variables(true);
        // check if the `tools` variable is used in the template
        let use_default_tool_template = !variables.contains("tools");

46
        Self {
Nicolas Patry's avatar
Nicolas Patry committed
47
48
49
50
            template,
            bos_token: bos_token.map(|token| token.as_str().to_string()),
            eos_token: eos_token.map(|token| token.as_str().to_string()),
            use_default_tool_template,
51
            variables,
52
53
54
        }
    }

Nicolas Patry's avatar
Nicolas Patry committed
55
    pub(crate) fn apply(
56
        &self,
57
        guideline: Option<&str>,
Nicolas Patry's avatar
Nicolas Patry committed
58
59
60
61
62
63
64
65
        mut messages: Vec<Message>,
        grammar_with_prompt: Option<(GrammarType, String)>,
    ) -> Result<String, InferError> {
        if self.use_default_tool_template {
            if let Some(last_message) = messages.last_mut() {
                if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
                    last_message.content.push(MessageChunk::Text {
                        text: format!("\n---\n{}\n{}", tool_prompt, tools),
66
                    });
67
68
                }
            }
69
        }
70

Nicolas Patry's avatar
Nicolas Patry committed
71
72
        let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();

73
74
75
76
77
        // check if guideline is expected but not provided
        if self.variables.contains("guideline") && guideline.is_none() {
            return Err(InferError::MissingTemplateVariable("guideline".to_string()));
        }

Nicolas Patry's avatar
Nicolas Patry committed
78
79
        self.template
            .render(ChatTemplateInputs {
80
                guideline,
Nicolas Patry's avatar
Nicolas Patry committed
81
82
83
84
85
86
87
88
                messages,
                bos_token: self.bos_token.as_deref(),
                eos_token: self.eos_token.as_deref(),
                add_generation_prompt: true,
                tools: None,
                tools_prompt: None,
            })
            .map_err(InferError::TemplateError)
89
90
    }
}
91
92
93
94

// tests
#[cfg(test)]
mod tests {
Nicolas Patry's avatar
Nicolas Patry committed
95
    use crate::infer::chat_template::raise_exception;
96
    use crate::infer::ChatTemplate;
drbh's avatar
drbh committed
97
98
99
    use crate::{
        ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
    };
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    use minijinja::Environment;

    #[test]
    fn test_chat_template() {
        let env = Environment::new();

        let source = r#"
        {% for message in messages %}
            {% if message['role'] == 'system' %}
                {% if message['content']%}
                    {{'### System:\n' + message['content']+'\n\n'}}
                {% endif %}
            {% elif message['role'] == 'user' %}
                {{'### User:\n' + message['content']+'\n\n'}}
            {% elif message['role'] == 'assistant' %}
                {{'### Assistant:\n'  + message['content']}}
            {% endif %}
            {% if loop.last and add_generation_prompt %}
                {{ '### Assistant:\n' }}
            {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
Nicolas Patry's avatar
Nicolas Patry committed
133
                TextMessage {
134
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
135
                    content: "Hi!".to_string(),
136
                },
Nicolas Patry's avatar
Nicolas Patry committed
137
                TextMessage {
138
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
139
                    content: "Hello how can I help?".to_string(),
140
                },
Nicolas Patry's avatar
Nicolas Patry committed
141
                TextMessage {
142
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
143
                    content: "What is Deep Learning?".to_string(),
144
                },
Nicolas Patry's avatar
Nicolas Patry committed
145
                TextMessage {
146
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
147
                    content: "magic!".to_string(),
148
149
150
151
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
152
            add_generation_prompt: true,
153
            ..Default::default()
154
155
156
157
158
159
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();

        assert_eq!(
            result,
160
            "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        );
    }

    #[test]
    fn test_chat_template_invalid_with_raise() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {{ bos_token }}
        {% for message in messages %}
        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
        {% endif %}
        {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] + ' [/INST]' }}
        {% elif message['role'] == 'assistant' %}
        {{ message['content'] + eos_token}}
        {% else %}
        {{ raise_exception('Only user and assistant roles are supported!') }}
        {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
Nicolas Patry's avatar
Nicolas Patry committed
195
                TextMessage {
196
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
197
                    content: "Hi!".to_string(),
198
                },
Nicolas Patry's avatar
Nicolas Patry committed
199
                TextMessage {
200
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
201
                    content: "Hi again!".to_string(),
202
                },
Nicolas Patry's avatar
Nicolas Patry committed
203
                TextMessage {
204
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
205
                    content: "Hello how can I help?".to_string(),
206
                },
Nicolas Patry's avatar
Nicolas Patry committed
207
                TextMessage {
208
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
209
                    content: "What is Deep Learning?".to_string(),
210
                },
Nicolas Patry's avatar
Nicolas Patry committed
211
                TextMessage {
212
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
213
                    content: "magic!".to_string(),
214
215
216
217
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
218
            add_generation_prompt: true,
219
            ..Default::default()
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        };

        let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();

        match result {
            Ok(_) => panic!("Should have failed"),
            Err(e) => {
                assert_eq!(
                    e.detail().unwrap(),
                    "Conversation roles must alternate user/assistant/user/assistant/..."
                );
            }
        }
    }

    #[test]
    fn test_chat_template_valid_with_raise() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {{ bos_token }}
        {% for message in messages %}
        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
        {% endif %}
        {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] + ' [/INST]' }}
        {% elif message['role'] == 'assistant' %}
        {{ message['content'] + eos_token}}
        {% else %}
        {{ raise_exception('Only user and assistant roles are supported!') }}
        {% endif %}
        {% endfor %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
Nicolas Patry's avatar
Nicolas Patry committed
266
                TextMessage {
267
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
268
                    content: "Hi!".to_string(),
269
                },
Nicolas Patry's avatar
Nicolas Patry committed
270
                TextMessage {
271
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
272
                    content: "Hello how can I help?".to_string(),
273
                },
Nicolas Patry's avatar
Nicolas Patry committed
274
                TextMessage {
275
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
276
                    content: "What is Deep Learning?".to_string(),
277
                },
Nicolas Patry's avatar
Nicolas Patry committed
278
                TextMessage {
279
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
280
                    content: "magic!".to_string(),
281
282
283
284
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
285
            add_generation_prompt: true,
286
            ..Default::default()
287
288
289
290
291
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
        assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
    }
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    #[test]
    fn test_chat_template_valid_with_add_generation_prompt() {
        let mut env = Environment::new();
        env.add_function("raise_exception", raise_exception);

        let source = r#"
        {% for message in messages %}
        {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
        {% endfor %}
        {% if add_generation_prompt %}
            {{ '<|im_start|>assistant\n' }}
        {% endif %}"#;

        // trim all the whitespace
        let source = source
            .lines()
            .map(|line| line.trim())
            .collect::<Vec<&str>>()
            .join("");

        let tmpl = env.template_from_str(&source);

        let chat_template_inputs = ChatTemplateInputs {
            messages: vec![
Nicolas Patry's avatar
Nicolas Patry committed
317
                TextMessage {
318
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
319
                    content: "Hi!".to_string(),
320
                },
Nicolas Patry's avatar
Nicolas Patry committed
321
                TextMessage {
322
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
323
                    content: "Hello how can I help?".to_string(),
324
                },
Nicolas Patry's avatar
Nicolas Patry committed
325
                TextMessage {
326
                    role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
327
                    content: "What is Deep Learning?".to_string(),
328
                },
Nicolas Patry's avatar
Nicolas Patry committed
329
                TextMessage {
330
                    role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
331
                    content: "magic!".to_string(),
332
333
334
335
336
                },
            ],
            bos_token: Some("[BOS]"),
            eos_token: Some("[EOS]"),
            add_generation_prompt: true,
337
            ..Default::default()
338
339
340
341
342
        };

        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
        assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
    }
343
344
345
346
347
348
349
350
351
352
353

    struct ChatTemplateTestItem {
        name: &'static str,
        chat_template: &'static str,
        input: ChatTemplateInputs<'static>,
        target: &'static str,
    }

    #[test]
    fn test_many_chat_templates() {
        let example_chat = vec![
Nicolas Patry's avatar
Nicolas Patry committed
354
            TextMessage {
355
                role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
356
                content: "Hello, how are you?".to_string(),
357
            },
Nicolas Patry's avatar
Nicolas Patry committed
358
            TextMessage {
359
                role: "assistant".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
360
                content: "I'm doing great. How can I help you today?".to_string(),
361
            },
Nicolas Patry's avatar
Nicolas Patry committed
362
            TextMessage {
363
                role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
364
                content: "I'd like to show off how chat templating works!".to_string(),
365
366
367
            },
        ];

Nicolas Patry's avatar
Nicolas Patry committed
368
        let example_chat_with_system = [TextMessage {
369
            role: "system".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
370
371
            content: "You are a friendly chatbot who always responds in the style of a pirate"
                .to_string(),
372
373
374
375
376
377
378
379
380
381
        }]
        .iter()
        .chain(&example_chat)
        .cloned()
        .collect::<Vec<_>>();

        let test_default_templates = vec![
            ChatTemplateTestItem {
                name: "_base",
                chat_template: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
382
                input: ChatTemplateInputs {
383
384
385
386
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some(""),
387
                    ..Default::default()
388
389
390
391
392
393
                },
                target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "blenderbot",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}",
394
                input: ChatTemplateInputs {
395
396
397
398
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
399
                    ..Default::default()
400
401
402
403
404
405
                },
                target: " Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "blenderbot_small",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}",
406
                input: ChatTemplateInputs {
407
408
409
410
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
411
                    ..Default::default()
412
413
414
415
416
417
                },
                target: " Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "bloom",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
418
                input: ChatTemplateInputs {
419
420
421
422
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
423
                    ..Default::default()
424
425
426
427
428
429
                },
                target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "gpt_neox",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
430
                input: ChatTemplateInputs {
431
432
433
434
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
435
                    ..Default::default()
436
437
438
439
440
441
                },
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
            },
            ChatTemplateTestItem {
                name: "gpt2",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
442
443
444
445
446
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
447
                    ..Default::default()
448
                },
449
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
450
451
452
453
454
            },
            ChatTemplateTestItem {
                name: "llama",
                // NOTE: the `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content | trim + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
455
456
457
458
459
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: true,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
460
                    ..Default::default()
461
                },
462
                target: "<s>[INST] <<SYS>>\nYou are a friendly chatbot who always responds in the style of a pirate\n<</SYS>>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
463
464
465
466
            },
            ChatTemplateTestItem {
                name: "whisper",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
467
468
469
470
471
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: true,
                    bos_token: Some(""),
                    eos_token: Some("<|endoftext|>"),
472
                    ..Default::default()
473
                },
474
475
                target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>",
            },
476
477
478
479
480
481
482
483
484
485
486
487
        ];

        #[allow(unused_variables)] // name is unused
        for ChatTemplateTestItem {
            name,
            chat_template,
            input,
            target,
        } in test_default_templates
        {
            let mut env = Environment::new();
            env.add_function("raise_exception", raise_exception);
488
            let tmpl = env.template_from_str(chat_template);
489
490
491
492
493
494
495
496
497
498
499
500
            let result = tmpl.unwrap().render(input).unwrap();
            assert_eq!(result, target);
        }

        let test_custom_templates = vec![
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: false,
                    bos_token: Some(""),
501
                    eos_token: Some("</s>"),
502
                    ..Default::default()
503
504
505
506
507
508
509
510
                },
                target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHello, how are you?</s><|assistant|>\nI'm doing great. How can I help you today?</s><|user|>\nI'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: vec![
OlivierDehaene's avatar
OlivierDehaene committed
511
                        TextMessage {
512
                            role: "system".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
513
                            content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
514
                        },
OlivierDehaene's avatar
OlivierDehaene committed
515
                        TextMessage {
516
                            role: "user".to_string(),
Nicolas Patry's avatar
Nicolas Patry committed
517
                            content: "How many helicopters can a human eat in one sitting?".to_string(),
518
519
520
521
522
                        },
                    ],
                    add_generation_prompt: true,
                    bos_token: Some(""),
                    eos_token: Some("</s>"),
523
                    ..Default::default()
524
                },
525
                target: "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\nHow many helicopters can a human eat in one sitting?</s><|assistant|>",
526
527
528
529
530
531
532
533
534
            },
            ChatTemplateTestItem {
                name: "HuggingFaceH4/zephyr-7b-gemma-v0.1",
                chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<bos>"),
                    eos_token: Some("<eos>"),
535
                    ..Default::default()
536
537
538
539
540
541
542
543
544
545
546
                },
                target: "<bos><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "mistralai/Mistral-7B-Instruct-v0.1",
                chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
547
                    ..Default::default()
548
                },
549
                target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]",
550
551
552
553
554
555
556
557
558
            },
            ChatTemplateTestItem {
                name: "mistralai/Mixtral-8x7B-Instruct-v0.1",
                chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
559
                    ..Default::default()
560
561
562
563
564
565
566
                },
                target: "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]",
            },
            ChatTemplateTestItem {
                name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b",
                chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
                input: ChatTemplateInputs {
567
                    messages: example_chat.clone(),
568
569
570
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
571
                    ..Default::default()
572
573
574
575
576
577
578
579
580
581
582
583
                },
                target: "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "openchat/openchat-3.5-0106",
                // `.title()` has been replaced with `| upper` in the following template
                chat_template: "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
584
                    ..Default::default()
585
586
587
588
589
590
591
592
593
594
595
                },
                target: "<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>",
            },
            ChatTemplateTestItem {
                name: "upstage/SOLAR-10.7B-Instruct-v1.0",
                chat_template: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
596
                    ..Default::default()
597
598
599
600
601
602
603
604
605
606
607
608
                },
                target: "Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>",
            },
            ChatTemplateTestItem {
                name: "codellama/CodeLlama-70b-Instruct-hf",
                // NOTE: `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\n\\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\\nDestination: user\\n\\n '}}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
609
                    ..Default::default()
610
611
612
613
614
615
616
617
618
619
620
                },
                target: "<s>Source: user\n\n Hello, how are you? <step> Source: assistant\n\n I'm doing great. How can I help you today? <step> Source: user\n\n I'd like to show off how chat templating works! <step> Source: assistant\nDestination: user\n\n ",
            },
            ChatTemplateTestItem {
                name: "Deci/DeciLM-7B-instruct",
                chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\\n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
621
                    ..Default::default()
622
623
624
625
626
627
628
629
630
631
632
                },
                target: "### User:\nHello, how are you?### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "Qwen/Qwen1.5-72B-Chat",
                chat_template: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
633
                    ..Default::default()
634
635
636
637
638
639
640
641
642
643
644
                },
                target: "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "deepseek-ai/deepseek-llm-7b-chat",
                chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\n\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("<|end▁of▁sentence|>"),
645
                    ..Default::default()
646
647
648
649
650
651
652
653
654
655
656
                },
                target: "<|begin▁of▁sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end▁of▁sentence|>User: I'd like to show off how chat templating works!\n\n",
            },
            ChatTemplateTestItem {
                name: "h2oai/h2o-danube-1.8b-chat",
                chat_template: "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>'  + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
657
                    ..Default::default()
658
                },
659
                target: "<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>",
660
661
662
663
664
665
666
667
668
            },
            ChatTemplateTestItem {
                name: "internlm/internlm2-chat-7b",
                chat_template: "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
669
                    ..Default::default()
670
671
672
673
674
675
676
677
678
679
680
                },
                target: "<s><|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n",
            },
            ChatTemplateTestItem {
                name: "TheBloke/deepseek-coder-33B-instruct-AWQ",
                chat_template: "{%- set found_item = false -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set found_item = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("<|EOT|>"),
681
                    ..Default::default()
682
683
684
685
686
687
688
689
690
691
692
693
                },
                target: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n",
            },
            ChatTemplateTestItem {
                name: "ericzzz/falcon-rw-1b-chat",
                // `.strip()` has been replaced with `| trim` in the following template
                chat_template: "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] '  + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|endoftext|>"),
                    eos_token: Some("<|endoftext|>"),
694
                    ..Default::default()
695
696
697
698
699
700
701
702
703
704
705
                },
                target: "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "abacusai/Smaug-34B-v0.1",
                chat_template: "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
706
                    ..Default::default()
707
708
709
710
711
712
713
714
715
716
717
                },
                target: "Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]",
            },
            ChatTemplateTestItem {
                name: "maywell/Synatra-Mixtral-8x7B",
                chat_template: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
718
                    ..Default::default()
719
720
721
722
723
724
725
726
727
728
729
                },
                target: "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!",
            },
            ChatTemplateTestItem {
                name: "deepseek-ai/deepseek-coder-33b-instruct",
                chat_template: "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set ns.found = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<|begin▁of▁sentence|>"),
                    eos_token: Some("</EOT>"),
730
                    ..Default::default()
731
732
733
734
                },
                target: "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n",
            },
            // NOT INCLUDED
OlivierDehaene's avatar
OlivierDehaene committed
735
            // - meetkai/functionary-medium-v3.2
736
737
738
739
740
741
742
743
744
745
            // - fireworks-ai/firefunction-v1
            // https://github
            ChatTemplateTestItem {
                name: "maywell/PiVoT-MoE",
                chat_template: "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}",
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
746
                    ..Default::default()
747
748
                },
                target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!",
749
            },
750
751
752
753
754
755
756
757
758
759
760
761
762
            ChatTemplateTestItem {
                name: "google/shieldgemma-9b",
                chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n    {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Human Question: \" + messages[-1].content }}\n    {{- \"\\n<end_of_turn>\\n\\n\" }}\n    {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n    {{- \"* \" + guideline + \"\\n\" }}\n    {{- \"\\n===\\n\\n\" }}\n    {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n    {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Human Question: \" + messages[-2].content }}\n    {{- \"\\n<end_of_turn>\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Chatbot Response: \" + messages[-1].content }}\n    {{- \"\\n<end_of_turn>\\n\\n\" }}\n    {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n    {{- \"* \" + guideline + \"\\n\" }}\n    {{- \"\\n===\\n\\n\" }}\n    {{- \"Does the Chatbot Response violate the above principle? Your answer must  start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n",
                input: ChatTemplateInputs {
                    messages: example_chat_with_system.clone(),
                    add_generation_prompt: false,
                    bos_token: Some("<s>"),
                    eos_token: Some("</s>"),
                    guideline: Some("Do not use offensive language."),
                    ..Default::default()
                },
                target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
            },
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        ];

        #[allow(unused_variables)] // name is unused
        for ChatTemplateTestItem {
            name,
            chat_template,
            input,
            target,
        } in test_custom_templates
        {
            let mut env = Environment::new();
            env.add_function("raise_exception", raise_exception);
            // trim all the whitespace
            let chat_template = chat_template
                .lines()
                .map(|line| line.trim())
                .collect::<Vec<&str>>()
                .join("");

            let tmpl = env.template_from_str(&chat_template);
            let result = tmpl.unwrap().render(input).unwrap();
            assert_eq!(result, target);
        }
    }
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827

    #[test]
    fn test_chat_template_invalid_with_guideline() {
        let ct = ChatTemplate::new(
            "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n    {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Human Question: \" + messages[-1].content }}\n    {{- \"\\n<end_of_turn>\\n\\n\" }}\n    {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n    {{- \"* \" + guideline + \"\\n\" }}\n    {{- \"\\n===\\n\\n\" }}\n    {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n    {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Human Question: \" + messages[-2].content }}\n    {{- \"\\n<end_of_turn>\\n\" }}\n    {{- \"<start_of_turn>\\n\" }}\n    {{- \"Chatbot Response: \" + messages[-1].content }}\n    {{- \"\\n<end_of_turn>\\n\\n\" }}\n    {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n    {{- \"* \" + guideline + \"\\n\" }}\n    {{- \"\\n===\\n\\n\" }}\n    {{- \"Does the Chatbot Response violate the above principle? Your answer must  start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n".to_string(),
            Some(TokenizerConfigToken::String("<s>".to_string())),
            Some(TokenizerConfigToken::String("</s>".to_string())),
        );

        // convert TextMessage to Message
        let msgs: Vec<Message> = vec![
            Message {
                name: None,
                role: "user".to_string(),
                content: MessageContent::SingleText(
                    "I'd like to show off how chat templating works!".to_string(),
                ),
            },
            Message {
                name: None,
                role: "assistant".to_string(),
                content: MessageContent::SingleText(
                    "I'm doing great. How can I help you today?".to_string(),
                ),
            },
            Message {
                name: None,
                role: "user".to_string(),
                content: MessageContent::SingleText("Hello, how are you?".to_string()),
            },
        ];

        let result = ct.apply(None, msgs, None);

        match result {
            Ok(_) => panic!("Should have failed since no guideline is provided"),
            Err(e) => {
                assert_eq!(e.to_string(), "Missing template vatiable: guideline")
            }
        }
    }
drbh's avatar
drbh committed
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863

    #[test]
    fn test_chat_template_with_default_tool_template() {
        let ct = ChatTemplate::new(
            "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
            Some(TokenizerConfigToken::String("<s>".to_string())),
            Some(TokenizerConfigToken::String("</s>".to_string())),
        );

        // convert TextMessage to Message
        let msgs: Vec<Message> = vec![
            Message {
                name: None,
                role: "user".to_string(),
                content: MessageContent::SingleText(
                    "I'd like to show off how chat templating works!".to_string(),
                ),
            },
            Message {
                name: None,
                role: "assistant".to_string(),
                content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
            },
            Message {
                name: None,
                role: "user".to_string(),
                content: MessageContent::SingleText("Just testing".to_string()),
            },
        ];
        let tools = serde_json::json!("[]");
        let tool_prompt = "This default prompt will be used".to_string();
        let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
        let result = ct.apply(None, msgs, Some(grammer_with_prompt));
        let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
        assert_eq!(result.unwrap(), expected);
    }
864
}