// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use super::*; use minijinja::{context, value::Value}; use crate::protocols::openai::{ chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest, }; use tracing; use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput}; impl OAIChatLikeRequest for NvCreateChatCompletionRequest { fn model(&self) -> String { self.inner.model.clone() } fn messages(&self) -> Value { Value::from_serialize(&self.inner.messages) } fn tools(&self) -> Option { if self.inner.tools.is_none() { None } else { Some(Value::from_serialize(&self.inner.tools)) } } fn tool_choice(&self) -> Option { if self.inner.tool_choice.is_none() { None } else { Some(Value::from_serialize(&self.inner.tool_choice)) } } fn should_add_generation_prompt(&self) -> bool { if let Some(last) = self.inner.messages.last() { matches!( last, async_openai::types::ChatCompletionRequestMessage::User(_) ) } else { true } } fn extract_text(&self) -> Option { Some(TextInput::Single(String::new())) } } impl OAIChatLikeRequest for NvCreateCompletionRequest { fn model(&self) -> String { self.inner.model.clone() } fn messages(&self) -> minijinja::value::Value { let message = async_openai::types::ChatCompletionRequestMessage::User( async_openai::types::ChatCompletionRequestUserMessage { content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt), ), name: None, }, ); minijinja::value::Value::from_serialize(vec![message]) } fn should_add_generation_prompt(&self) -> bool { true } fn prompt_input_type(&self) -> PromptInput { match &self.inner.prompt { async_openai::types::Prompt::IntegerArray(_) => { PromptInput::Tokens(TokenInput::Single(vec![])) } async_openai::types::Prompt::ArrayOfIntegerArray(_) => { PromptInput::Tokens(TokenInput::Batch(vec![])) } async_openai::types::Prompt::String(_) => { PromptInput::Text(TextInput::Single(String::new())) } async_openai::types::Prompt::StringArray(_) => { PromptInput::Text(TextInput::Batch(vec![])) } } } fn extract_tokens(&self) -> Option { match &self.inner.prompt { async_openai::types::Prompt::IntegerArray(tokens) => { Some(TokenInput::Single(tokens.clone())) } async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => { Some(TokenInput::Batch(arrays.clone())) } _ => None, } } fn extract_text(&self) -> Option { match &self.inner.prompt { async_openai::types::Prompt::String(text) => Some(TextInput::Single(text.to_string())), async_openai::types::Prompt::StringArray(texts) => { Some(TextInput::Batch(texts.to_vec())) } _ => None, } } } impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { fn supports_add_generation_prompt(&self) -> bool { self.supports_add_generation_prompt } fn render(&self, req: &dyn OAIChatLikeRequest) -> Result { let mixins = Value::from_dyn_object(self.mixins.clone()); let tools = req.tools(); let has_tools = tools.is_some(); let add_generation_prompt = req.should_add_generation_prompt(); tracing::trace!( "Rendering prompt with tools: {:?}, add_generation_prompt: {}", has_tools, add_generation_prompt ); let ctx = context! { messages => req.messages(), tools => tools, bos_token => self.config.bos_tok(), eos_token => self.config.eos_tok(), unk_token => self.config.unk_tok(), add_generation_prompt => add_generation_prompt, ..mixins }; let ctx = context! { ..ctx, ..context! { }}; let tmpl = if has_tools { self.env.get_template("tool_use")? } else { self.env.get_template("default")? }; Ok(tmpl.render(&ctx)?) } }