Unverified Commit c443528f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(preprocessor): Populate model ID in PreprocessedRequest (#2397)

parent b5efb957
...@@ -35,6 +35,7 @@ pub async fn run( ...@@ -35,6 +35,7 @@ pub async fn run(
.or(flags.model_path_flag.clone()), .or(flags.model_path_flag.clone()),
) )
.model_name(flags.model_name.clone()) .model_name(flags.model_name.clone())
.model_config(flags.model_config.clone())
.kv_cache_block_size(flags.kv_cache_block_size) .kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json // Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length) .context_length(flags.context_length)
......
...@@ -175,6 +175,7 @@ mod tests { ...@@ -175,6 +175,7 @@ mod tests {
// Helper to create a mock preprocessed request // Helper to create a mock preprocessed request
fn create_mock_request(max_tokens: u32) -> PreprocessedRequest { fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
PreprocessedRequest { PreprocessedRequest {
model: "mock".to_string(),
token_ids: vec![1, 2, 3], token_ids: vec![1, 2, 3],
batch_token_ids: None, batch_token_ids: None,
stop_conditions: StopConditions { stop_conditions: StopConditions {
......
...@@ -633,6 +633,7 @@ mod integration_tests { ...@@ -633,6 +633,7 @@ mod integration_tests {
// Create test requests for both DP workers // Create test requests for both DP workers
let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| PreprocessedRequest { let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| PreprocessedRequest {
model: "mock".to_string(),
token_ids: tokens, token_ids: tokens,
batch_token_ids: None, batch_token_ids: None,
stop_conditions: StopConditions { stop_conditions: StopConditions {
......
...@@ -153,6 +153,7 @@ impl OpenAIPreprocessor { ...@@ -153,6 +153,7 @@ impl OpenAIPreprocessor {
) -> Result<(PreprocessedRequest, HashMap<String, String>)> { ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new(); let mut annotations = HashMap::new();
let mut builder = PreprocessedRequest::builder(); let mut builder = PreprocessedRequest::builder();
builder.model(request.model());
// match request type before any conversion/processing // match request type before any conversion/processing
match request.prompt_input_type() { match request.prompt_input_type() {
......
...@@ -58,6 +58,7 @@ pub enum PromptInput { ...@@ -58,6 +58,7 @@ pub enum PromptInput {
/// Trait that defines a request that can map to an OpenAI-like request. /// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest { pub trait OAIChatLikeRequest {
fn model(&self) -> String;
fn messages(&self) -> Value; fn messages(&self) -> Value;
fn tools(&self) -> Option<Value> { fn tools(&self) -> Option<Value> {
None None
......
...@@ -25,6 +25,10 @@ use tracing; ...@@ -25,6 +25,10 @@ use tracing;
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput}; use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
impl OAIChatLikeRequest for NvCreateChatCompletionRequest { impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}
fn messages(&self) -> Value { fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages) Value::from_serialize(&self.inner.messages)
} }
...@@ -62,6 +66,9 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -62,6 +66,9 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
} }
impl OAIChatLikeRequest for NvCreateCompletionRequest { impl OAIChatLikeRequest for NvCreateCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}
fn messages(&self) -> minijinja::value::Value { fn messages(&self) -> minijinja::value::Value {
let message = async_openai::types::ChatCompletionRequestMessage::User( let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { async_openai::types::ChatCompletionRequestUserMessage {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -23,6 +11,9 @@ use crate::protocols::TokenIdType; ...@@ -23,6 +11,9 @@ use crate::protocols::TokenIdType;
/// crate is responsible for converting request from the public APIs to this internal representation. /// crate is responsible for converting request from the public APIs to this internal representation.
#[derive(Serialize, Deserialize, Debug, Clone, Builder)] #[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct PreprocessedRequest { pub struct PreprocessedRequest {
/// ID of the model to use
pub model: String,
/// Type of prompt /// Type of prompt
pub token_ids: Vec<TokenIdType>, pub token_ids: Vec<TokenIdType>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment