Unverified Commit 7ba3de0e authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[oai serving chat] Add argument `--sampling-defaults` and fix...

[oai serving chat] Add argument `--sampling-defaults` and fix `ChatCompletionRequest` defaults (#11304)
parent fde9b963
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
import math import math
import os import os
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -90,6 +90,7 @@ class ModelConfig: ...@@ -90,6 +90,7 @@ class ModelConfig:
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai",
) -> None: ) -> None:
# Parse args # Parse args
self.model_path = model_path self.model_path = model_path
...@@ -98,6 +99,7 @@ class ModelConfig: ...@@ -98,6 +99,7 @@ class ModelConfig:
self.modelopt_quant = modelopt_quant self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model self.is_draft_model = is_draft_model
self.model_impl = model_impl self.model_impl = model_impl
self.sampling_defaults = sampling_defaults
# Get hf config # Get hf config
self._maybe_pull_model_tokenizer_from_remote() self._maybe_pull_model_tokenizer_from_remote()
...@@ -214,6 +216,7 @@ class ModelConfig: ...@@ -214,6 +216,7 @@ class ModelConfig:
modelopt_quant=server_args.modelopt_quant, modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl, model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults,
**kwargs, **kwargs,
) )
...@@ -659,6 +662,38 @@ class ModelConfig: ...@@ -659,6 +662,38 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids eos_ids = eos_ids | generation_eos_ids
return eos_ids return eos_ids
def get_default_sampling_params(self) -> dict[str, Any]:
"""
Get default sampling parameters from the model's generation config.
This method returns non-default sampling parameters from the model's
generation_config.json when sampling_defaults is set to "model".
Returns:
A dictionary containing the non-default sampling parameters.
"""
if self.sampling_defaults != "model":
return {}
if self.hf_generation_config is None:
return {}
config = self.hf_generation_config.to_dict()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
default_sampling_params = {
p: config.get(p) for p in available_params if config.get(p) is not None
}
return default_sampling_params
def _maybe_pull_model_tokenizer_from_remote(self) -> None: def _maybe_pull_model_tokenizer_from_remote(self) -> None:
""" """
Pull the model config files to a temporary Pull the model config files to a temporary
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""Pydantic models for OpenAI API protocol""" """Pydantic models for OpenAI API protocol"""
import logging
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
...@@ -37,6 +38,10 @@ from pydantic import ( ...@@ -37,6 +38,10 @@ from pydantic import (
) )
from typing_extensions import Literal from typing_extensions import Literal
from sglang.utils import convert_json_schema_to_str
logger = logging.getLogger(__name__)
DEFAULT_MODEL_NAME = "default" DEFAULT_MODEL_NAME = "default"
...@@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stream: bool = False stream: bool = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
temperature: float = 0.7 temperature: Optional[float] = None
top_p: float = 1.0 top_p: Optional[float] = None
user: Optional[str] = None user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None]) tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
...@@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel): ...@@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel):
"Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.", "Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.",
) )
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: Optional[int] = None
min_p: Optional[float] = None
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
# OpenAI/SGLang default sampling parameters
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
"repetition_penalty": 1.0,
}
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def set_tool_choice_default(cls, values): def set_tool_choice_default(cls, values):
...@@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel): ...@@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel):
return values return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models. def to_sampling_params(
top_k: int = -1 self,
min_p: float = 0.0 stop: List[str],
min_tokens: int = 0 model_generation_config: Dict[str, Any],
regex: Optional[str] = None tool_call_constraint: Optional[Any] = None,
ebnf: Optional[str] = None ) -> Dict[str, Any]:
repetition_penalty: float = 1.0 """
stop_token_ids: Optional[List[int]] = None Convert request to sampling parameters.
no_stop_trim: bool = False Priority: user value > model generation_config > OpenAI defaults
ignore_eos: bool = False """
continue_final_message: bool = False
skip_special_tokens: bool = True def get_param(param_name: str):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None value = getattr(self, param_name)
session_params: Optional[Dict] = None if value is None:
separate_reasoning: bool = True return model_generation_config.get(
stream_reasoning: bool = True param_name, self._DEFAULT_SAMPLING_PARAMS[param_name]
chat_template_kwargs: Optional[Dict] = None )
return value
sampling_params = {
"temperature": get_param("temperature"),
"max_new_tokens": self.max_tokens or self.max_completion_tokens,
"min_new_tokens": self.min_tokens,
"stop": stop,
"stop_token_ids": self.stop_token_ids,
"top_p": get_param("top_p"),
"top_k": get_param("top_k"),
"min_p": get_param("min_p"),
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": get_param("repetition_penalty"),
"regex": self.regex,
"ebnf": self.ebnf,
"n": self.n,
"no_stop_trim": self.no_stop_trim,
"ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens,
"logit_bias": self.logit_bias,
}
# For request id if self.response_format and self.response_format.type == "json_schema":
rid: Optional[Union[List[str], str]] = None sampling_params["json_schema"] = convert_json_schema_to_str(
# Extra key for classifying the request (e.g. cache_salt) self.response_format.json_schema.schema_
extra_key: Optional[Union[List[str], str]] = None )
# Cache salt for request caching elif self.response_format and self.response_format.type == "json_object":
cache_salt: Optional[Union[List[str], str]] = None sampling_params["json_schema"] = '{"type": "object"}'
# Priority for the request elif self.response_format and self.response_format.type == "structural_tag":
priority: Optional[int] = None sampling_params["structural_tag"] = convert_json_schema_to_str(
self.response_format.model_dump(by_alias=True)
)
# For PD disaggregation # Check if there are already existing output constraints
bootstrap_host: Optional[Union[List[str], str]] = None has_existing_constraints = (
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None sampling_params.get("regex")
bootstrap_room: Optional[Union[List[int], int]] = None or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
......
...@@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput ...@@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
...@@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase):
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
# Get default sampling parameters from model's generation config
self.default_sampling_params = (
self.tokenizer_manager.model_config.get_default_sampling_params()
)
if self.default_sampling_params:
logger.info(
f"Using default chat sampling params from model generation config: {self.default_sampling_params}",
)
def _request_id_prefix(self) -> str: def _request_id_prefix(self) -> str:
return "chatcmpl-" return "chatcmpl-"
...@@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase):
processed_messages = self._process_messages(request, is_multimodal) processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters # Build sampling parameters
sampling_params = self._build_sampling_params( sampling_params = request.to_sampling_params(
request, stop=processed_messages.stop,
processed_messages.stop, model_generation_config=self.default_sampling_params,
processed_messages.tool_call_constraint, tool_call_constraint=processed_messages.tool_call_constraint,
) )
# Handle single vs multiple requests # Handle single vs multiple requests
...@@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase):
stop=stop, stop=stop,
) )
def _build_sampling_params(
self,
request: ChatCompletionRequest,
stop: List[str],
tool_call_constraint: Optional[Any],
) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
async def _handle_streaming_request( async def _handle_streaming_request(
self, self,
adapted_request: GenerateReqInput, adapted_request: GenerateReqInput,
......
...@@ -252,6 +252,7 @@ class ServerArgs: ...@@ -252,6 +252,7 @@ class ServerArgs:
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
tool_server: Optional[str] = None tool_server: Optional[str] = None
sampling_defaults: str = "model"
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
...@@ -1872,6 +1873,16 @@ class ServerArgs: ...@@ -1872,6 +1873,16 @@ class ServerArgs:
default=ServerArgs.tool_call_parser, default=ServerArgs.tool_call_parser,
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.", help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
) )
parser.add_argument(
"--sampling-defaults",
type=str,
choices=["openai", "model"],
default=ServerArgs.sampling_defaults,
help="Where to get default sampling parameters. "
"'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). "
"'model' uses the model's generation_config.json to get the recommended "
"sampling parameters if available. Default is 'model'.",
)
parser.add_argument( parser.add_argument(
"--tool-server", "--tool-server",
type=str, type=str,
......
...@@ -150,10 +150,26 @@ class TestChatCompletionRequest(unittest.TestCase): ...@@ -150,10 +150,26 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(len(request.messages), 1) self.assertEqual(len(request.messages), 1)
self.assertEqual(request.messages[0].role, "user") self.assertEqual(request.messages[0].role, "user")
self.assertEqual(request.messages[0].content, "Hello") self.assertEqual(request.messages[0].content, "Hello")
self.assertEqual(request.temperature, 0.7) # default self.assertEqual(request.temperature, None) # default
self.assertFalse(request.stream) # default self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools self.assertEqual(request.tool_choice, "none") # default when no tools
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
params = req.to_sampling_params(["</s>"], {}, None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
def test_chat_completion_tool_choice_validation(self): def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic""" """Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
......
...@@ -177,28 +177,6 @@ class ServingChatTestCase(unittest.TestCase): ...@@ -177,28 +177,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertNotIn("CUSTOM_STOP", result2.stop) self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str) self.assertEqual(conv_ins.stop_str, initial_stop_str)
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
):
params = self.chat._build_sampling_params(req, ["</s>"], None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
async def test_unstreamed_tool_args_completion(self): async def test_unstreamed_tool_args_completion(self):
"""Test that remaining tool call arguments are sent when generation finishes.""" """Test that remaining tool call arguments are sent when generation finishes."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment