Unverified Commit 33b3c0f8 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] grpc router generate endpoint support (#11070)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent e5281f84
use std::convert::TryFrom;
use std::time::Duration; use std::time::Duration;
use tonic::{transport::Channel, Request}; use tonic::{transport::Channel, Request};
use tracing::debug; use tracing::debug;
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat}; use crate::protocols::spec::{
ChatCompletionRequest, GenerateRequest, ResponseFormat,
SamplingParams as GenerateSamplingParams, StringOrArray,
};
// Include the generated protobuf code // Include the generated protobuf code
pub mod proto { pub mod proto {
...@@ -112,6 +116,37 @@ impl SglangSchedulerClient { ...@@ -112,6 +116,37 @@ impl SglangSchedulerClient {
Ok(grpc_request) Ok(grpc_request)
} }
/// Build a basic GenerateRequest from the SGLang spec GenerateRequest
pub fn build_plain_generate_request(
&self,
request_id: String,
body: &GenerateRequest,
original_text: Option<String>,
token_ids: Vec<u32>,
) -> Result<proto::GenerateRequest, String> {
let sampling_params =
Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?;
let grpc_request = proto::GenerateRequest {
request_id,
tokenized: Some(proto::TokenizedInput {
original_text: original_text.unwrap_or_default(),
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
return_logprob: body.return_logprob,
logprob_start_len: -1,
top_logprobs_num: 0,
token_ids_logprob: vec![],
return_hidden_states: body.return_hidden_states,
stream: body.stream,
log_metrics: true,
..Default::default()
};
Ok(grpc_request)
}
/// Build gRPC SamplingParams from OpenAI request /// Build gRPC SamplingParams from OpenAI request
fn build_grpc_sampling_params( fn build_grpc_sampling_params(
&self, &self,
...@@ -165,8 +200,8 @@ impl SglangSchedulerClient { ...@@ -165,8 +200,8 @@ impl SglangSchedulerClient {
/// Extract stop strings from request /// Extract stop strings from request
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> { fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
match &request.stop { match &request.stop {
Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()], Some(StringOrArray::String(s)) => vec![s.clone()],
Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(), Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![], None => vec![],
} }
} }
...@@ -218,6 +253,100 @@ impl SglangSchedulerClient { ...@@ -218,6 +253,100 @@ impl SglangSchedulerClient {
_ => Err("Multiple constraints are not allowed.".to_string()), _ => Err("Multiple constraints are not allowed.".to_string()),
} }
} }
fn build_single_constraint_from_plain(
params: &GenerateSamplingParams,
) -> Result<Option<proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
if let Some(json_schema) = &params.json_schema {
constraints.push(proto::sampling_params::Constraint::JsonSchema(
json_schema.clone(),
));
}
if let Some(regex) = &params.regex {
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
}
if let Some(ebnf) = &params.ebnf {
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
));
}
match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple structured constraints are not allowed".to_string()),
}
}
fn build_sampling_params_from_plain(
params: Option<&GenerateSamplingParams>,
) -> Result<proto::SamplingParams, String> {
let mut sampling = proto::SamplingParams {
temperature: 1.0,
top_p: 1.0,
top_k: -1,
repetition_penalty: 1.0,
n: 1,
..Default::default()
};
let Some(p) = params else {
return Ok(sampling);
};
// Simple field mappings using a macro
macro_rules! map_field {
($field:ident) => {
if let Some(val) = p.$field {
sampling.$field = val;
}
};
}
map_field!(temperature);
map_field!(top_p);
map_field!(top_k);
map_field!(frequency_penalty);
map_field!(presence_penalty);
map_field!(repetition_penalty);
map_field!(min_p);
map_field!(ignore_eos);
map_field!(skip_special_tokens);
map_field!(no_stop_trim);
// Handle stop sequences
if let Some(stop) = &p.stop {
match stop {
StringOrArray::String(s) => sampling.stop.push(s.clone()),
StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()),
}
}
// Handle stop token IDs
if let Some(stop_token_ids) = &p.stop_token_ids {
sampling.stop_token_ids = stop_token_ids.clone();
}
// Handle max_new_tokens with conversion
if let Some(max_new_tokens) = p.max_new_tokens {
sampling.max_new_tokens =
Some(i32::try_from(max_new_tokens).map_err(|_| {
"max_new_tokens must fit into a 32-bit signed integer".to_string()
})?);
}
// Handle min_tokens with conversion
if let Some(min_tokens) = p.min_tokens {
sampling.min_new_tokens = i32::try_from(min_tokens)
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
}
// Handle constraints (exactly one allowed)
sampling.constraint = Self::build_single_constraint_from_plain(p)?;
Ok(sampling)
}
} }
#[cfg(test)] #[cfg(test)]
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment