Unverified Commit a578d300 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Fix proto3 default value mismatches and cleanup unused fields (#11283)

parent 8c967037
...@@ -14,6 +14,7 @@ from concurrent import futures ...@@ -14,6 +14,7 @@ from concurrent import futures
from typing import AsyncIterator, Dict, Optional, Tuple from typing import AsyncIterator, Dict, Optional, Tuple
import grpc import grpc
from google.protobuf.json_format import MessageToDict
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
...@@ -483,28 +484,52 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -483,28 +484,52 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
elif grpc_params.HasField("structural_tag"): elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag structural_tag = grpc_params.structural_tag
# Handle optional parameters conversion
custom_params = (
MessageToDict(grpc_params.custom_params)
if grpc_params.HasField("custom_params")
else None
)
max_new_tokens = (
grpc_params.max_new_tokens
if grpc_params.HasField("max_new_tokens")
else None
)
stream_interval = (
grpc_params.stream_interval
if grpc_params.HasField("stream_interval")
else None
)
logit_bias = dict(grpc_params.logit_bias) if grpc_params.logit_bias else None
stop = list(grpc_params.stop) if grpc_params.stop else None
stop_token_ids = (
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
)
return SGLSamplingParams( return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0, temperature=grpc_params.temperature,
top_p=grpc_params.top_p or 1.0, top_p=grpc_params.top_p,
top_k=grpc_params.top_k or -1, top_k=grpc_params.top_k,
min_p=grpc_params.min_p or 0.0, min_p=grpc_params.min_p,
frequency_penalty=grpc_params.frequency_penalty or 0.0, frequency_penalty=grpc_params.frequency_penalty,
presence_penalty=grpc_params.presence_penalty or 0.0, presence_penalty=grpc_params.presence_penalty,
repetition_penalty=grpc_params.repetition_penalty or 1.0, repetition_penalty=grpc_params.repetition_penalty,
max_new_tokens=grpc_params.max_new_tokens or 128, max_new_tokens=max_new_tokens,
min_new_tokens=grpc_params.min_new_tokens or 0, min_new_tokens=grpc_params.min_new_tokens,
stop=list(grpc_params.stop) if grpc_params.stop else [], stop=stop,
stop_token_ids=( stop_token_ids=stop_token_ids,
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
),
skip_special_tokens=grpc_params.skip_special_tokens, skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens, spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
no_stop_trim=grpc_params.no_stop_trim,
regex=regex, regex=regex,
json_schema=json_schema, json_schema=json_schema,
ebnf=ebnf_grammar, ebnf=ebnf_grammar,
structural_tag=structural_tag, structural_tag=structural_tag,
n=grpc_params.n or 1, n=grpc_params.n,
ignore_eos=grpc_params.ignore_eos, ignore_eos=grpc_params.ignore_eos,
stream_interval=stream_interval,
logit_bias=logit_bias,
custom_params=custom_params,
) )
def _convert_output_logprobs_to_proto( def _convert_output_logprobs_to_proto(
......
...@@ -27,6 +27,11 @@ service SglangScheduler { ...@@ -27,6 +27,11 @@ service SglangScheduler {
// ===================== // =====================
// Sampling parameters matching SGLang's SamplingParams // Sampling parameters matching SGLang's SamplingParams
//
// IMPORTANT: Do not use SamplingParams::default() directly!
// The proto3 defaults (0 for numeric fields) do NOT match the semantic defaults
// (temperature=1.0, top_p=1.0, top_k=-1, etc.). Always construct with explicit values
// or use the conversion functions in sglang_scheduler.rs / grpc_server.py.
message SamplingParams { message SamplingParams {
float temperature = 1; float temperature = 1;
float top_p = 2; float top_p = 2;
...@@ -50,24 +55,18 @@ message SamplingParams { ...@@ -50,24 +55,18 @@ message SamplingParams {
string structural_tag = 16; string structural_tag = 16;
} }
// LoRA adapter
string lora_path = 17;
// Speculative decoding // Speculative decoding
int32 n = 18; // Number of samples int32 n = 17; // Number of samples
// Token healing
bool token_healing = 19;
// Additional parameters // Additional parameters
int32 min_new_tokens = 20; int32 min_new_tokens = 18;
bool ignore_eos = 21; bool ignore_eos = 19;
bool no_stop_trim = 22; bool no_stop_trim = 20;
int32 stream_interval = 23; optional int32 stream_interval = 21;
map<string, float> logit_bias = 24; map<string, float> logit_bias = 22;
// Custom parameters for extensibility // Custom parameters for extensibility
google.protobuf.Struct custom_params = 25; google.protobuf.Struct custom_params = 23;
} }
......
...@@ -11,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union ...@@ -11,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message): class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params") __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "n", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
class LogitBiasEntry(_message.Message): class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value") __slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int] KEY_FIELD_NUMBER: _ClassVar[int]
...@@ -35,9 +35,7 @@ class SamplingParams(_message.Message): ...@@ -35,9 +35,7 @@ class SamplingParams(_message.Message):
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int] STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int] IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
...@@ -60,16 +58,14 @@ class SamplingParams(_message.Message): ...@@ -60,16 +58,14 @@ class SamplingParams(_message.Message):
json_schema: str json_schema: str
ebnf_grammar: str ebnf_grammar: str
structural_tag: str structural_tag: str
lora_path: str
n: int n: int
token_healing: bool
min_new_tokens: int min_new_tokens: int
ignore_eos: bool ignore_eos: bool
no_stop_trim: bool no_stop_trim: bool
stream_interval: int stream_interval: int
logit_bias: _containers.ScalarMap[str, float] logit_bias: _containers.ScalarMap[str, float]
custom_params: _struct_pb2.Struct custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., n: _Optional[int] = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message): class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
......
...@@ -202,6 +202,7 @@ impl SglangSchedulerClient { ...@@ -202,6 +202,7 @@ impl SglangSchedulerClient {
stop: stop_sequences, stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens, skip_special_tokens,
spaces_between_special_tokens: true, // Default from Python SamplingParams
ignore_eos: request.ignore_eos, ignore_eos: request.ignore_eos,
no_stop_trim: request.no_stop_trim, no_stop_trim: request.no_stop_trim,
n: request.n.unwrap_or(1) as i32, n: request.n.unwrap_or(1) as i32,
...@@ -301,6 +302,8 @@ impl SglangSchedulerClient { ...@@ -301,6 +302,8 @@ impl SglangSchedulerClient {
top_k: -1, top_k: -1,
repetition_penalty: 1.0, repetition_penalty: 1.0,
n: 1, n: 1,
skip_special_tokens: true,
spaces_between_special_tokens: true,
..Default::default() ..Default::default()
}; };
...@@ -444,10 +447,24 @@ mod tests { ...@@ -444,10 +447,24 @@ mod tests {
#[test] #[test]
fn test_sampling_params_defaults() { fn test_sampling_params_defaults() {
let params = proto::SamplingParams::default(); let params = proto::SamplingParams::default();
// Numeric fields have proto defaults (0)
assert_eq!(params.temperature, 0.0); assert_eq!(params.temperature, 0.0);
assert_eq!(params.max_new_tokens, None);
assert_eq!(params.top_p, 0.0); assert_eq!(params.top_p, 0.0);
assert_eq!(params.top_k, 0); assert_eq!(params.top_k, 0);
assert_eq!(params.repetition_penalty, 0.0);
assert_eq!(params.n, 0);
// Bool fields have proto defaults (false)
assert!(!params.skip_special_tokens);
assert!(!params.spaces_between_special_tokens);
assert!(!params.ignore_eos);
assert!(!params.no_stop_trim);
// Optional int fields should be None
assert_eq!(params.max_new_tokens, None);
assert_eq!(params.stream_interval, None);
// Other non-optional fields
assert_eq!(params.min_p, 0.0);
assert_eq!(params.frequency_penalty, 0.0);
assert_eq!(params.presence_penalty, 0.0);
assert!(params.stop.is_empty()); assert!(params.stop.is_empty());
} }
......
...@@ -27,6 +27,11 @@ service SglangScheduler { ...@@ -27,6 +27,11 @@ service SglangScheduler {
// ===================== // =====================
// Sampling parameters matching SGLang's SamplingParams // Sampling parameters matching SGLang's SamplingParams
//
// IMPORTANT: Do not use SamplingParams::default() directly!
// The proto3 defaults (0 for numeric fields) do NOT match the semantic defaults
// (temperature=1.0, top_p=1.0, top_k=-1, etc.). Always construct with explicit values
// or use the conversion functions in sglang_scheduler.rs / grpc_server.py.
message SamplingParams { message SamplingParams {
float temperature = 1; float temperature = 1;
float top_p = 2; float top_p = 2;
...@@ -50,24 +55,18 @@ message SamplingParams { ...@@ -50,24 +55,18 @@ message SamplingParams {
string structural_tag = 16; string structural_tag = 16;
} }
// LoRA adapter
string lora_path = 17;
// Speculative decoding // Speculative decoding
int32 n = 18; // Number of samples int32 n = 17; // Number of samples
// Token healing
bool token_healing = 19;
// Additional parameters // Additional parameters
int32 min_new_tokens = 20; int32 min_new_tokens = 18;
bool ignore_eos = 21; bool ignore_eos = 19;
bool no_stop_trim = 22; bool no_stop_trim = 20;
int32 stream_interval = 23; optional int32 stream_interval = 21;
map<string, float> logit_bias = 24; map<string, float> logit_bias = 22;
// Custom parameters for extensibility // Custom parameters for extensibility
google.protobuf.Struct custom_params = 25; google.protobuf.Struct custom_params = 23;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment