test_jinja_template_utils.py 7.36 KB
Newer Older
1
"""
2
Unit tests for Jinja chat template utils.
3
4
5
6
"""

import unittest

7
8
from sglang.srt.jinja_template_utils import (
    detect_jinja_template_content_format,
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    process_content_for_template_format,
)
from sglang.test.test_utils import CustomTestCase


class TestTemplateContentFormatDetection(CustomTestCase):
    """Test template content format detection functionality."""

    def test_detect_llama4_openai_format(self):
        """Test detection of llama4-style template (should be 'openai' format)."""
        llama4_pattern = """
{%- for message in messages %}
    {%- if message['content'] is string %}
        {{- message['content'] }}
    {%- else %}
        {%- for content in message['content'] %}
            {%- if content['type'] == 'image' %}
                {{- '<|image|>' }}
            {%- elif content['type'] == 'text' %}
                {{- content['text'] | trim }}
            {%- endif %}
        {%- endfor %}
    {%- endif %}
{%- endfor %}
        """

35
        result = detect_jinja_template_content_format(llama4_pattern)
36
37
38
39
40
41
42
43
44
45
46
47
        self.assertEqual(result, "openai")

    def test_detect_deepseek_string_format(self):
        """Test detection of deepseek-style template (should be 'string' format)."""
        deepseek_pattern = """
{%- for message in messages %}
    {%- if message['role'] == 'user' %}
        {{- '<|User|>' + message['content'] + '<|Assistant|>' }}
    {%- endif %}
{%- endfor %}
        """

48
        result = detect_jinja_template_content_format(deepseek_pattern)
49
50
51
52
53
54
        self.assertEqual(result, "string")

    def test_detect_invalid_template(self):
        """Test handling of invalid template (should default to 'string')."""
        invalid_pattern = "{{{{ invalid jinja syntax }}}}"

55
        result = detect_jinja_template_content_format(invalid_pattern)
56
57
58
59
        self.assertEqual(result, "string")

    def test_detect_empty_template(self):
        """Test handling of empty template (should default to 'string')."""
60
        result = detect_jinja_template_content_format("")
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        self.assertEqual(result, "string")

    def test_process_content_openai_format(self):
        """Test content processing for openai format."""
        msg_dict = {
            "role": "user",
            "content": [
                {"type": "text", "text": "Look at this image:"},
                {
                    "type": "image_url",
                    "image_url": {"url": "http://example.com/image.jpg"},
                },
                {"type": "text", "text": "What do you see?"},
            ],
        }

        image_data = []
78
        video_data = []
79
80
81
82
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
83
            msg_dict, "openai", image_data, video_data, audio_data, modalities
84
85
86
87
        )

        # Check that image_data was extracted
        self.assertEqual(len(image_data), 1)
88
        self.assertEqual(image_data[0].url, "http://example.com/image.jpg")
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

        # Check that content was normalized
        expected_content = [
            {"type": "text", "text": "Look at this image:"},
            {"type": "image"},  # normalized from image_url
            {"type": "text", "text": "What do you see?"},
        ]
        self.assertEqual(result["content"], expected_content)
        self.assertEqual(result["role"], "user")

    def test_process_content_string_format(self):
        """Test content processing for string format."""
        msg_dict = {
            "role": "user",
            "content": [
                {"type": "text", "text": "Hello"},
                {
                    "type": "image_url",
                    "image_url": {"url": "http://example.com/image.jpg"},
                },
                {"type": "text", "text": "world"},
            ],
        }

        image_data = []
114
        video_data = []
115
116
117
118
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
119
            msg_dict, "string", image_data, video_data, audio_data, modalities
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        )

        # For string format, should flatten to text only
        self.assertEqual(result["content"], "Hello world")
        self.assertEqual(result["role"], "user")

        # Image data should not be extracted for string format
        self.assertEqual(len(image_data), 0)

    def test_process_content_with_audio(self):
        """Test content processing with audio content."""
        msg_dict = {
            "role": "user",
            "content": [
                {"type": "text", "text": "Listen to this:"},
                {
                    "type": "audio_url",
                    "audio_url": {"url": "http://example.com/audio.mp3"},
                },
            ],
        }

        image_data = []
143
        video_data = []
144
145
146
147
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
148
            msg_dict, "openai", image_data, video_data, audio_data, modalities
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        )

        # Check that audio_data was extracted
        self.assertEqual(len(audio_data), 1)
        self.assertEqual(audio_data[0], "http://example.com/audio.mp3")

        # Check that content was normalized
        expected_content = [
            {"type": "text", "text": "Listen to this:"},
            {"type": "audio"},  # normalized from audio_url
        ]
        self.assertEqual(result["content"], expected_content)

    def test_process_content_already_string(self):
        """Test processing content that's already a string."""
        msg_dict = {"role": "user", "content": "Hello world"}

        image_data = []
167
        video_data = []
168
169
170
171
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
172
            msg_dict, "openai", image_data, video_data, audio_data, modalities
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        )

        # Should pass through unchanged
        self.assertEqual(result["content"], "Hello world")
        self.assertEqual(result["role"], "user")
        self.assertEqual(len(image_data), 0)

    def test_process_content_with_modalities(self):
        """Test content processing with modalities field."""
        msg_dict = {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {"url": "http://example.com/image.jpg"},
                    "modalities": ["vision"],
                }
            ],
        }

        image_data = []
194
        video_data = []
195
196
197
198
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
199
            msg_dict, "openai", image_data, video_data, audio_data, modalities
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        )

        # Check that modalities was extracted
        self.assertEqual(len(modalities), 1)
        self.assertEqual(modalities[0], ["vision"])

    def test_process_content_filter_none_values(self):
        """Test that None values are filtered out of processed messages."""
        msg_dict = {
            "role": "user",
            "content": "Hello",
            "name": None,
            "tool_call_id": None,
        }

        image_data = []
216
        video_data = []
217
218
219
220
        audio_data = []
        modalities = []

        result = process_content_for_template_format(
221
            msg_dict, "string", image_data, video_data, audio_data, modalities
222
223
224
225
226
227
228
229
230
        )

        # None values should be filtered out
        expected_keys = {"role", "content"}
        self.assertEqual(set(result.keys()), expected_keys)


if __name__ == "__main__":
    unittest.main()