Unverified Commit 5ae5ecaa authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Implement OpenAI Responses API specification (#9367)

parent 5fbad308
...@@ -5,3 +5,4 @@ pub mod chat; ...@@ -5,3 +5,4 @@ pub mod chat;
pub mod common; pub mod common;
pub mod completions; pub mod completions;
pub mod errors; pub mod errors;
pub mod responses;
// Responses API module
pub mod request;
pub mod response;
pub mod types;
// Re-export main types for convenience
pub use request::ResponsesRequest;
pub use response::ResponsesResponse;
pub use types::*;
// Responses API request types
use crate::protocols::common::{GenerationRequest, StringOrArray};
use crate::protocols::openai::responses::types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn generate_request_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
// ============= Core OpenAI API fields =============
/// Run the request in the background
#[serde(default)]
pub background: bool,
/// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")]
pub include: Option<Vec<IncludeField>>,
/// Input content - can be string or structured items
pub input: ResponseInput,
/// System instructions for the model
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
/// Maximum number of output tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
/// Maximum number of tool calls
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tool_calls: Option<u32>,
/// Additional metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
/// Model to use (optional to match vLLM)
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
/// Whether to enable parallel tool calls
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
/// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
/// Reasoning configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ResponseReasoningParam>,
/// Service tier
#[serde(default)]
pub service_tier: ServiceTier,
/// Whether to store the response
#[serde(default = "default_true")]
pub store: bool,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// Tool choice behavior
#[serde(default)]
pub tool_choice: ToolChoice,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
/// Number of top logprobs to return
#[serde(default)]
pub top_logprobs: u32,
/// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Truncation behavior
#[serde(default)]
pub truncation: Truncation,
/// User identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
// ============= SGLang Extensions =============
/// Request ID
#[serde(default = "generate_request_id")]
pub request_id: String,
/// Request priority
#[serde(default)]
pub priority: i32,
/// Frequency penalty
#[serde(default)]
pub frequency_penalty: f32,
/// Presence penalty
#[serde(default)]
pub presence_penalty: f32,
/// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Top-k sampling parameter
#[serde(default = "default_top_k")]
pub top_k: i32,
/// Min-p sampling parameter
#[serde(default)]
pub min_p: f32,
/// Repetition penalty
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ResponseInput {
Text(String),
Items(Vec<ResponseInputOutputItem>),
}
fn default_top_k() -> i32 {
-1
}
fn default_repetition_penalty() -> f32 {
1.0
}
fn default_true() -> bool {
true
}
impl ResponsesRequest {
/// Default sampling parameters
const DEFAULT_TEMPERATURE: f32 = 0.7;
const DEFAULT_TOP_P: f32 = 1.0;
/// Convert to sampling parameters for generation
pub fn to_sampling_params(
&self,
default_max_tokens: u32,
default_params: Option<HashMap<String, serde_json::Value>>,
) -> HashMap<String, serde_json::Value> {
let mut params = HashMap::new();
// Use max_output_tokens if available
let max_tokens = if let Some(max_output) = self.max_output_tokens {
std::cmp::min(max_output, default_max_tokens)
} else {
default_max_tokens
};
// Avoid exceeding context length by minus 1 token
let max_tokens = max_tokens.saturating_sub(1);
// Temperature
let temperature = self.temperature.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("temperature"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TEMPERATURE)
});
// Top-p
let top_p = self.top_p.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("top_p"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TOP_P)
});
params.insert(
"max_new_tokens".to_string(),
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
);
params.insert(
"temperature".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
);
params.insert(
"top_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
);
params.insert(
"frequency_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
),
);
params.insert(
"presence_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
),
);
params.insert(
"top_k".to_string(),
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
);
params.insert(
"min_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
);
params.insert(
"repetition_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
),
);
if let Some(ref stop) = self.stop {
match serde_json::to_value(stop) {
Ok(value) => params.insert("stop".to_string(), value),
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
};
}
// Apply any additional default parameters
if let Some(default_params) = default_params {
for (key, value) in default_params {
params.entry(key).or_insert(value);
}
}
params
}
}
impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
self.model.as_deref()
}
fn extract_text_for_routing(&self) -> String {
match &self.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(items) => items
.iter()
.filter_map(|item| match item {
ResponseInputOutputItem::Message { content, .. } => {
let texts: Vec<String> = content
.iter()
.map(|part| match part {
ResponseContentPart::OutputText { text, .. } => text.clone(),
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(" "))
}
}
ResponseInputOutputItem::Reasoning { content, .. } => {
let texts: Vec<String> = content
.iter()
.map(|part| match part {
ResponseReasoningContent::ReasoningText { text } => text.clone(),
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(" "))
}
}
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
Some(arguments.clone())
}
})
.collect::<Vec<String>>()
.join(" "),
}
}
}
// Responses API response types
use crate::protocols::openai::responses::request::ResponsesRequest;
use crate::protocols::openai::responses::types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
fn generate_response_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs() as i64
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesResponse {
/// Response ID
#[serde(default = "generate_response_id")]
pub id: String,
/// Object type
#[serde(default = "default_object_type")]
pub object: String,
/// Creation timestamp
#[serde(default = "current_timestamp")]
pub created_at: i64,
/// Model name
pub model: String,
/// Output items
#[serde(default)]
pub output: Vec<ResponseOutputItem>,
/// Response status
pub status: ResponseStatus,
/// Usage information
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<UsageInfo>,
/// Whether parallel tool calls are enabled
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
/// Tool choice setting
#[serde(default = "default_tool_choice")]
pub tool_choice: String,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
}
fn default_object_type() -> String {
"response".to_string()
}
fn default_true() -> bool {
true
}
fn default_tool_choice() -> String {
"auto".to_string()
}
impl ResponsesResponse {
/// Create a response from a request
#[allow(clippy::too_many_arguments)]
pub fn from_request(
request: &ResponsesRequest,
_sampling_params: &HashMap<String, serde_json::Value>,
model_name: String,
created_time: i64,
output: Vec<ResponseOutputItem>,
status: ResponseStatus,
usage: Option<UsageInfo>,
) -> Self {
Self {
id: request.request_id.clone(),
object: "response".to_string(),
created_at: created_time,
model: model_name,
output,
status,
usage,
parallel_tool_calls: request.parallel_tool_calls,
tool_choice: match request.tool_choice {
ToolChoice::Auto => "auto".to_string(),
ToolChoice::Required => "required".to_string(),
ToolChoice::None => "none".to_string(),
},
tools: request.tools.clone(),
}
}
/// Create a new response with default values
pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self {
Self {
id: request_id,
object: "response".to_string(),
created_at: current_timestamp(),
model,
output: Vec::new(),
status,
usage: None,
parallel_tool_calls: true,
tool_choice: "auto".to_string(),
tools: Vec::new(),
}
}
/// Add an output item to the response
pub fn add_output(&mut self, item: ResponseOutputItem) {
self.output.push(item);
}
/// Set the usage information
pub fn set_usage(&mut self, usage: UsageInfo) {
self.usage = Some(usage);
}
/// Update the status
pub fn set_status(&mut self, status: ResponseStatus) {
self.status = status;
}
/// Check if the response is complete
pub fn is_complete(&self) -> bool {
matches!(self.status, ResponseStatus::Completed)
}
/// Check if the response is in progress
pub fn is_in_progress(&self) -> bool {
matches!(self.status, ResponseStatus::InProgress)
}
/// Check if the response failed
pub fn is_failed(&self) -> bool {
matches!(self.status, ResponseStatus::Failed)
}
/// Check if the response was cancelled
pub fn is_cancelled(&self) -> bool {
matches!(self.status, ResponseStatus::Cancelled)
}
/// Check if the response is queued
pub fn is_queued(&self) -> bool {
matches!(self.status, ResponseStatus::Queued)
}
/// Convert usage to OpenAI Responses API format
pub fn usage_in_response_format(
&self,
) -> Option<crate::protocols::openai::responses::types::ResponseUsage> {
self.usage.as_ref().map(|usage| usage.to_response_usage())
}
/// Get the response as a JSON value with usage in response format
pub fn to_response_format(&self) -> serde_json::Value {
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
// Convert usage to response format if present
if let Some(usage) = &self.usage {
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
response["usage"] = usage_value;
}
}
response
}
}
// ============= Helper Functions =============
impl ResponseOutputItem {
/// Create a new message output item
pub fn new_message(
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
) -> Self {
Self::Message {
id,
role,
content,
status,
}
}
/// Create a new reasoning output item
pub fn new_reasoning(
id: String,
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
status: Option<String>,
) -> Self {
Self::Reasoning {
id,
summary,
content,
status,
}
}
/// Create a new function tool call output item
pub fn new_function_tool_call(
id: String,
name: String,
arguments: String,
output: Option<String>,
status: String,
) -> Self {
Self::FunctionToolCall {
id,
name,
arguments,
output,
status,
}
}
}
impl ResponseContentPart {
/// Create a new text content part
pub fn new_text(
text: String,
annotations: Vec<String>,
logprobs: Option<crate::protocols::openai::common::ChatLogProbs>,
) -> Self {
Self::OutputText {
text,
annotations,
logprobs,
}
}
}
impl ResponseReasoningContent {
/// Create a new reasoning text content
pub fn new_reasoning_text(text: String) -> Self {
Self::ReasoningText { text }
}
}
impl UsageInfo {
/// Create a new usage info with token counts
pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option<u32>) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
reasoning_tokens,
prompt_tokens_details: None,
}
}
/// Create usage info with cached token details
pub fn new_with_cached(
prompt_tokens: u32,
completion_tokens: u32,
reasoning_tokens: Option<u32>,
cached_tokens: u32,
) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
reasoning_tokens,
prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }),
}
}
}
// Supporting types for Responses API
use crate::protocols::openai::common::ChatLogProbs;
use serde::{Deserialize, Serialize};
// ============= Tool Definitions =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
}
// ============= Reasoning Configuration =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")]
pub effort: Option<ReasoningEffort>,
}
fn default_reasoning_effort() -> Option<ReasoningEffort> {
Some(ReasoningEffort::Medium)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
// ============= Input/Output Items =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseInputOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseContentPart {
#[serde(rename = "output_text")]
OutputText {
text: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
annotations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<ChatLogProbs>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseReasoningContent {
#[serde(rename = "reasoning_text")]
ReasoningText { text: String },
}
// ============= Output Items for Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
status: String,
},
}
// ============= Service Tier =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
Default,
Flex,
Scale,
Priority,
}
impl Default for ServiceTier {
fn default() -> Self {
Self::Auto
}
}
// ============= Tool Choice =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
Required,
None,
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Auto
}
}
// ============= Truncation =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Truncation {
Auto,
Disabled,
}
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
// ============= Response Status =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
Queued,
InProgress,
Completed,
Failed,
Cancelled,
}
// ============= Include Fields =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeField {
#[serde(rename = "code_interpreter_call.outputs")]
CodeInterpreterCallOutputs,
#[serde(rename = "computer_call_output.output.image_url")]
ComputerCallOutputImageUrl,
#[serde(rename = "file_search_call.results")]
FileSearchCallResults,
#[serde(rename = "message.input_image.image_url")]
MessageInputImageUrl,
#[serde(rename = "message.output_text.logprobs")]
MessageOutputTextLogprobs,
#[serde(rename = "reasoning.encrypted_content")]
ReasoningEncryptedContent,
}
// ============= Usage Info =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens_details: Option<InputTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens_details: Option<OutputTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OutputTokensDetails {
pub reasoning_tokens: u32,
}
impl UsageInfo {
/// Convert to OpenAI Responses API format
pub fn to_response_usage(&self) -> ResponseUsage {
ResponseUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
total_tokens: self.total_tokens,
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
InputTokensDetails {
cached_tokens: details.cached_tokens,
}
}),
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
reasoning_tokens: tokens,
}),
}
}
}
impl From<UsageInfo> for ResponseUsage {
fn from(usage: UsageInfo) -> Self {
usage.to_response_usage()
}
}
impl ResponseUsage {
/// Convert back to standard UsageInfo format
pub fn to_usage_info(&self) -> UsageInfo {
UsageInfo {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.total_tokens,
reasoning_tokens: self
.output_tokens_details
.as_ref()
.map(|details| details.reasoning_tokens),
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
PromptTokenUsageInfo {
cached_tokens: details.cached_tokens,
}
}),
}
}
}
// Integration test for Responses API
use sglang_router_rs::protocols::common::GenerationRequest;
use sglang_router_rs::protocols::openai::responses::request::ResponseInput;
use sglang_router_rs::protocols::openai::responses::*;
#[test]
fn test_responses_request_creation() {
let request = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()),
max_output_tokens: Some(100),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
}),
service_tier: ServiceTier::Auto,
store: true,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Auto,
tools: vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
}],
top_logprobs: 5,
top_p: Some(0.9),
truncation: Truncation::Disabled,
user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: -1,
min_p: 0.0,
repetition_penalty: 1.0,
};
// Test GenerationRequest trait implementation
assert!(!request.is_stream());
assert_eq!(request.get_model(), Some("test-model"));
let routing_text = request.extract_text_for_routing();
assert_eq!(routing_text, "Hello, world!");
}
#[test]
fn test_sampling_params_conversion() {
let request = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("Test".to_string()),
instructions: None,
max_output_tokens: Some(50),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true, // Use default true
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true, // Use default true
stream: false,
temperature: Some(0.8),
tool_choice: ToolChoice::Auto,
tools: vec![],
top_logprobs: 0, // Use default 0
top_p: Some(0.95),
truncation: Truncation::Auto,
user: None,
request_id: "resp_test456".to_string(),
priority: 0,
frequency_penalty: 0.1,
presence_penalty: 0.2,
stop: None,
top_k: 10,
min_p: 0.05,
repetition_penalty: 1.1,
};
let params = request.to_sampling_params(1000, None);
// Check that parameters are converted correctly
assert!(params.contains_key("temperature"));
assert!(params.contains_key("top_p"));
assert!(params.contains_key("frequency_penalty"));
assert!(params.contains_key("max_new_tokens"));
}
#[test]
fn test_responses_response_creation() {
let response = ResponsesResponse::new(
"resp_test789".to_string(),
"test-model".to_string(),
ResponseStatus::Completed,
);
assert_eq!(response.id, "resp_test789");
assert_eq!(response.model, "test-model");
assert!(response.is_complete());
assert!(!response.is_in_progress());
assert!(!response.is_failed());
}
#[test]
fn test_usage_conversion() {
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3);
let response_usage = usage_info.to_response_usage();
assert_eq!(response_usage.input_tokens, 15);
assert_eq!(response_usage.output_tokens, 25);
assert_eq!(response_usage.total_tokens, 40);
// Check details are converted correctly
assert!(response_usage.input_tokens_details.is_some());
assert_eq!(
response_usage
.input_tokens_details
.as_ref()
.unwrap()
.cached_tokens,
3
);
assert!(response_usage.output_tokens_details.is_some());
assert_eq!(
response_usage
.output_tokens_details
.as_ref()
.unwrap()
.reasoning_tokens,
8
);
// Test reverse conversion
let back_to_usage = response_usage.to_usage_info();
assert_eq!(back_to_usage.prompt_tokens, 15);
assert_eq!(back_to_usage.completion_tokens, 25);
assert_eq!(back_to_usage.reasoning_tokens, Some(8));
}
#[test]
fn test_reasoning_param_default() {
let param = ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
};
// Test JSON serialization/deserialization preserves default
let json = serde_json::to_string(&param).unwrap();
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium)));
}
#[test]
fn test_json_serialization() {
let request = ResponsesRequest {
background: true,
include: None,
input: ResponseInput::Text("Test input".to_string()),
instructions: Some("Test instructions".to_string()),
max_output_tokens: Some(200),
max_tool_calls: Some(5),
metadata: None,
model: Some("gpt-4".to_string()),
parallel_tool_calls: false,
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High),
}),
service_tier: ServiceTier::Priority,
store: false,
stream: true,
temperature: Some(0.9),
tool_choice: ToolChoice::Required,
tools: vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
}],
top_logprobs: 10,
top_p: Some(0.8),
truncation: Truncation::Auto,
user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(),
priority: 1,
frequency_penalty: 0.3,
presence_penalty: 0.4,
stop: None,
top_k: 50,
min_p: 0.1,
repetition_penalty: 1.2,
};
// Test that everything can be serialized to JSON and back
let json = serde_json::to_string(&request).expect("Serialization should work");
let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work");
assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert!(parsed.background);
assert!(parsed.stream);
assert_eq!(parsed.tools.len(), 1);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment