Unverified Commit dfbd741d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

feat: Support for Responses API (#1694)

parent edf00c5c
...@@ -31,6 +31,7 @@ dynamo-engine-llamacpp = { path = "../../lib/engines/llamacpp", optional = true ...@@ -31,6 +31,7 @@ dynamo-engine-llamacpp = { path = "../../lib/engines/llamacpp", optional = true
dynamo-engine-mistralrs = { path = "../../lib/engines/mistralrs", optional = true } dynamo-engine-mistralrs = { path = "../../lib/engines/mistralrs", optional = true }
anyhow = { workspace = true } anyhow = { workspace = true }
async-openai = { workspace = true }
async-stream = { workspace = true } async-stream = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
...@@ -44,7 +45,6 @@ tracing = { workspace = true } ...@@ -44,7 +45,6 @@ tracing = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
async-openai = { workspace = true }
clap = { version = "4.5", features = ["derive", "env"] } clap = { version = "4.5", features = ["derive", "env"] }
futures-util = { version = "0.3" } futures-util = { version = "0.3" }
regex = "1" regex = "1"
......
...@@ -42,9 +42,9 @@ dynamo-llm = { path = "../../llm" } ...@@ -42,9 +42,9 @@ dynamo-llm = { path = "../../llm" }
dynamo-runtime = { path = "../../runtime" } dynamo-runtime = { path = "../../runtime" }
anyhow = { version = "1" } anyhow = { version = "1" }
async-openai = { version = "0.29.0" }
async-stream = { version = "0.3" } async-stream = { version = "0.3" }
async-trait = { version = "0.1" } async-trait = { version = "0.1" }
async-openai = { version = "0.29.0" }
futures = { version = "0.3" } futures = { version = "0.3" }
once_cell = { version = "1.20.3" } once_cell = { version = "1.20.3" }
serde = { version = "1" } serde = { version = "1" }
......
...@@ -42,6 +42,7 @@ dynamo-runtime = { workspace = true } ...@@ -42,6 +42,7 @@ dynamo-runtime = { workspace = true }
# workspace # workspace
anyhow = { workspace = true } anyhow = { workspace = true }
async-openai = { workspace = true }
async-stream = { workspace = true } async-stream = { workspace = true }
async-trait = { workspace = true } async-trait = { workspace = true }
async-nats = { workspace = true } async-nats = { workspace = true }
...@@ -72,7 +73,6 @@ uuid = { workspace = true } ...@@ -72,7 +73,6 @@ uuid = { workspace = true }
xxhash-rust = { workspace = true } xxhash-rust = { workspace = true }
akin = "0.4.0" akin = "0.4.0"
async-openai = { workspace = true }
blake3 = "1" blake3 = "1"
bytemuck = "1.22" bytemuck = "1.22"
candle-core = { version = "0.8.0" } candle-core = { version = "0.8.0" }
......
...@@ -58,6 +58,9 @@ pub enum Endpoint { ...@@ -58,6 +58,9 @@ pub enum Endpoint {
/// OAI Embeddings /// OAI Embeddings
Embeddings, Embeddings,
/// OAI Responses
Responses,
} }
/// Metrics for the HTTP service /// Metrics for the HTTP service
...@@ -354,6 +357,7 @@ impl std::fmt::Display for Endpoint { ...@@ -354,6 +357,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::Completions => write!(f, "completions"), Endpoint::Completions => write!(f, "completions"),
Endpoint::ChatCompletions => write!(f, "chat_completions"), Endpoint::ChatCompletions => write!(f, "chat_completions"),
Endpoint::Embeddings => write!(f, "embeddings"), Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Responses => write!(f, "responses"),
} }
} }
} }
...@@ -364,6 +368,7 @@ impl Endpoint { ...@@ -364,6 +368,7 @@ impl Endpoint {
Endpoint::Completions => "completions", Endpoint::Completions => "completions",
Endpoint::ChatCompletions => "chat_completions", Endpoint::ChatCompletions => "chat_completions",
Endpoint::Embeddings => "embeddings", Endpoint::Embeddings => "embeddings",
Endpoint::Responses => "responses",
} }
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::{
collections::HashSet,
pin::Pin,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use axum::{ use axum::{
extract::State, extract::State,
http::StatusCode, http::StatusCode,
...@@ -11,14 +18,9 @@ use axum::{ ...@@ -11,14 +18,9 @@ use axum::{
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
}; };
use dynamo_runtime::pipeline::{AsyncEngineContext, Context};
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
pin::Pin,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use super::{ use super::{
...@@ -26,21 +28,15 @@ use super::{ ...@@ -26,21 +28,15 @@ use super::{
metrics::{Endpoint, InflightGuard, ResponseMetricCollector}, metrics::{Endpoint, InflightGuard, ResponseMetricCollector},
service_v2, RouteDoc, service_v2, RouteDoc,
}; };
use crate::preprocessor::LLMMetricAnnotation; use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse};
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse},
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::{ use crate::types::Annotated;
openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
},
Annotated,
};
use dynamo_runtime::pipeline::{AsyncEngineContext, Context};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub(crate) struct ErrorResponse { pub(crate) struct ErrorResponse {
...@@ -83,6 +79,19 @@ impl ErrorResponse { ...@@ -83,6 +79,19 @@ impl ErrorResponse {
) )
} }
/// Not Implemented Error
/// Return this error when the client requests a feature that is not yet implemented.
/// This should be used for features that are planned but not available.
pub fn not_implemented_error(msg: &str) -> (StatusCode, Json<ErrorResponse>) {
tracing::error!("Not Implemented error: {msg}");
(
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: msg.to_string(),
}),
)
}
/// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return
/// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`]. /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`].
/// If successful, it will return the [`HttpError`] as an [`ErrorResponse::internal_server_error`] /// If successful, it will return the [`HttpError`] as an [`ErrorResponse::internal_server_error`]
...@@ -384,6 +393,218 @@ async fn chat_completions( ...@@ -384,6 +393,218 @@ async fn chat_completions(
} }
} }
/// OpenAI Responses Request Handler
///
/// This method will handle the incoming request for the /v1/responses endpoint.
#[tracing::instrument(skip_all)]
async fn responses(
State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
Json(mut request): Json<NvCreateResponse>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready
check_ready(&state)?;
// Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields,
// then a field was used that is unsupported. We will log an error message
// and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
if let Some(resp) = validate_unsupported_fields(&request) {
return Ok(resp.into_response());
}
// Handle non-text (image, audio, file) inputs - if Some(resp) is returned by
// validate_input_is_text_only, then we are handling something other than Input::Text(_).
// We will log an error message and early return a 501 NOT_IMPLEMENTED status code.
// Otherwise, proceeed.
if let Some(resp) = validate_input_is_text_only(&request) {
return Ok(resp.into_response());
}
// Apply template values if present
if let Some(template) = template {
if request.inner.model.is_empty() {
request.inner.model = template.model.clone();
}
if request.inner.temperature.unwrap_or(0.0) == 0.0 {
request.inner.temperature = Some(template.temperature);
}
if request.inner.max_output_tokens.unwrap_or(0) == 0 {
request.inner.max_output_tokens = Some(template.max_completion_tokens);
}
}
tracing::trace!("Received chat completions request: {:?}", request.inner);
let request_id = uuid::Uuid::new_v4().to_string();
// Convert NvCreateResponse --> NvCreateChatCompletionRequest
let request: NvCreateChatCompletionRequest = request.try_into().map_err(|e| {
tracing::error!(
request_id,
"Failed to convert NvCreateResponse to NvCreateChatCompletionRequest: {:?}",
e
);
ErrorResponse::not_implemented_error(&format!(
"Only Input::Text(_) is currently supported: {}",
e
))
})?;
let model = &request.inner.model;
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
.manager()
.get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?;
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Responses, false);
let _response_collector = state.metrics_clone().create_response_collector(model);
let request = Context::with_id(request, request_id.clone());
tracing::trace!("Issuing generate call for chat completions");
// issue the generate call on the engine
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate completions"))?;
// TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorResponse::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
// Convert NvCreateChatCompletionResponse --> NvResponse
let response: NvResponse = response.try_into().map_err(|e| {
tracing::error!(
request_id,
"Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}",
e
);
ErrorResponse::internal_server_error("Failed to convert internal response")
})?;
inflight_guard.mark_ok();
Ok(Json(response).into_response())
}
pub fn validate_input_is_text_only(request: &NvCreateResponse) -> Option<impl IntoResponse> {
match &request.inner.input {
async_openai::types::responses::Input::Text(_) => None,
_ => Some(ErrorResponse::not_implemented_error("Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.")),
}
}
/// Checks for unsupported fields in the request.
/// Returns Some(response) if unsupported fields are present.
pub fn validate_unsupported_fields(request: &NvCreateResponse) -> Option<impl IntoResponse> {
let inner = &request.inner;
if inner.background == Some(true) {
return Some(ErrorResponse::not_implemented_error(
"`background: true` is not supported.",
));
}
if inner.include.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`include` is not supported.",
));
}
if inner.instructions.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`instructions` is not supported.",
));
}
if inner.max_tool_calls.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`max_tool_calls` is not supported.",
));
}
if inner.metadata.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`metadata` is not supported.",
));
}
if inner.parallel_tool_calls == Some(true) {
return Some(ErrorResponse::not_implemented_error(
"`parallel_tool_calls: true` is not supported.",
));
}
if inner.previous_response_id.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`previous_response_id` is not supported.",
));
}
if inner.prompt.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`prompt` is not supported.",
));
}
if inner.reasoning.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`reasoning` is not supported.",
));
}
if inner.service_tier.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`service_tier` is not supported.",
));
}
if inner.store == Some(true) {
return Some(ErrorResponse::not_implemented_error(
"`store: true` is not supported.",
));
}
if inner.stream == Some(true) {
return Some(ErrorResponse::not_implemented_error(
"`stream: true` is not supported.",
));
}
if inner.text.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`text` is not supported.",
));
}
if inner.tool_choice.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`tool_choice` is not supported.",
));
}
if inner.tools.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`tools` is not supported.",
));
}
if inner.truncation.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`truncation` is not supported.",
));
}
if inner.user.is_some() {
return Some(ErrorResponse::not_implemented_error(
"`user` is not supported.",
));
}
None
}
// todo - abstract this to the top level lib.rs to be reused // todo - abstract this to the top level lib.rs to be reused
// todo - move the service_observer to its own state/arc // todo - move the service_observer to its own state/arc
fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
...@@ -598,11 +819,34 @@ pub fn list_models_router( ...@@ -598,11 +819,34 @@ pub fn list_models_router(
(vec![doc_for_openai], router) (vec![doc_for_openai], router)
} }
/// Create an Axum [`Router`] for the OpenAI API Responses endpoint
/// If not path is provided, the default path is `/v1/responses`
pub fn responses_router(
state: Arc<service_v2::State>,
template: Option<RequestTemplate>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/responses".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(responses))
.with_state((state, template));
(vec![doc], router)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::discovery::ModelManagerError; use std::collections::HashMap;
use async_openai::types::responses::{
CreateResponse, Input, InputContent, InputItem, InputMessage, PromptConfig,
Role as ResponseRole, ServiceTier, TextConfig, TextResponseFormat, ToolChoice,
ToolChoiceMode, Truncation,
};
use super::*; use super::*;
use crate::discovery::ModelManagerError;
use crate::protocols::openai::responses::NvCreateResponse;
const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions"; const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions";
...@@ -617,6 +861,37 @@ mod tests { ...@@ -617,6 +861,37 @@ mod tests {
Err(ModelManagerError::ModelNotFound("foo".to_string()))? Err(ModelManagerError::ModelNotFound("foo".to_string()))?
} }
fn make_base_request() -> NvCreateResponse {
NvCreateResponse {
inner: CreateResponse {
input: Input::Text("hello".into()),
model: "test-model".into(),
background: None,
include: None,
instructions: None,
max_output_tokens: None,
max_tool_calls: None,
metadata: None,
parallel_tool_calls: None,
previous_response_id: None,
prompt: None,
reasoning: None,
service_tier: None,
store: None,
stream: None,
text: None,
tool_choice: None,
tools: None,
truncation: None,
user: None,
temperature: None,
top_logprobs: None,
top_p: None,
},
nvext: None,
}
}
#[test] #[test]
fn test_http_error_response_from_anyhow() { fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err(); let err = http_error_from_engine(400).unwrap_err();
...@@ -657,4 +932,101 @@ mod tests { ...@@ -657,4 +932,101 @@ mod tests {
) )
); );
} }
#[test]
fn test_validate_input_is_text_only_accepts_text() {
let request = make_base_request();
let result = validate_input_is_text_only(&request);
assert!(result.is_none());
}
#[test]
fn test_validate_input_is_text_only_rejects_items() {
let mut request = make_base_request();
request.inner.input = Input::Items(vec![InputItem::Message(InputMessage {
kind: Default::default(),
role: ResponseRole::User,
content: InputContent::TextInput("structured".into()),
})]);
let result = validate_input_is_text_only(&request);
assert!(result.is_some());
}
#[test]
fn test_validate_unsupported_fields_accepts_clean_request() {
let request = make_base_request();
let result = validate_unsupported_fields(&request);
assert!(result.is_none());
}
#[test]
fn test_validate_unsupported_fields_detects_flags() {
#[allow(clippy::type_complexity)]
let unsupported_cases: Vec<(&str, Box<dyn FnOnce(&mut CreateResponse)>)> = vec![
("background", Box::new(|r| r.background = Some(true))),
(
"include",
Box::new(|r| r.include = Some(vec!["file_search_call.results".into()])),
),
(
"instructions",
Box::new(|r| r.instructions = Some("System prompt".into())),
),
("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))),
("metadata", Box::new(|r| r.metadata = Some(HashMap::new()))),
(
"parallel_tool_calls",
Box::new(|r| r.parallel_tool_calls = Some(true)),
),
(
"previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())),
),
(
"prompt",
Box::new(|r| {
r.prompt = Some(PromptConfig {
id: "template-id".into(),
version: None,
variables: None,
})
}),
),
(
"reasoning",
Box::new(|r| r.reasoning = Some(Default::default())),
),
(
"service_tier",
Box::new(|r| r.service_tier = Some(ServiceTier::Auto)),
),
("store", Box::new(|r| r.store = Some(true))),
("stream", Box::new(|r| r.stream = Some(true))),
(
"text",
Box::new(|r| {
r.text = Some(TextConfig {
format: TextResponseFormat::Text,
})
}),
),
(
"tool_choice",
Box::new(|r| r.tool_choice = Some(ToolChoice::Mode(ToolChoiceMode::Required))),
),
("tools", Box::new(|r| r.tools = Some(vec![]))),
(
"truncation",
Box::new(|r| r.truncation = Some(Truncation::Auto)),
),
("user", Box::new(|r| r.user = Some("user-id".into()))),
];
for (field, set_field) in unsupported_cases {
let mut req = make_base_request();
(set_field)(&mut req.inner);
let result = validate_unsupported_fields(&req);
assert!(result.is_some(), "Expected rejection for `{field}`");
}
}
} }
...@@ -78,6 +78,9 @@ pub struct HttpServiceConfig { ...@@ -78,6 +78,9 @@ pub struct HttpServiceConfig {
#[builder(default = "true")] #[builder(default = "true")]
enable_embeddings_endpoints: bool, enable_embeddings_endpoints: bool,
#[builder(default = "true")]
enable_responses_endpoints: bool,
#[builder(default = "None")] #[builder(default = "None")]
request_template: Option<RequestTemplate>, request_template: Option<RequestTemplate>,
} }
...@@ -153,7 +156,7 @@ impl HttpServiceConfigBuilder { ...@@ -153,7 +156,7 @@ impl HttpServiceConfigBuilder {
if config.enable_chat_endpoints { if config.enable_chat_endpoints {
routes.push(super::openai::chat_completions_router( routes.push(super::openai::chat_completions_router(
state.clone(), state.clone(),
config.request_template, config.request_template.clone(), // TODO clone()? reference?
None, None,
)); ));
} }
...@@ -166,6 +169,14 @@ impl HttpServiceConfigBuilder { ...@@ -166,6 +169,14 @@ impl HttpServiceConfigBuilder {
routes.push(super::openai::embeddings_router(state.clone(), None)); routes.push(super::openai::embeddings_router(state.clone(), None));
} }
if config.enable_responses_endpoints {
routes.push(super::openai::responses_router(
state.clone(),
config.request_template,
None,
));
}
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) { // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
// router = router.merge(route); // router = router.merge(route);
// all_docs.extend(route_docs); // all_docs.extend(route_docs);
......
...@@ -28,6 +28,7 @@ pub mod completions; ...@@ -28,6 +28,7 @@ pub mod completions;
pub mod embeddings; pub mod embeddings;
pub mod models; pub mod models;
pub mod nvext; pub mod nvext;
pub mod responses;
/// Minimum allowed value for OpenAI's `temperature` sampling option /// Minimum allowed value for OpenAI's `temperature` sampling option
pub const MIN_TEMPERATURE: f32 = 0.0; pub const MIN_TEMPERATURE: f32 = 0.0;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_openai::types::responses::{
Content, Input, OutputContent, OutputMessage, OutputStatus, OutputText, Response,
Role as ResponseRole, Status,
};
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use validator::Validate;
use super::chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse};
use super::nvext::{NvExt, NvExtProvider};
use super::{OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider};
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateResponse {
#[serde(flatten)]
pub inner: async_openai::types::responses::CreateResponse,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvResponse {
#[serde(flatten)]
pub inner: async_openai::types::responses::Response,
}
/// Implements `NvExtProvider` for `NvCreateResponse`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateResponse {
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
/// Returns `None`, as raw prompt extraction is not implemented.
fn raw_prompt(&self) -> Option<String> {
None
}
}
/// Implements `AnnotationsProvider` for `NvCreateResponse`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateResponse {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
/// Implements `OpenAISamplingOptionsProvider` for `NvCreateResponse`,
/// exposing OpenAI's sampling parameters for chat completion.
impl OpenAISamplingOptionsProvider for NvCreateResponse {
/// Retrieves the temperature parameter for sampling, if set.
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
/// Retrieves the top-p (nucleus sampling) parameter, if set.
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
/// Retrieves the frequency penalty parameter, if set.
fn get_frequency_penalty(&self) -> Option<f32> {
None // TODO setting as None for now
}
/// Retrieves the presence penalty parameter, if set.
fn get_presence_penalty(&self) -> Option<f32> {
None // TODO setting as None for now
}
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`,
/// providing access to stop conditions that control chat completion behavior.
impl OpenAIStopConditionsProvider for NvCreateResponse {
/// Retrieves the maximum number of tokens allowed in the response.
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_output_tokens
}
/// Retrieves the minimum number of tokens required in the response.
///
/// # Note
/// This method is currently a placeholder and always returns `None`
/// since `min_tokens` is not an OpenAI-supported parameter.
fn get_min_tokens(&self) -> Option<u32> {
None
}
/// Retrieves the stop conditions that terminate the chat completion response.
///
/// Converts OpenAI's `Stop` enum to a `Vec<String>`, normalizing the representation.
///
/// # Returns
/// * `Some(Vec<String>)` if stop conditions are set.
/// * `None` if no stop conditions are defined.
fn get_stop(&self) -> Option<Vec<String>> {
None // TODO returning None for now
}
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
type Error = anyhow::Error;
fn try_from(resp: NvCreateResponse) -> Result<Self, Self::Error> {
// Create messages from input
let input_text = match resp.inner.input {
Input::Text(text) => text,
Input::Items(_) => {
return Err(anyhow::anyhow!(
"Input::Items not supported in conversion to NvCreateChatCompletionRequest"
));
}
};
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(input_text),
name: None,
},
)];
// TODO: See this PR for details: https://github.com/64bit/async-openai/pull/398
let top_logprobs = convert_top_logprobs(resp.inner.top_logprobs);
// The below should encompass all of the allowed configurable parameters
Ok(NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
messages,
model: resp.inner.model,
temperature: resp.inner.temperature,
top_p: resp.inner.top_p,
max_completion_tokens: resp.inner.max_output_tokens,
top_logprobs,
stream: Some(true), // Set this to Some(True) by default to aggregate stream
..Default::default()
},
nvext: resp.nvext,
})
}
}
fn convert_top_logprobs(input: Option<u32>) -> Option<u8> {
input.map(|x| x.min(20) as u8)
}
impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
type Error = anyhow::Error;
fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> {
let chat_resp = nv_resp.inner;
let content_text = chat_resp
.choices
.into_iter()
.next()
.and_then(|choice| choice.message.content)
.unwrap_or_else(|| {
tracing::warn!("No choices in chat completion response, using empty content");
String::new()
});
let message_id = format!("msg_{}", Uuid::new_v4().simple());
let response_id = format!("resp_{}", Uuid::new_v4().simple());
let output = vec![OutputContent::Message(OutputMessage {
id: message_id,
role: ResponseRole::Assistant,
status: OutputStatus::Completed,
content: vec![Content::OutputText(OutputText {
text: content_text,
annotations: vec![],
})],
})];
let response = Response {
id: response_id,
object: "response".to_string(),
created_at: chat_resp.created as u64,
model: chat_resp.model,
status: Status::Completed,
output,
output_text: None,
parallel_tool_calls: None,
reasoning: None,
service_tier: None,
store: None,
truncation: None,
temperature: None,
top_p: None,
tools: None,
metadata: None,
previous_response_id: None,
error: None,
incomplete_details: None,
instructions: None,
max_output_tokens: None,
text: None,
tool_choice: None,
usage: None,
user: None,
};
Ok(NvResponse { inner: response })
}
}
#[cfg(test)]
mod tests {
use async_openai::types::responses::{CreateResponse, Input};
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageContent,
};
use super::*;
use crate::types::openai::chat_completions::NvCreateChatCompletionResponse;
fn make_response_with_input(text: &str) -> NvCreateResponse {
NvCreateResponse {
inner: CreateResponse {
input: Input::Text(text.into()),
model: "test-model".into(),
max_output_tokens: Some(1024),
temperature: Some(0.5),
top_p: Some(0.9),
top_logprobs: Some(15),
..Default::default()
},
nvext: Some(NvExt {
annotations: Some(vec!["debug".into(), "trace".into()]),
..Default::default()
}),
}
}
#[test]
fn test_annotations_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(
req.annotations(),
Some(vec!["debug".to_string(), "trace".to_string()])
);
assert!(req.has_annotation("debug"));
assert!(req.has_annotation("trace"));
assert!(!req.has_annotation("missing"));
}
#[test]
fn test_openai_sampling_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(req.get_temperature(), Some(0.5));
assert_eq!(req.get_top_p(), Some(0.9));
assert_eq!(req.get_frequency_penalty(), None);
assert_eq!(req.get_presence_penalty(), None);
}
#[test]
fn test_openai_stop_conditions_trait_behavior() {
let req = make_response_with_input("hello");
assert_eq!(req.get_max_tokens(), Some(1024));
assert_eq!(req.get_min_tokens(), None);
assert_eq!(req.get_stop(), None);
}
#[test]
fn test_into_nvcreate_chat_completion_request() {
let nv_req: NvCreateChatCompletionRequest =
make_response_with_input("hi there").try_into().unwrap();
assert_eq!(nv_req.inner.model, "test-model");
assert_eq!(nv_req.inner.temperature, Some(0.5));
assert_eq!(nv_req.inner.top_p, Some(0.9));
assert_eq!(nv_req.inner.max_completion_tokens, Some(1024));
assert_eq!(nv_req.inner.top_logprobs, Some(15));
assert_eq!(nv_req.inner.stream, Some(true));
let messages = &nv_req.inner.messages;
assert_eq!(messages.len(), 1);
match &messages[0] {
ChatCompletionRequestMessage::User(user_msg) => match &user_msg.content {
ChatCompletionRequestUserMessageContent::Text(t) => {
assert_eq!(t, "hi there");
}
_ => panic!("unexpected user content type"),
},
_ => panic!("expected user message"),
}
}
#[allow(deprecated)]
#[test]
fn test_into_nvresponse_from_chat_response() {
let now = 1_726_000_000;
let chat_resp = NvCreateChatCompletionResponse {
inner: async_openai::types::CreateChatCompletionResponse {
id: "chatcmpl-xyz".into(),
choices: vec![async_openai::types::ChatChoice {
index: 0,
message: async_openai::types::ChatCompletionResponseMessage {
content: Some("This is a reply".into()),
refusal: None,
tool_calls: None,
role: async_openai::types::Role::Assistant,
function_call: None,
audio: None,
},
finish_reason: None,
logprobs: None,
}],
created: now,
model: "llama-3.1-8b-instruct".into(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion".to_string(),
usage: None,
},
};
let wrapped: NvResponse = chat_resp.try_into().unwrap();
assert_eq!(wrapped.inner.model, "llama-3.1-8b-instruct");
assert_eq!(wrapped.inner.status, Status::Completed);
assert_eq!(wrapped.inner.object, "response");
assert!(wrapped.inner.id.starts_with("resp_"));
let msg = match &wrapped.inner.output[0] {
OutputContent::Message(m) => m,
_ => panic!("Expected Message variant"),
};
assert_eq!(msg.role, ResponseRole::Assistant);
match &msg.content[0] {
Content::OutputText(txt) => {
assert_eq!(txt.text, "This is a reply");
}
_ => panic!("Expected OutputText content"),
}
}
#[test]
fn test_convert_top_logprobs_clamped() {
assert_eq!(convert_top_logprobs(Some(5)), Some(5));
assert_eq!(convert_top_logprobs(Some(21)), Some(20));
assert_eq!(convert_top_logprobs(Some(1000)), Some(20));
assert_eq!(convert_top_logprobs(None), None);
}
}
...@@ -143,6 +143,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu ...@@ -143,6 +143,7 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
Endpoint::Completions => 0, Endpoint::Completions => 0,
Endpoint::ChatCompletions => 1, Endpoint::ChatCompletions => 1,
Endpoint::Embeddings => todo!(), Endpoint::Embeddings => todo!(),
Endpoint::Responses => todo!(),
}; };
let request_type = match request_type { let request_type = match request_type {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment