test_chat_template.rs 5.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#[cfg(test)]
mod tests {
    use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};

    #[test]
    fn test_chat_message_helpers() {
        let system_msg = ChatMessage::system("You are a helpful assistant");
        assert_eq!(system_msg.role, "system");
        assert_eq!(system_msg.content, "You are a helpful assistant");

        let user_msg = ChatMessage::user("Hello!");
        assert_eq!(user_msg.role, "user");
        assert_eq!(user_msg.content, "Hello!");

        let assistant_msg = ChatMessage::assistant("Hi there!");
        assert_eq!(assistant_msg.role, "assistant");
        assert_eq!(assistant_msg.content, "Hi there!");
    }

    #[test]
    fn test_llama_style_template() {
        // Test a Llama-style chat template
        let template = r#"
{%- if messages[0]['role'] == 'system' -%}
    {%- set system_message = messages[0]['content'] -%}
    {%- set messages = messages[1:] -%}
{%- else -%}
    {%- set system_message = '' -%}
{%- endif -%}

{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}

{%- for message in messages %}
    {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}

{%- if add_generation_prompt %}
    {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#;

        let processor = ChatTemplateProcessor::new(
            template.to_string(),
            Some("<|begin_of_text|>".to_string()),
            Some("<|end_of_text|>".to_string()),
        );

        let messages = vec![
            ChatMessage::system("You are a helpful assistant"),
            ChatMessage::user("What is 2+2?"),
        ];

        let result = processor.apply_chat_template(&messages, true).unwrap();

        // Check that the result contains expected markers
        assert!(result.contains("<|begin_of_text|>"));
        assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
        assert!(result.contains("You are a helpful assistant"));
        assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
        assert!(result.contains("What is 2+2?"));
        assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
    }

    #[test]
    fn test_chatml_template() {
        // Test a ChatML-style template
        let template = r#"
{%- for message in messages %}
    {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- endif %}
"#;

        let processor = ChatTemplateProcessor::new(template.to_string(), None, None);

        let messages = vec![
            ChatMessage::user("Hello"),
            ChatMessage::assistant("Hi there!"),
            ChatMessage::user("How are you?"),
        ];

        let result = processor.apply_chat_template(&messages, true).unwrap();

        // Check ChatML format
        assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
        assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
        assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
        assert!(result.ends_with("<|im_start|>assistant\n"));
    }

    #[test]
    fn test_template_without_generation_prompt() {
        let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;

        let processor = ChatTemplateProcessor::new(template.to_string(), None, None);

        let messages = vec![ChatMessage::user("Test")];

        // Test without generation prompt
        let result = processor.apply_chat_template(&messages, false).unwrap();
        assert_eq!(result.trim(), "user: Test");

        // Test with generation prompt
        let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
        assert!(result_with_prompt.contains("assistant:"));
    }

    #[test]
    fn test_template_with_special_tokens() {
        let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;

        let processor = ChatTemplateProcessor::new(
            template.to_string(),
            Some("<s>".to_string()),
            Some("</s>".to_string()),
        );

        let messages = vec![ChatMessage::user("Hello")];

        let result = processor.apply_chat_template(&messages, false).unwrap();
        assert_eq!(result, "<s>Hello</s>");
    }

    #[test]
    fn test_empty_messages() {
        let template =
            r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;

        let processor = ChatTemplateProcessor::new(template.to_string(), None, None);

        let messages = vec![];
        let result = processor.apply_chat_template(&messages, false).unwrap();
        assert_eq!(result, "");
    }

    // Integration test with actual tokenizer file loading would go here
    // but requires a real tokenizer_config.json file
}