test_separate_reasoning_execution.py 7.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Tests for the execution of separate_reasoning functionality in sglang.

Usage:
python3 -m unittest test/lang/test_separate_reasoning_execution.py
"""

import threading
import time
import unittest
from unittest.mock import MagicMock, patch

from sglang import assistant, gen, separate_reasoning, user
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglGen, SglSeparateReasoning
from sglang.test.test_utils import CustomTestCase


# Helper function to create events that won't block program exit
def create_daemon_event():
    event = threading.Event()
    return event


class MockReasoningParser:
    def __init__(self, model_type):
        self.model_type = model_type
        self.parse_non_stream_called = False
        self.parse_stream_chunk_called = False

    def parse_non_stream(self, full_text):
        self.parse_non_stream_called = True
        # Simulate parsing by adding a prefix to indicate reasoning
        reasoning = f"[REASONING from {self.model_type}]: {full_text}"
        normal_text = f"[NORMAL from {self.model_type}]: {full_text}"
        return reasoning, normal_text

    def parse_stream_chunk(self, chunk_text):
        self.parse_stream_chunk_called = True
        # Simulate parsing by adding a prefix to indicate reasoning
        reasoning = f"[REASONING from {self.model_type}]: {chunk_text}"
        normal_text = f"[NORMAL from {self.model_type}]: {chunk_text}"
        return reasoning, normal_text


class TestSeparateReasoningExecution(CustomTestCase):
    def setUp(self):
        """Set up for the test."""
        super().setUp()
        # Store any events created during the test
        self.events = []

    def tearDown(self):
        """Clean up any threads that might have been created during the test."""
        super().tearDown()

        # Set all events to ensure any waiting threads are released
        for event in self.events:
            event.set()

    def tearDown(self):
        super().tearDown()
        # wake up all threads
        for ev in self.events:
            ev.set()

    @patch("sglang.srt.reasoning_parser.ReasoningParser")
    def test_execute_separate_reasoning(self, mock_parser_class):
        """Test that _execute_separate_reasoning correctly calls the ReasoningParser."""
        # Setup mock parser
        mock_parser = MockReasoningParser("deepseek-r1")
        mock_parser_class.return_value = mock_parser

        # Create a mock backend to avoid AttributeError in __del__
        mock_backend = MagicMock()

        # Create a StreamExecutor with necessary setup
        executor = StreamExecutor(
            backend=mock_backend,
            arguments={},
            default_sampling_para={},
            chat_template={
                "role_map": {"user": "user", "assistant": "assistant"}
            },  # Simple chat template
            stream=False,
            use_thread=False,
        )

        # Set up the executor with a variable and its value
        var_name = "test_var"
        reasoning_name = f"{var_name}_reasoning_content"
        var_value = "Test content"
        executor.variables = {var_name: var_value}

        # Create events and track them for cleanup
        var_event = create_daemon_event()
        reasoning_event = create_daemon_event()
        self.events.extend([var_event, reasoning_event])

        executor.variable_event = {var_name: var_event, reasoning_name: reasoning_event}
        executor.variable_event[var_name].set()  # Mark as ready

        # Set up the current role
        executor.cur_role = "assistant"
        executor.cur_role_begin_pos = 0
        executor.text_ = var_value

        # Create a gen expression and a separate_reasoning expression
        gen_expr = SglGen(var_name)
        expr = SglSeparateReasoning("deepseek-r1", expr=gen_expr)

        # Execute separate_reasoning
        executor._execute_separate_reasoning(expr)

        # Verify that the parser was created with the correct model type
        mock_parser_class.assert_called_once_with("deepseek-r1")

        # Verify that parse_non_stream was called
        self.assertTrue(mock_parser.parse_non_stream_called)

        # Verify that the variables were updated correctly
        reasoning_name = f"{var_name}_reasoning_content"
        self.assertIn(reasoning_name, executor.variables)
        self.assertEqual(
            executor.variables[reasoning_name],
            f"[REASONING from deepseek-r1]: {var_value}",
        )
        self.assertEqual(
            executor.variables[var_name], f"[NORMAL from deepseek-r1]: {var_value}"
        )

        # Verify that the variable event was set
        self.assertIn(reasoning_name, executor.variable_event)
        self.assertTrue(executor.variable_event[reasoning_name].is_set())

        # Verify that the text was updated
        self.assertEqual(executor.text_, f"[NORMAL from deepseek-r1]: {var_value}")

    @patch("sglang.srt.reasoning_parser.ReasoningParser")
    def test_reasoning_parser_integration(self, mock_parser_class):
        """Test the integration between separate_reasoning and ReasoningParser."""
        # Setup mock parsers for different model types
        deepseek_parser = MockReasoningParser("deepseek-r1")
        qwen_parser = MockReasoningParser("qwen3")

        # Configure the mock to return different parsers based on model type
        def get_parser(model_type):
            if model_type == "deepseek-r1":
                return deepseek_parser
            elif model_type == "qwen3":
                return qwen_parser
            else:
                raise ValueError(f"Unsupported model type: {model_type}")

        mock_parser_class.side_effect = get_parser

        # Test with DeepSeek-R1 model
        test_text = "This is a test"
        reasoning, normal_text = deepseek_parser.parse_non_stream(test_text)

        self.assertEqual(reasoning, f"[REASONING from deepseek-r1]: {test_text}")
        self.assertEqual(normal_text, f"[NORMAL from deepseek-r1]: {test_text}")

        # Test with Qwen3 model
        reasoning, normal_text = qwen_parser.parse_non_stream(test_text)

        self.assertEqual(reasoning, f"[REASONING from qwen3]: {test_text}")
        self.assertEqual(normal_text, f"[NORMAL from qwen3]: {test_text}")

    @patch("sglang.srt.reasoning_parser.ReasoningParser")
    def test_reasoning_parser_invalid_model(self, mock_parser_class):
        """Test that ReasoningParser raises an error for invalid model types."""

        # Configure the mock to raise an error for invalid model types
        def get_parser(model_type):
            if model_type in ["deepseek-r1", "qwen3"]:
                return MockReasoningParser(model_type)
            elif model_type is None:
                raise ValueError("Model type must be specified")
            else:
                raise ValueError(f"Unsupported model type: {model_type}")

        mock_parser_class.side_effect = get_parser

        with self.assertRaises(ValueError) as context:
            mock_parser_class("invalid-model")
        self.assertIn("Unsupported model type", str(context.exception))

        with self.assertRaises(ValueError) as context:
            mock_parser_class(None)
        self.assertIn("Model type must be specified", str(context.exception))


if __name__ == "__main__":
    unittest.main()