Unverified Commit 7ba3de0e authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[oai serving chat] Add argument `--sampling-defaults` and fix...

[oai serving chat] Add argument `--sampling-defaults` and fix `ChatCompletionRequest` defaults (#11304)
parent fde9b963
......@@ -17,7 +17,7 @@ import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
......@@ -90,6 +90,7 @@ class ModelConfig:
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai",
) -> None:
# Parse args
self.model_path = model_path
......@@ -98,6 +99,7 @@ class ModelConfig:
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model
self.model_impl = model_impl
self.sampling_defaults = sampling_defaults
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
......@@ -214,6 +216,7 @@ class ModelConfig:
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults,
**kwargs,
)
......@@ -659,6 +662,38 @@ class ModelConfig:
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def get_default_sampling_params(self) -> dict[str, Any]:
"""
Get default sampling parameters from the model's generation config.
This method returns non-default sampling parameters from the model's
generation_config.json when sampling_defaults is set to "model".
Returns:
A dictionary containing the non-default sampling parameters.
"""
if self.sampling_defaults != "model":
return {}
if self.hf_generation_config is None:
return {}
config = self.hf_generation_config.to_dict()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
default_sampling_params = {
p: config.get(p) for p in available_params if config.get(p) is not None
}
return default_sampling_params
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
"""
Pull the model config files to a temporary
......
......@@ -13,6 +13,7 @@
# ==============================================================================
"""Pydantic models for OpenAI API protocol"""
import logging
import time
import uuid
from dataclasses import dataclass
......@@ -37,6 +38,10 @@ from pydantic import (
)
from typing_extensions import Literal
from sglang.utils import convert_json_schema_to_str
logger = logging.getLogger(__name__)
DEFAULT_MODEL_NAME = "default"
......@@ -445,8 +450,8 @@ class ChatCompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
stream_options: Optional[StreamOptions] = None
temperature: float = 0.7
top_p: float = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
......@@ -461,6 +466,47 @@ class ChatCompletionRequest(BaseModel):
"Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.",
)
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: Optional[int] = None
min_p: Optional[float] = None
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
# OpenAI/SGLang default sampling parameters
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
"repetition_penalty": 1.0,
}
@model_validator(mode="before")
@classmethod
def set_tool_choice_default(cls, values):
......@@ -531,37 +577,81 @@ class ChatCompletionRequest(BaseModel):
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
separate_reasoning: bool = True
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
def to_sampling_params(
self,
stop: List[str],
model_generation_config: Dict[str, Any],
tool_call_constraint: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Convert request to sampling parameters.
Priority: user value > model generation_config > OpenAI defaults
"""
def get_param(param_name: str):
value = getattr(self, param_name)
if value is None:
return model_generation_config.get(
param_name, self._DEFAULT_SAMPLING_PARAMS[param_name]
)
return value
sampling_params = {
"temperature": get_param("temperature"),
"max_new_tokens": self.max_tokens or self.max_completion_tokens,
"min_new_tokens": self.min_tokens,
"stop": stop,
"stop_token_ids": self.stop_token_ids,
"top_p": get_param("top_p"),
"top_k": get_param("top_k"),
"min_p": get_param("min_p"),
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": get_param("repetition_penalty"),
"regex": self.regex,
"ebnf": self.ebnf,
"n": self.n,
"no_stop_trim": self.no_stop_trim,
"ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens,
"logit_bias": self.logit_bias,
}
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Cache salt for request caching
cache_salt: Optional[Union[List[str], str]] = None
# Priority for the request
priority: Optional[int] = None
if self.response_format and self.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
self.response_format.json_schema.schema_
)
elif self.response_format and self.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif self.response_format and self.response_format.type == "structural_tag":
sampling_params["structural_tag"] = convert_json_schema_to_str(
self.response_format.model_dump(by_alias=True)
)
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
class ChatMessage(BaseModel):
......
......@@ -44,7 +44,6 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.parser.conversation import generate_chat_conv
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.utils import convert_json_schema_to_str
if TYPE_CHECKING:
from sglang.srt.managers.template_manager import TemplateManager
......@@ -66,6 +65,15 @@ class OpenAIServingChat(OpenAIServingBase):
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
# Get default sampling parameters from model's generation config
self.default_sampling_params = (
self.tokenizer_manager.model_config.get_default_sampling_params()
)
if self.default_sampling_params:
logger.info(
f"Using default chat sampling params from model generation config: {self.default_sampling_params}",
)
def _request_id_prefix(self) -> str:
return "chatcmpl-"
......@@ -137,10 +145,10 @@ class OpenAIServingChat(OpenAIServingBase):
processed_messages = self._process_messages(request, is_multimodal)
# Build sampling parameters
sampling_params = self._build_sampling_params(
request,
processed_messages.stop,
processed_messages.tool_call_constraint,
sampling_params = request.to_sampling_params(
stop=processed_messages.stop,
model_generation_config=self.default_sampling_params,
tool_call_constraint=processed_messages.tool_call_constraint,
)
# Handle single vs multiple requests
......@@ -410,72 +418,6 @@ class OpenAIServingChat(OpenAIServingBase):
stop=stop,
)
def _build_sampling_params(
self,
request: ChatCompletionRequest,
stop: List[str],
tool_call_constraint: Optional[Any],
) -> Dict[str, Any]:
"""Build sampling parameters for the request"""
sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
}
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value
)
else:
sampling_params[constraint_type] = constraint_value
return sampling_params
async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
......
......@@ -252,6 +252,7 @@ class ServerArgs:
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
tool_server: Optional[str] = None
sampling_defaults: str = "model"
# Data parallelism
dp_size: int = 1
......@@ -1872,6 +1873,16 @@ class ServerArgs:
default=ServerArgs.tool_call_parser,
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
)
parser.add_argument(
"--sampling-defaults",
type=str,
choices=["openai", "model"],
default=ServerArgs.sampling_defaults,
help="Where to get default sampling parameters. "
"'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). "
"'model' uses the model's generation_config.json to get the recommended "
"sampling parameters if available. Default is 'model'.",
)
parser.add_argument(
"--tool-server",
type=str,
......
......@@ -150,10 +150,26 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertEqual(len(request.messages), 1)
self.assertEqual(request.messages[0].role, "user")
self.assertEqual(request.messages[0].content, "Hello")
self.assertEqual(request.temperature, 0.7) # default
self.assertEqual(request.temperature, None) # default
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
params = req.to_sampling_params(["</s>"], {}, None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
......
......@@ -177,28 +177,6 @@ class ServingChatTestCase(unittest.TestCase):
self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
):
params = self.chat._build_sampling_params(req, ["</s>"], None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
async def test_unstreamed_tool_args_completion(self):
"""Test that remaining tool call arguments are sent when generation finishes."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment