"tests/vscode:/vscode.git/clone" did not exist on "8ecbfa57ae5b17ee551d99a98c9121302123d20b"
Unverified Commit 91678474 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router: Fix constraint proto and `build_constraint` in grpc router (#10881)

parent d511b2d9
...@@ -438,6 +438,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -438,6 +438,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex = None regex = None
json_schema = None json_schema = None
ebnf_grammar = None ebnf_grammar = None
structural_tag = None
if grpc_params.HasField("regex"): if grpc_params.HasField("regex"):
regex = grpc_params.regex regex = grpc_params.regex
...@@ -445,6 +446,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -445,6 +446,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
json_schema = grpc_params.json_schema json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"): elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar ebnf_grammar = grpc_params.ebnf_grammar
elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag
return SGLSamplingParams( return SGLSamplingParams(
temperature=grpc_params.temperature or 1.0, temperature=grpc_params.temperature or 1.0,
...@@ -465,6 +468,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -465,6 +468,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
regex=regex, regex=regex,
json_schema=json_schema, json_schema=json_schema,
ebnf=ebnf_grammar, ebnf=ebnf_grammar,
structural_tag=structural_tag,
n=grpc_params.n or 1, n=grpc_params.n or 1,
ignore_eos=grpc_params.ignore_eos, ignore_eos=grpc_params.ignore_eos,
) )
......
...@@ -47,24 +47,24 @@ message SamplingParams { ...@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13; string regex = 13;
string json_schema = 14; string json_schema = 14;
string ebnf_grammar = 15; string ebnf_grammar = 15;
string structural_tag = 16;
} }
// LoRA adapter // LoRA adapter
string lora_path = 16; string lora_path = 17;
// Speculative decoding // Speculative decoding
int32 n = 17; // Number of samples int32 n = 18; // Number of samples
// Token healing // Token healing
bool token_healing = 18; bool token_healing = 19;
// Additional parameters // Additional parameters
int32 min_new_tokens = 19; int32 min_new_tokens = 20;
bool ignore_eos = 20; bool ignore_eos = 21;
bool no_stop_trim = 21; bool no_stop_trim = 22;
int32 stream_interval = 22; int32 stream_interval = 23;
map<string, float> logit_bias = 23; map<string, float> logit_bias = 24;
string structural_tag = 24;
// Custom parameters for extensibility // Custom parameters for extensibility
google.protobuf.Struct custom_params = 25; google.protobuf.Struct custom_params = 25;
......
...@@ -12,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union ...@@ -12,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class SamplingParams(_message.Message): class SamplingParams(_message.Message):
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params") __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
class LogitBiasEntry(_message.Message): class LogitBiasEntry(_message.Message):
__slots__ = ("key", "value") __slots__ = ("key", "value")
KEY_FIELD_NUMBER: _ClassVar[int] KEY_FIELD_NUMBER: _ClassVar[int]
...@@ -35,6 +35,7 @@ class SamplingParams(_message.Message): ...@@ -35,6 +35,7 @@ class SamplingParams(_message.Message):
REGEX_FIELD_NUMBER: _ClassVar[int] REGEX_FIELD_NUMBER: _ClassVar[int]
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
LORA_PATH_FIELD_NUMBER: _ClassVar[int] LORA_PATH_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int]
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int] TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
...@@ -43,7 +44,6 @@ class SamplingParams(_message.Message): ...@@ -43,7 +44,6 @@ class SamplingParams(_message.Message):
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int] STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int] CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
temperature: float temperature: float
top_p: float top_p: float
...@@ -60,6 +60,7 @@ class SamplingParams(_message.Message): ...@@ -60,6 +60,7 @@ class SamplingParams(_message.Message):
regex: str regex: str
json_schema: str json_schema: str
ebnf_grammar: str ebnf_grammar: str
structural_tag: str
lora_path: str lora_path: str
n: int n: int
token_healing: bool token_healing: bool
...@@ -68,9 +69,8 @@ class SamplingParams(_message.Message): ...@@ -68,9 +69,8 @@ class SamplingParams(_message.Message):
no_stop_trim: bool no_stop_trim: bool
stream_interval: int stream_interval: int
logit_bias: _containers.ScalarMap[str, float] logit_bias: _containers.ScalarMap[str, float]
structural_tag: str
custom_params: _struct_pb2.Struct custom_params: _struct_pb2.Struct
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
class DisaggregatedParams(_message.Message): class DisaggregatedParams(_message.Message):
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
......
# This file is auto-generated. Do not edit manually.
# Regenerate with: python compile_proto.py
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services.""" """Client and server classes corresponding to protobuf-defined services."""
import grpc import grpc
......
...@@ -47,24 +47,24 @@ message SamplingParams { ...@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13; string regex = 13;
string json_schema = 14; string json_schema = 14;
string ebnf_grammar = 15; string ebnf_grammar = 15;
string structural_tag = 16;
} }
// LoRA adapter // LoRA adapter
string lora_path = 16; string lora_path = 17;
// Speculative decoding // Speculative decoding
int32 n = 17; // Number of samples int32 n = 18; // Number of samples
// Token healing // Token healing
bool token_healing = 18; bool token_healing = 19;
// Additional parameters // Additional parameters
int32 min_new_tokens = 19; int32 min_new_tokens = 20;
bool ignore_eos = 20; bool ignore_eos = 21;
bool no_stop_trim = 21; bool no_stop_trim = 22;
int32 stream_interval = 22; int32 stream_interval = 23;
map<string, float> logit_bias = 23; map<string, float> logit_bias = 24;
string structural_tag = 24;
// Custom parameters for extensibility // Custom parameters for extensibility
google.protobuf.Struct custom_params = 25; google.protobuf.Struct custom_params = 25;
......
...@@ -241,14 +241,14 @@ impl GrpcRouter { ...@@ -241,14 +241,14 @@ impl GrpcRouter {
debug!("Tokenized {} tokens from input", token_ids.len()); debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Build tool constraints if needed // Step 5: Build tool constraints if needed
let structural_tag = if let Some(tools) = &body.tools { let tool_call_constraint = if let Some(tools) = &body.tools {
self.generate_tool_constraints(tools, &body.tool_choice, &body.model) self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
} else { } else {
None None
}; };
// Step 6: Build SamplingParams for gRPC // Step 6: Build SamplingParams for gRPC
let sampling_params = match self.build_grpc_sampling_params(body, structural_tag) { let sampling_params = match self.build_grpc_sampling_params(body, tool_call_constraint) {
Ok(params) => params, Ok(params) => params,
Err(e) => { Err(e) => {
error!("Failed to build sampling parameters: {}", e); error!("Failed to build sampling parameters: {}", e);
...@@ -286,6 +286,41 @@ impl GrpcRouter { ...@@ -286,6 +286,41 @@ impl GrpcRouter {
} }
// ============ Helper Methods ============ // ============ Helper Methods ============
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Process chat messages and apply template /// Process chat messages and apply template
fn process_chat_messages( fn process_chat_messages(
...@@ -516,7 +551,7 @@ impl GrpcRouter { ...@@ -516,7 +551,7 @@ impl GrpcRouter {
fn build_grpc_sampling_params( fn build_grpc_sampling_params(
&self, &self,
request: &ChatCompletionRequest, request: &ChatCompletionRequest,
structural_tag: Option<String>, tool_call_constraint: Option<(String, String)>,
) -> Result<proto::SamplingParams, String> { ) -> Result<proto::SamplingParams, String> {
let stop_sequences = self.extract_stop_strings(request); let stop_sequences = self.extract_stop_strings(request);
...@@ -555,8 +590,7 @@ impl GrpcRouter { ...@@ -555,8 +590,7 @@ impl GrpcRouter {
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens, skip_special_tokens,
n: request.n.unwrap_or(1) as i32, n: request.n.unwrap_or(1) as i32,
structural_tag: structural_tag.unwrap_or_default(), constraint: self.build_constraint(request, tool_call_constraint)?,
constraint: self.build_constraint(request)?,
..Default::default() ..Default::default()
}) })
} }
...@@ -574,28 +608,48 @@ impl GrpcRouter { ...@@ -574,28 +608,48 @@ impl GrpcRouter {
fn build_constraint( fn build_constraint(
&self, &self,
request: &ChatCompletionRequest, request: &ChatCompletionRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<Option<proto::sampling_params::Constraint>, String> { ) -> Result<Option<proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format { if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format {
let schema_str = serde_json::to_string(&json_schema.schema) let schema_str = serde_json::to_string(&json_schema.schema)
.map_err(|e| format!("Failed to serialize JSON schema: {}", e))?; .map_err(|e| format!("Failed to serialize JSON schema: {}", e))?;
return Ok(Some(proto::sampling_params::Constraint::JsonSchema( constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str));
schema_str,
)));
} }
if let Some(ebnf) = &request.ebnf { if let Some(ebnf) = &request.ebnf {
return Ok(Some(proto::sampling_params::Constraint::EbnfGrammar( constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(), ebnf.clone(),
))); ));
} }
if let Some(regex) = &request.regex { if let Some(regex) = &request.regex {
return Ok(Some(proto::sampling_params::Constraint::Regex( constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
regex.clone(), }
)));
// Handle tool call constraint
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
if !constraints.is_empty() {
return Err("Constrained decoding is not compatible with tool calls.".to_string());
}
let tool_constraint = match constraint_type.as_str() {
"structural_tag" => {
proto::sampling_params::Constraint::StructuralTag(constraint_value)
}
"json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value),
"ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value),
"regex" => proto::sampling_params::Constraint::Regex(constraint_value),
_ => return Err(format!("Unknown constraint type: {}", constraint_type)),
};
constraints.push(tool_constraint);
} }
Ok(None) match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple constraints are not allowed.".to_string()),
}
} }
/// Generate tool constraints for structured generation /// Generate tool constraints for structured generation
...@@ -604,52 +658,19 @@ impl GrpcRouter { ...@@ -604,52 +658,19 @@ impl GrpcRouter {
_tools: &[crate::protocols::spec::Tool], _tools: &[crate::protocols::spec::Tool],
_tool_choice: &Option<crate::protocols::spec::ToolChoice>, _tool_choice: &Option<crate::protocols::spec::ToolChoice>,
model: &str, model: &str,
) -> Option<String> { ) -> Option<(String, String)> {
let _parser = self.tool_parser_registry.get_parser(model)?; let _parser = self.tool_parser_registry.get_parser(model)?;
// TODO: Implement actual constraint generation logic
// For now, return None as this is placeholder implementation
None None
} }
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Get or create a gRPC client for the worker /// Get or create a gRPC client for the worker
async fn get_or_create_grpc_client( async fn get_or_create_grpc_client(
&self, &self,
worker_url: &str, worker_url: &str,
) -> Result<SglangSchedulerClient, String> { ) -> Result<SglangSchedulerClient, String> {
// TODO: move to worker
debug!("Creating new gRPC client for worker: {}", worker_url); debug!("Creating new gRPC client for worker: {}", worker_url);
SglangSchedulerClient::connect(worker_url) SglangSchedulerClient::connect(worker_url)
.await .await
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment