Unverified Commit 112b496a authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

misc: Improvement to serving_chat.py and add more ut (#7489)

parent 3562256b
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
"""Pydantic models for OpenAI API protocol""" """Pydantic models for OpenAI API protocol"""
import time import time
from typing import Dict, List, Optional, Union from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
...@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[ ...@@ -587,3 +588,30 @@ OpenAIServingRequest = Union[
ScoringRequest, ScoringRequest,
V1RerankReqInput, V1RerankReqInput,
] ]
@dataclass
class MessageProcessingResult:
"""Result of processing chat messages and applying templates.
This dataclass encapsulates all the outputs from message processing including
prompt generation, multimodal data extraction, and constraint preparation.
Used internally by OpenAIServingChat to pass processed data between methods.
Args:
prompt: The final text prompt after applying chat template
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
image_data: Extracted image data from messages, if any
audio_data: Extracted audio data from messages, if any
modalities: List of modality types present in the messages
stop: Combined stop strings from template and request
tool_call_constraint: Optional constraint for structured tool calls
"""
prompt: str
prompt_ids: Union[str, List[int]]
image_data: Optional[Any]
audio_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
...@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -22,6 +22,7 @@ from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse, ErrorResponse,
FunctionResponse, FunctionResponse,
LogProbs, LogProbs,
MessageProcessingResult,
ToolCall, ToolCall,
TopLogprob, TopLogprob,
) )
...@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -62,41 +63,33 @@ class OpenAIServingChat(OpenAIServingBase):
is_multimodal = self.tokenizer_manager.model_config.is_multimodal is_multimodal = self.tokenizer_manager.model_config.is_multimodal
# Process messages and apply chat template # Process messages and apply chat template
( processed_messages = self._process_messages(request, is_multimodal)
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
# Build sampling parameters # Build sampling parameters
sampling_params = self._build_sampling_params( sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint request, processed_messages.stop, processed_messages.tool_call_constraint
) )
# Handle single vs multiple requests # Handle single vs multiple requests
if is_multimodal: if is_multimodal:
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": processed_messages.prompt}
else: else:
if isinstance(prompt_ids, str): if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": prompt_ids} prompt_kwargs = {"text": processed_messages.prompt_ids}
else: else:
prompt_kwargs = {"input_ids": prompt_ids} prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=image_data, image_data=processed_messages.image_data,
audio_data=audio_data, audio_data=processed_messages.audio_data,
sampling_params=sampling_params, sampling_params=sampling_params,
return_logprob=request.logprobs, return_logprob=request.logprobs,
logprob_start_len=-1, logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0, top_logprobs_num=request.top_logprobs or 0,
stream=request.stream, stream=request.stream,
return_text_in_logprobs=True, return_text_in_logprobs=True,
modalities=modalities, modalities=processed_messages.modalities,
lora_path=request.lora_path, lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host, bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port, bootstrap_port=request.bootstrap_port,
...@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -108,74 +101,42 @@ class OpenAIServingChat(OpenAIServingBase):
def _process_messages( def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool self, request: ChatCompletionRequest, is_multimodal: bool
) -> tuple[ ) -> MessageProcessingResult:
str,
Union[str, List[int]],
Optional[Any],
Optional[Any],
List[str],
List[str],
Optional[Any],
]:
"""Process chat messages and apply chat template""" """Process chat messages and apply chat template"""
tool_call_constraint = None tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str): # Apply chat template and its stop strings
# Apply chat template and its stop strings tools = None
tools = None if request.tools and request.tool_choice != "none":
if request.tools and request.tool_choice != "none": request.skip_special_tokens = False
request.skip_special_tokens = False if not isinstance(request.tool_choice, str):
if not isinstance(request.tool_choice, str): tools = [
tools = [ item.function.model_dump()
item.function.model_dump() for item in request.tools
for item in request.tools if item.function.name == request.tool_choice.function.name
if item.function.name == request.tool_choice.function.name ]
] else:
else: tools = [item.function.model_dump() for item in request.tools]
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser) parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint( tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
request.tool_choice
)
# Use chat template # Use chat template
if self.template_manager.chat_template_name is None: if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = ( result = self._apply_jinja_template(request, tools, is_multimodal)
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
else: else:
# Use raw prompt result = self._apply_conversation_template(request, is_multimodal)
prompt_ids = request.messages
stop = request.stop or [] result.tool_call_constraint = tool_call_constraint
image_data = None return result
audio_data = None
modalities = []
prompt = request.messages
return (
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
)
def _apply_jinja_template( def _apply_jinja_template(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
tools: Optional[List[Dict]], tools: Optional[List[Dict]],
is_multimodal: bool, is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: ) -> MessageProcessingResult:
"""Apply Jinja chat template""" """Apply Jinja chat template"""
prompt = "" prompt = ""
prompt_ids = [] prompt_ids = []
...@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -253,13 +214,20 @@ class OpenAIServingChat(OpenAIServingBase):
image_data = image_data if image_data else None image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None audio_data = audio_data if audio_data else None
modalities = modalities if modalities else [] modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _apply_conversation_template( def _apply_conversation_template(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
is_multimodal: bool, is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]: ) -> MessageProcessingResult:
"""Apply conversation template""" """Apply conversation template"""
prompt = "" prompt = ""
prompt_ids = [] prompt_ids = []
...@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -304,7 +272,14 @@ class OpenAIServingChat(OpenAIServingBase):
if not is_multimodal: if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
return prompt, prompt_ids, image_data, audio_data, modalities, stop return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)
def _build_sampling_params( def _build_sampling_params(
self, self,
......
...@@ -13,7 +13,10 @@ from unittest.mock import Mock, patch ...@@ -13,7 +13,10 @@ from unittest.mock import Mock, patch
from fastapi import Request from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
MessageProcessingResult,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
...@@ -104,7 +107,7 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -104,7 +107,7 @@ class ServingChatTestCase(unittest.TestCase):
conv_ins.stop_str = ["</s>"] conv_ins.stop_str = ["</s>"]
conv_mock.return_value = conv_ins conv_mock.return_value = conv_ins
proc_mock.return_value = ( proc_mock.return_value = MessageProcessingResult(
"Test prompt", "Test prompt",
[1, 2, 3], [1, 2, 3],
None, None,
...@@ -119,6 +122,59 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -119,6 +122,59 @@ class ServingChatTestCase(unittest.TestCase):
self.assertFalse(adapted.stream) self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req) self.assertEqual(processed, self.basic_req)
def test_stop_str_isolation_between_requests(self):
"""Test that stop strings from one request don't affect subsequent requests.
This tests the fix for the bug where conv.stop_str was being mutated globally,
causing stop strings from one request to persist in subsequent requests.
"""
# Mock conversation template with initial stop_str
initial_stop_str = ["\n"]
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock:
# Create a mock conversation object that will be returned by generate_chat_conv
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = None
conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = (
initial_stop_str.copy()
) # Template's default stop strings
conv_mock.return_value = conv_ins
# First request with additional stop string
req1 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "First request"}],
stop=["CUSTOM_STOP"],
)
# Call the actual _apply_conversation_template method (not mocked)
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
# Verify first request has both stop strings
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
self.assertEqual(result1.stop, expected_stop1)
# Verify the original template's stop_str wasn't mutated after first request
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# Second request without additional stop string
req2 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Second request"}],
# No custom stop strings
)
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
self.assertEqual(result2.stop, initial_stop_str)
self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# ------------- sampling-params ------------- # ------------- sampling-params -------------
def test_sampling_param_build(self): def test_sampling_param_build(self):
req = ChatCompletionRequest( req = ChatCompletionRequest(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment