Unverified Commit ce67b2d5 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router]restructure protocol modules for better organization (#9321)

parent 3c2c9f6c
......@@ -3,9 +3,13 @@ use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
use sglang_router_rs::protocols::{
common::StringOrArray,
generate::{GenerateParameters, GenerateRequest, SamplingParams},
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
};
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
......
......@@ -5,8 +5,8 @@ use std::collections::HashMap;
pub mod core;
pub mod metrics;
pub mod middleware;
pub mod openai_api_types;
pub mod policies;
pub mod protocols;
pub mod reasoning_parser;
pub mod routers;
pub mod server;
......
This diff is collapsed.
// Common types shared across all protocol implementations
use serde::{Deserialize, Serialize};
/// Helper function for serde default value
pub fn default_true() -> bool {
true
}
/// Common trait for all generation requests across different APIs
pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming
fn is_stream(&self) -> bool;
/// Get the model name if specified
fn get_model(&self) -> Option<&str>;
/// Extract text content for routing decisions
fn extract_text_for_routing(&self) -> String;
}
/// Helper type for string or array of strings
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}
// SGLang native Generate API module (/generate)
pub mod request;
pub mod types;
// Re-export main types for convenience
pub use request::GenerateRequest;
pub use types::{GenerateParameters, InputIds, SamplingParams};
// Generate API request types (/generate)
use crate::protocols::common::{GenerationRequest, LoRAPath, StringOrArray};
use crate::protocols::generate::types::{GenerateParameters, InputIds, SamplingParams};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<StringOrArray>,
/// Text input - SGLang native format
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
/// Input IDs for tokenized input
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
/// Generation parameters
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Whether to return logprobs
#[serde(default)]
pub return_logprob: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Request ID for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
// Generate requests typically don't have a model field
None
}
fn extract_text_for_routing(&self) -> String {
// Check fields in priority order: text, prompt, inputs
if let Some(ref text) = self.text {
return text.clone();
}
if let Some(ref prompt) = self.prompt {
return match prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
};
}
if let Some(ref input_ids) = self.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| id.to_string())
.collect::<Vec<String>>()
.join(" "),
InputIds::Batch(batches) => batches
.iter()
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
.collect::<Vec<String>>()
.join(" "),
};
}
// No text input found
String::new()
}
}
// Types for the SGLang native /generate API
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GenerateParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decoder_input_details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_full_text: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub watermark: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<crate::protocols::common::StringOrArray>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore_eos: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_special_tokens: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>,
}
// Protocol definitions and validation for various LLM APIs
// This module provides a structured approach to handling different API protocols
pub mod common;
pub mod generate;
pub mod openai;
// Chat Completions API module
pub mod request;
pub mod response;
pub mod types;
// Re-export main types for convenience
pub use request::ChatCompletionRequest;
pub use response::{
ChatChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice,
};
pub use types::*;
// Chat Completions API request types
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
use crate::protocols::openai::chat::types::*;
use crate::protocols::openai::common::StreamOptions;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionRequest {
/// ID of the model to use
pub model: String,
/// A list of messages comprising the conversation so far
pub messages: Vec<ChatMessage>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// If set, partial message deltas will be sent
#[serde(default)]
pub stream: bool,
/// Options for streaming response
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub logprobs: bool,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
/// A list of tools the model may call
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// Deprecated: use tools instead
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Continue generating from final assistant message
#[serde(default)]
pub continue_final_message: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Separate reasoning content from final answer (O1-style models)
#[serde(default = "default_true")]
pub separate_reasoning: bool,
/// Stream reasoning tokens during generation
#[serde(default = "default_true")]
pub stream_reasoning: bool,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Extract text from messages for routing decisions
self.messages
.iter()
.filter_map(|msg| match msg {
ChatMessage::System { content, .. } => Some(content.clone()),
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(parts) => {
let texts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect();
Some(texts.join(" "))
}
},
ChatMessage::Assistant {
content,
reasoning_content,
..
} => {
// Combine content and reasoning content for routing decisions
let main_content = content.clone().unwrap_or_default();
let reasoning = reasoning_content.clone().unwrap_or_default();
if main_content.is_empty() && reasoning.is_empty() {
None
} else {
Some(format!("{} {}", main_content, reasoning).trim().to_string())
}
}
ChatMessage::Tool { content, .. } => Some(content.clone()),
ChatMessage::Function { content, .. } => Some(content.clone()),
})
.collect::<Vec<String>>()
.join(" ")
}
}
// Chat Completions API response types
use crate::protocols::openai::chat::types::{ChatMessage, ChatMessageDelta};
use crate::protocols::openai::common::{ChatLogProbs, Usage};
use serde::{Deserialize, Serialize};
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String, // "chat.completion"
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
}
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String, // "chat.completion.chunk"
pub created: u64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
}
// Types specific to the Chat Completions API
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============= Message Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatMessage {
System {
role: String, // "system"
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
role: String, // "user"
content: UserMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
role: String, // "assistant"
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
},
Tool {
role: String, // "tool"
content: String,
tool_call_id: String,
},
Function {
role: String, // "function"
content: String,
name: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum UserMessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, // "auto", "low", or "high"
}
// ============= Response Format Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
// ============= Tool/Function Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
None,
Auto,
Required,
Function {
#[serde(rename = "type")]
tool_type: String, // "function"
function: FunctionChoice,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
pub arguments: String, // JSON string
}
// ============= Streaming Delta Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// Common types shared across OpenAI API implementations
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ============= Shared Request Components =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
// ============= Usage Tracking =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
// ============= Logprobs Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbs {
pub content: Option<Vec<ChatLogProbsContent>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
// Completions API module (v1/completions)
pub mod request;
pub mod response;
// Re-export main types for convenience
pub use request::CompletionRequest;
pub use response::{
CompletionChoice, CompletionResponse, CompletionStreamChoice, CompletionStreamResponse,
};
// Completions API request types (v1/completions) - DEPRECATED but still supported
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
use crate::protocols::openai::common::StreamOptions;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionRequest {
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub model: String,
/// The prompt(s) to generate completions for
pub prompt: StringOrArray,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// Whether to stream back partial progress
#[serde(default)]
pub stream: bool,
/// Options for streaming response
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub echo: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// JSON schema constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>,
}
impl GenerationRequest for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
}
}
}
// Completions API response types
use crate::protocols::openai::common::{LogProbs, Usage};
use serde::{Deserialize, Serialize};
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
}
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub choices: Vec<CompletionStreamChoice>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
}
// OpenAI API error response types
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
// OpenAI protocol module
// This module contains all OpenAI API-compatible types and future validation logic
pub mod chat;
pub mod common;
pub mod completions;
pub mod errors;
......@@ -9,7 +9,10 @@ use axum::{
};
use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::protocols::{
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
pub mod factory;
pub mod header_utils;
......
......@@ -11,8 +11,15 @@ use crate::core::{
RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy;
use crate::protocols::{
common::StringOrArray,
generate::GenerateRequest,
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
};
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
......@@ -616,7 +623,7 @@ impl PDRouter {
// Helper to determine batch size from a GenerateRequest
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
// Check prompt array
if let Some(crate::openai_api_types::StringOrArray::Array(arr)) = &req.prompt {
if let Some(StringOrArray::Array(arr)) = &req.prompt {
if !arr.is_empty() {
return Some(arr.len());
}
......@@ -645,7 +652,7 @@ impl PDRouter {
// Helper to determine batch size from a CompletionRequest
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
// Check prompt array
if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt {
if let StringOrArray::Array(arr) = &req.prompt {
if !arr.is_empty() {
return Some(arr.len());
}
......@@ -1724,10 +1731,8 @@ impl RouterTrait for PDRouter {
.as_deref()
.or_else(|| {
body.prompt.as_ref().and_then(|p| match p {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => {
v.first().map(|s| s.as_str())
}
StringOrArray::String(s) => Some(s.as_str()),
StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
})
})
.map(|s| s.to_string())
......@@ -1763,13 +1768,11 @@ impl RouterTrait for PDRouter {
// Extract text for cache-aware routing
let request_text = if self.policies_need_request_text() {
body.messages.first().and_then(|msg| match msg {
crate::openai_api_types::ChatMessage::User { content, .. } => match content {
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()),
crate::openai_api_types::UserMessageContent::Parts(_) => None,
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(_) => None,
},
crate::openai_api_types::ChatMessage::System { content, .. } => {
Some(content.clone())
}
ChatMessage::System { content, .. } => Some(content.clone()),
_ => None,
})
} else {
......@@ -1804,10 +1807,8 @@ impl RouterTrait for PDRouter {
// Extract text for cache-aware routing
let request_text = if self.policies_need_request_text() {
match &body.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()),
crate::openai_api_types::StringOrArray::Array(v) => {
v.first().map(|s| s.to_string())
}
StringOrArray::String(s) => Some(s.clone()),
StringOrArray::Array(v) => v.first().map(|s| s.to_string()),
}
} else {
None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment