Unverified Commit c443528f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(preprocessor): Populate model ID in PreprocessedRequest (#2397)

parent b5efb957
......@@ -35,6 +35,7 @@ pub async fn run(
.or(flags.model_path_flag.clone()),
)
.model_name(flags.model_name.clone())
.model_config(flags.model_config.clone())
.kv_cache_block_size(flags.kv_cache_block_size)
// Only set if user provides. Usually loaded from tokenizer_config.json
.context_length(flags.context_length)
......
......@@ -175,6 +175,7 @@ mod tests {
// Helper to create a mock preprocessed request
fn create_mock_request(max_tokens: u32) -> PreprocessedRequest {
PreprocessedRequest {
model: "mock".to_string(),
token_ids: vec![1, 2, 3],
batch_token_ids: None,
stop_conditions: StopConditions {
......
......@@ -633,6 +633,7 @@ mod integration_tests {
// Create test requests for both DP workers
let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| PreprocessedRequest {
model: "mock".to_string(),
token_ids: tokens,
batch_token_ids: None,
stop_conditions: StopConditions {
......
......@@ -153,6 +153,7 @@ impl OpenAIPreprocessor {
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new();
let mut builder = PreprocessedRequest::builder();
builder.model(request.model());
// match request type before any conversion/processing
match request.prompt_input_type() {
......
......@@ -58,6 +58,7 @@ pub enum PromptInput {
/// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest {
fn model(&self) -> String;
fn messages(&self) -> Value;
fn tools(&self) -> Option<Value> {
None
......
......@@ -25,6 +25,10 @@ use tracing;
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}
fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages)
}
......@@ -62,6 +66,9 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
}
impl OAIChatLikeRequest for NvCreateCompletionRequest {
fn model(&self) -> String {
self.inner.model.clone()
}
fn messages(&self) -> minijinja::value::Value {
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
......@@ -23,6 +11,9 @@ use crate::protocols::TokenIdType;
/// crate is responsible for converting request from the public APIs to this internal representation.
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct PreprocessedRequest {
/// ID of the model to use
pub model: String,
/// Type of prompt
pub token_ids: Vec<TokenIdType>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment