Unverified Commit 14c18d25 authored by Yudi Xue's avatar Yudi Xue Committed by GitHub
Browse files

Frontend language separate reasoning support (#6031)

parent 90bd3e32
This diff is collapsed.
...@@ -68,6 +68,7 @@ The core features include: ...@@ -68,6 +68,7 @@ The core features include:
:caption: Frontend Tutorial :caption: Frontend Tutorial
frontend/frontend.ipynb frontend/frontend.ipynb
frontend/frontend_reasoning.ipynb
frontend/choices_methods.md frontend/choices_methods.md
.. toctree:: .. toctree::
......
...@@ -15,6 +15,7 @@ from sglang.api import ( ...@@ -15,6 +15,7 @@ from sglang.api import (
get_server_info, get_server_info,
image, image,
select, select,
separate_reasoning,
set_default_backend, set_default_backend,
system, system,
system_begin, system_begin,
...@@ -54,6 +55,7 @@ __all__ = [ ...@@ -54,6 +55,7 @@ __all__ = [
"get_server_info", "get_server_info",
"image", "image",
"select", "select",
"separate_reasoning",
"set_default_backend", "set_default_backend",
"system", "system",
"system_begin", "system_begin",
......
...@@ -15,6 +15,7 @@ from sglang.lang.ir import ( ...@@ -15,6 +15,7 @@ from sglang.lang.ir import (
SglRoleBegin, SglRoleBegin,
SglRoleEnd, SglRoleEnd,
SglSelect, SglSelect,
SglSeparateReasoning,
SglVideo, SglVideo,
) )
...@@ -277,3 +278,9 @@ def assistant_begin(): ...@@ -277,3 +278,9 @@ def assistant_begin():
def assistant_end(): def assistant_end():
return SglRoleEnd("assistant") return SglRoleEnd("assistant")
def separate_reasoning(
expr: Optional[SglExpr] = None, model_type: Optional[str] = None
):
return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
...@@ -26,6 +26,7 @@ from sglang.lang.ir import ( ...@@ -26,6 +26,7 @@ from sglang.lang.ir import (
SglRoleBegin, SglRoleBegin,
SglRoleEnd, SglRoleEnd,
SglSelect, SglSelect,
SglSeparateReasoning,
SglVariable, SglVariable,
SglVarScopeBegin, SglVarScopeBegin,
SglVarScopeEnd, SglVarScopeEnd,
...@@ -472,6 +473,8 @@ class StreamExecutor: ...@@ -472,6 +473,8 @@ class StreamExecutor:
self._execute_concatenate_and_append_kv_cache(other) self._execute_concatenate_and_append_kv_cache(other)
else: else:
self._execute_concatenate_and_append_text(other) self._execute_concatenate_and_append_text(other)
elif isinstance(other, SglSeparateReasoning):
self._execute_separate_reasoning(other)
else: else:
raise ValueError(f"Unknown type: {type(other)}") raise ValueError(f"Unknown type: {type(other)}")
...@@ -724,8 +727,44 @@ class StreamExecutor: ...@@ -724,8 +727,44 @@ class StreamExecutor:
src_rids = [state.stream_executor.sid for state in expr.states] src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid) self.backend.concatenate_and_append(src_rids, self.sid)
def _execute_separate_reasoning(self, expr: SglSeparateReasoning):
if self.stream:
# separate reasoning for stream is not supported
return
if (
self.cur_role == "assistant"
and self.num_api_spec_tokens is not None
and self.backend.is_chat_model
):
# Execute the stored lazy generation calls
self.backend.role_end_generate(self)
from sglang.srt.reasoning_parser import ReasoningParser
reasoning_parser = ReasoningParser(expr.model_type)
other = expr.expr
if not other:
return
elif isinstance(other, SglGen) or isinstance(other, SglSelect):
cur_text = self.get_var(other.name)
reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text)
reasoning_name = expr.process_name_for_reasoning(other.name)
self.set_var(other.name, normal_text)
self.set_var(reasoning_name, reasoning)
# the variable is ready to be used
self.variable_event[reasoning_name].set()
self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text
elif isinstance(other, SglExprList):
for x in other.expr_list:
self._execute_separate_reasoning(
SglSeparateReasoning(expr.model_type, x)
)
def _init_var_event(self, expr): def _init_var_event(self, expr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)): if isinstance(
expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning)
):
self.variable_event[expr.name] = threading.Event() self.variable_event[expr.name] = threading.Event()
if self.stream: if self.stream:
self.stream_var_event[expr.name] = threading.Event() self.stream_var_event[expr.name] = threading.Event()
......
...@@ -606,3 +606,30 @@ class SglCommitLazy(SglExpr): ...@@ -606,3 +606,30 @@ class SglCommitLazy(SglExpr):
def __repr__(self): def __repr__(self):
return "CommitLazy()" return "CommitLazy()"
class SglSeparateReasoning(SglExpr):
def __init__(self, model_type: str, expr: SglExpr):
super().__init__()
self.model_type = model_type
self.expr = expr
self.name = None
self._process_expr(expr)
def process_name_for_reasoning(self, name):
if not name:
raise ValueError("name must be provided")
return f"{name}_reasoning_content"
def _process_expr(self, expr):
if isinstance(expr, SglGen):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglSelect):
self.name = self.process_name_for_reasoning(expr.name)
elif isinstance(expr, SglExprList):
for x in expr.expr_list:
self._process_expr(x)
def __repr__(self):
return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"
...@@ -8,6 +8,8 @@ suites = { ...@@ -8,6 +8,8 @@ suites = {
TestFile("test_srt_backend.py"), TestFile("test_srt_backend.py"),
# Skip this due to some OPENAI_API_KEY issues # Skip this due to some OPENAI_API_KEY issues
# "test_openai_backend.py", # "test_openai_backend.py",
TestFile("test_separate_reasoning.py"),
TestFile("test_separate_reasoning_execution.py"),
], ],
} }
......
"""
Tests for the separate_reasoning functionality in sglang.
Usage:
python3 -m unittest test/lang/test_separate_reasoning.py
"""
import unittest
from sglang import assistant, gen, separate_reasoning, user
from sglang.lang.ir import SglExprList, SglSeparateReasoning
from sglang.test.test_utils import CustomTestCase
class TestSeparateReasoning(CustomTestCase):
def test_separate_reasoning_creation(self):
"""Test that SglSeparateReasoning objects are created correctly."""
# Test with valid model type and gen expression
test_gen = gen("test")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(len(expr.expr_list), 2)
self.assertEqual(expr.expr_list[0], test_gen)
reasoning_expr = expr.expr_list[1]
self.assertIsInstance(reasoning_expr, SglSeparateReasoning)
self.assertEqual(reasoning_expr.model_type, "deepseek-r1")
self.assertEqual(reasoning_expr.name, "test_reasoning_content")
# Test with another valid model type
expr = separate_reasoning(test_gen, model_type="qwen3")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(expr.expr_list[1].model_type, "qwen3")
def test_separate_reasoning_name_processing(self):
"""Test that separate_reasoning correctly processes names."""
test_gen = gen("test_var")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
reasoning_expr = expr.expr_list[1]
self.assertEqual(reasoning_expr.name, "test_var_reasoning_content")
# Test the process_name_for_reasoning method
self.assertEqual(
reasoning_expr.process_name_for_reasoning("another_var"),
"another_var_reasoning_content",
)
def test_separate_reasoning_repr(self):
"""Test the string representation of SglSeparateReasoning."""
test_gen = gen("test_var")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
reasoning_expr = expr.expr_list[1]
self.assertEqual(
repr(reasoning_expr),
"SeparateReasoning(model_type=deepseek-r1, name=test_var_reasoning_content)",
)
def test_separate_reasoning_with_invalid_model_type(self):
"""Test that separate_reasoning accepts any model type during creation."""
# Create with invalid model type
test_gen = gen("test")
expr = separate_reasoning(test_gen, model_type="invalid-model")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(expr.expr_list[1].model_type, "invalid-model")
# The actual validation happens in the ReasoningParser constructor
if __name__ == "__main__":
unittest.main()
"""
Tests for the execution of separate_reasoning functionality in sglang.
Usage:
python3 -m unittest test/lang/test_separate_reasoning_execution.py
"""
import threading
import time
import unittest
from unittest.mock import MagicMock, patch
from sglang import assistant, gen, separate_reasoning, user
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglGen, SglSeparateReasoning
from sglang.test.test_utils import CustomTestCase
# Helper function to create events that won't block program exit
def create_daemon_event():
event = threading.Event()
return event
class MockReasoningParser:
def __init__(self, model_type):
self.model_type = model_type
self.parse_non_stream_called = False
self.parse_stream_chunk_called = False
def parse_non_stream(self, full_text):
self.parse_non_stream_called = True
# Simulate parsing by adding a prefix to indicate reasoning
reasoning = f"[REASONING from {self.model_type}]: {full_text}"
normal_text = f"[NORMAL from {self.model_type}]: {full_text}"
return reasoning, normal_text
def parse_stream_chunk(self, chunk_text):
self.parse_stream_chunk_called = True
# Simulate parsing by adding a prefix to indicate reasoning
reasoning = f"[REASONING from {self.model_type}]: {chunk_text}"
normal_text = f"[NORMAL from {self.model_type}]: {chunk_text}"
return reasoning, normal_text
class TestSeparateReasoningExecution(CustomTestCase):
def setUp(self):
"""Set up for the test."""
super().setUp()
# Store any events created during the test
self.events = []
def tearDown(self):
"""Clean up any threads that might have been created during the test."""
super().tearDown()
# Set all events to ensure any waiting threads are released
for event in self.events:
event.set()
def tearDown(self):
super().tearDown()
# wake up all threads
for ev in self.events:
ev.set()
@patch("sglang.srt.reasoning_parser.ReasoningParser")
def test_execute_separate_reasoning(self, mock_parser_class):
"""Test that _execute_separate_reasoning correctly calls the ReasoningParser."""
# Setup mock parser
mock_parser = MockReasoningParser("deepseek-r1")
mock_parser_class.return_value = mock_parser
# Create a mock backend to avoid AttributeError in __del__
mock_backend = MagicMock()
# Create a StreamExecutor with necessary setup
executor = StreamExecutor(
backend=mock_backend,
arguments={},
default_sampling_para={},
chat_template={
"role_map": {"user": "user", "assistant": "assistant"}
}, # Simple chat template
stream=False,
use_thread=False,
)
# Set up the executor with a variable and its value
var_name = "test_var"
reasoning_name = f"{var_name}_reasoning_content"
var_value = "Test content"
executor.variables = {var_name: var_value}
# Create events and track them for cleanup
var_event = create_daemon_event()
reasoning_event = create_daemon_event()
self.events.extend([var_event, reasoning_event])
executor.variable_event = {var_name: var_event, reasoning_name: reasoning_event}
executor.variable_event[var_name].set() # Mark as ready
# Set up the current role
executor.cur_role = "assistant"
executor.cur_role_begin_pos = 0
executor.text_ = var_value
# Create a gen expression and a separate_reasoning expression
gen_expr = SglGen(var_name)
expr = SglSeparateReasoning("deepseek-r1", expr=gen_expr)
# Execute separate_reasoning
executor._execute_separate_reasoning(expr)
# Verify that the parser was created with the correct model type
mock_parser_class.assert_called_once_with("deepseek-r1")
# Verify that parse_non_stream was called
self.assertTrue(mock_parser.parse_non_stream_called)
# Verify that the variables were updated correctly
reasoning_name = f"{var_name}_reasoning_content"
self.assertIn(reasoning_name, executor.variables)
self.assertEqual(
executor.variables[reasoning_name],
f"[REASONING from deepseek-r1]: {var_value}",
)
self.assertEqual(
executor.variables[var_name], f"[NORMAL from deepseek-r1]: {var_value}"
)
# Verify that the variable event was set
self.assertIn(reasoning_name, executor.variable_event)
self.assertTrue(executor.variable_event[reasoning_name].is_set())
# Verify that the text was updated
self.assertEqual(executor.text_, f"[NORMAL from deepseek-r1]: {var_value}")
@patch("sglang.srt.reasoning_parser.ReasoningParser")
def test_reasoning_parser_integration(self, mock_parser_class):
"""Test the integration between separate_reasoning and ReasoningParser."""
# Setup mock parsers for different model types
deepseek_parser = MockReasoningParser("deepseek-r1")
qwen_parser = MockReasoningParser("qwen3")
# Configure the mock to return different parsers based on model type
def get_parser(model_type):
if model_type == "deepseek-r1":
return deepseek_parser
elif model_type == "qwen3":
return qwen_parser
else:
raise ValueError(f"Unsupported model type: {model_type}")
mock_parser_class.side_effect = get_parser
# Test with DeepSeek-R1 model
test_text = "This is a test"
reasoning, normal_text = deepseek_parser.parse_non_stream(test_text)
self.assertEqual(reasoning, f"[REASONING from deepseek-r1]: {test_text}")
self.assertEqual(normal_text, f"[NORMAL from deepseek-r1]: {test_text}")
# Test with Qwen3 model
reasoning, normal_text = qwen_parser.parse_non_stream(test_text)
self.assertEqual(reasoning, f"[REASONING from qwen3]: {test_text}")
self.assertEqual(normal_text, f"[NORMAL from qwen3]: {test_text}")
@patch("sglang.srt.reasoning_parser.ReasoningParser")
def test_reasoning_parser_invalid_model(self, mock_parser_class):
"""Test that ReasoningParser raises an error for invalid model types."""
# Configure the mock to raise an error for invalid model types
def get_parser(model_type):
if model_type in ["deepseek-r1", "qwen3"]:
return MockReasoningParser(model_type)
elif model_type is None:
raise ValueError("Model type must be specified")
else:
raise ValueError(f"Unsupported model type: {model_type}")
mock_parser_class.side_effect = get_parser
with self.assertRaises(ValueError) as context:
mock_parser_class("invalid-model")
self.assertIn("Unsupported model type", str(context.exception))
with self.assertRaises(ValueError) as context:
mock_parser_class(None)
self.assertIn("Model type must be specified", str(context.exception))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment