Unverified Commit cb5a657a authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Load the tokenizer JSON once for chat and completions. (#2910)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 9ef13289
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*; use super::*;
use crate::llm::model_card::ModelDeploymentCard; use crate::llm::model_card::ModelDeploymentCard;
...@@ -33,10 +21,7 @@ pub(crate) struct Backend { ...@@ -33,10 +21,7 @@ pub(crate) struct Backend {
impl Backend { impl Backend {
#[new] #[new]
fn new(mdc: ModelDeploymentCard, endpoint: Endpoint) -> PyResult<Self> { fn new(mdc: ModelDeploymentCard, endpoint: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let backend = llm_rs::backend::Backend::from_mdc(&mdc.inner);
let backend = runtime
.block_on(llm_rs::backend::Backend::from_mdc(mdc.inner))
.map_err(to_pyerr)?;
Ok(Self { Ok(Self {
inner: backend, inner: backend,
endpoint, endpoint,
......
...@@ -16,14 +16,10 @@ impl ModelDeploymentCard {} ...@@ -16,14 +16,10 @@ impl ModelDeploymentCard {}
impl ModelDeploymentCard { impl ModelDeploymentCard {
// Previously called "from_local_path" // Previously called "from_local_path"
#[staticmethod] #[staticmethod]
fn load(path: String, model_name: String, py: Python<'_>) -> PyResult<Bound<'_, PyAny>> { fn load(path: String, model_name: String) -> PyResult<ModelDeploymentCard> {
pyo3_async_runtimes::tokio::future_into_py(py, async move { let mut card = RsModelDeploymentCard::load(&path, None).map_err(to_pyerr)?;
let mut card = RsModelDeploymentCard::load(&path, None) card.set_name(&model_name);
.await Ok(ModelDeploymentCard { inner: card })
.map_err(to_pyerr)?;
card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card })
})
} }
#[staticmethod] #[staticmethod]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*; use super::*;
use crate::llm::model_card::ModelDeploymentCard; use crate::llm::model_card::ModelDeploymentCard;
...@@ -42,10 +30,7 @@ pub(crate) struct OAIChatPreprocessor { ...@@ -42,10 +30,7 @@ pub(crate) struct OAIChatPreprocessor {
impl OAIChatPreprocessor { impl OAIChatPreprocessor {
#[new] #[new]
fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> { fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let preprocessor = OpenAIPreprocessor::new(mdc.inner.clone()).map_err(to_pyerr)?;
let preprocessor = runtime
.block_on(OpenAIPreprocessor::new(mdc.inner.clone()))
.map_err(to_pyerr)?;
Ok(Self { Ok(Self {
inner: preprocessor, inner: preprocessor,
current, current,
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use anyhow::{Error, Result}; use anyhow::Result;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use tracing as log; use tracing as log;
use crate::model_card::{ModelDeploymentCard, TokenizerKind}; use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::{ pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
...@@ -66,30 +66,27 @@ struct DecoderUnfoldState { ...@@ -66,30 +66,27 @@ struct DecoderUnfoldState {
} }
impl Backend { impl Backend {
pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> { pub fn from_tokenizer(tokenizer: HfTokenizer) -> Arc<Self> {
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer); let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
let tokenizer = Tokenizer::from(Arc::new(tokenizer)); let tokenizer = Tokenizer::from(Arc::new(tokenizer));
Ok(Arc::new(Self { Arc::new(Self {
tokenizer: Some(tokenizer), tokenizer: Some(tokenizer),
validate_engine_decode: false, validate_engine_decode: false,
})) })
} }
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> { pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
let tokenizer = match &mdc.tokenizer { match mdc.tokenizer_hf() {
Some(TokenizerKind::HfTokenizerJson(file)) => { Ok(tokenizer) => Self::from_tokenizer(tokenizer),
HfTokenizer::from_file(file).map_err(Error::msg)? Err(err) => {
} tracing::warn!(%err, "tokenizer_hf error converting ModelDeploymentCard to HF tokenizer");
Some(TokenizerKind::GGUF(t)) => *t.clone(), Arc::new(Self {
None => {
return Ok(Arc::new(Self {
tokenizer: None, tokenizer: None,
validate_engine_decode: false, validate_engine_decode: false,
})); })
} }
}; }
Self::from_tokenizer(tokenizer).await
} }
fn decoder( fn decoder(
......
...@@ -176,13 +176,13 @@ impl ModelWatcher { ...@@ -176,13 +176,13 @@ impl ModelWatcher {
} }
WatchEvent::Delete(kv) => match self.handle_delete(&kv).await { WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
Ok(Some(model_name)) => { Ok(Some(model_name)) => {
tracing::info!("removed model {}", model_name); tracing::info!(model_name, "removed model");
} }
Ok(None) => { Ok(None) => {
// There are other instances running this model, nothing to do // There are other instances running this model, nothing to do
} }
Err(e) => { Err(e) => {
tracing::error!("error removing model: {}", e); tracing::error!(error = %e, "error removing model");
} }
}, },
} }
...@@ -271,7 +271,7 @@ impl ModelWatcher { ...@@ -271,7 +271,7 @@ impl ModelWatcher {
Some(card) Some(card)
} }
Err(err) => { Err(err) => {
tracing::info!(%err, "load_mdc did not complete"); tracing::info!(error = %err, "load_mdc did not complete");
None None
} }
}; };
...@@ -308,6 +308,9 @@ impl ModelWatcher { ...@@ -308,6 +308,9 @@ impl ModelWatcher {
None None
}; };
// This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf()?;
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() { if model_entry.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
...@@ -319,18 +322,23 @@ impl ModelWatcher { ...@@ -319,18 +322,23 @@ impl ModelWatcher {
self.router_mode, self.router_mode,
self.busy_threshold, self.busy_threshold,
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(),
) )
.await?; .await?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)?;
tracing::info!("Chat completions is ready");
} }
// Add completions engine only if the model supports completions // Add completions engine only if the model supports completions
if model_entry.model_type.supports_completions() { if model_entry.model_type.supports_completions() {
let formatter = PromptFormatter::no_op(); let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter; let PromptFormatter::OAI(formatter) = formatter;
let preprocessor = let preprocessor = OpenAIPreprocessor::new_with_parts(
OpenAIPreprocessor::new_with_formatter(card.clone(), formatter).await?; card.clone(),
formatter,
tokenizer_hf.clone(),
)?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::< let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
...@@ -341,10 +349,12 @@ impl ModelWatcher { ...@@ -341,10 +349,12 @@ impl ModelWatcher {
self.busy_threshold, self.busy_threshold,
kv_chooser, kv_chooser,
preprocessor, preprocessor,
tokenizer_hf,
) )
.await?; .await?;
self.manager self.manager
.add_completions_model(&model_entry.name, completions_engine)?; .add_completions_model(&model_entry.name, completions_engine)?;
tracing::info!("Completions is ready");
} }
} else if model_entry.model_input == ModelInput::Text } else if model_entry.model_input == ModelInput::Text
&& model_entry.model_type.supports_chat() && model_entry.model_type.supports_chat()
...@@ -391,8 +401,8 @@ impl ModelWatcher { ...@@ -391,8 +401,8 @@ impl ModelWatcher {
ManyOut<Annotated<NvCreateEmbeddingResponse>>, ManyOut<Annotated<NvCreateEmbeddingResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(&card).into_operator();
let router = PushRouter::< let router = PushRouter::<
PreprocessedEmbeddingRequest, PreprocessedEmbeddingRequest,
......
...@@ -67,7 +67,9 @@ pub async fn run( ...@@ -67,7 +67,9 @@ pub async fn run(
let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?; let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let pre_processor = if prepared_engine.has_tokenizer() { let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?) Some(OpenAIPreprocessor::new(
prepared_engine.card.take().unwrap(),
)?)
} else { } else {
None None
}; };
......
...@@ -11,7 +11,7 @@ use crate::{ ...@@ -11,7 +11,7 @@ use crate::{
kv_router::{KvPushRouter, KvRouter}, kv_router::{KvPushRouter, KvRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
types::{ types::{
...@@ -131,10 +131,18 @@ pub async fn prepare_engine( ...@@ -131,10 +131,18 @@ pub async fn prepare_engine(
None None
}; };
let hf_tokenizer = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone()) >(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
hf_tokenizer,
)
.await?; .await?;
let service_name = local_model.service_name().to_string(); let service_name = local_model.service_name().to_string();
...@@ -167,7 +175,7 @@ pub async fn prepare_engine( ...@@ -167,7 +175,7 @@ pub async fn prepare_engine(
let pipeline = build_pipeline::< let pipeline = build_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine) >(model.card(), inner_engine, model.card().tokenizer_hf()?)
.await?; .await?;
let service_name = model.service_name().to_string(); let service_name = model.service_name().to_string();
...@@ -186,6 +194,7 @@ pub async fn prepare_engine( ...@@ -186,6 +194,7 @@ pub async fn prepare_engine(
pub async fn build_pipeline<Req, Resp>( pub async fn build_pipeline<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
engine: ExecutionContext, engine: ExecutionContext,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>> ) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where where
Req: Data, Req: Data,
...@@ -198,10 +207,11 @@ where ...@@ -198,10 +207,11 @@ where
>, >,
{ {
let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new(); let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor = OpenAIPreprocessor::new((*card).clone()) let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
.await? let preprocessor =
.into_operator(); OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?
let backend = Backend::from_mdc((*card).clone()).await?.into_operator(); .into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let engine = ServiceBackend::from_engine(engine); let engine = ServiceBackend::from_engine(engine);
Ok(frontend Ok(frontend
...@@ -219,6 +229,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -219,6 +229,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
router_mode: RouterMode, router_mode: RouterMode,
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
Req: Data, Req: Data,
...@@ -230,7 +241,9 @@ where ...@@ -230,7 +241,9 @@ where
Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>, Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
>, >,
{ {
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?; let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?;
build_routed_pipeline_with_preprocessor( build_routed_pipeline_with_preprocessor(
card, card,
client, client,
...@@ -238,6 +251,7 @@ where ...@@ -238,6 +251,7 @@ where
busy_threshold, busy_threshold,
chooser, chooser,
preprocessor, preprocessor,
hf_tokenizer,
) )
.await .await
} }
...@@ -249,6 +263,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( ...@@ -249,6 +263,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>, preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
Req: Data, Req: Data,
...@@ -262,8 +277,8 @@ where ...@@ -262,8 +277,8 @@ where
{ {
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new(); let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator(); let migration = Migration::from_mdc(card).into_operator();
let router = let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold( PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(), client.clone(),
...@@ -312,14 +327,14 @@ mod tests { ...@@ -312,14 +327,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH, None).await?; let card = ModelDeploymentCard::load(HF_PATH, None)?;
let engine = crate::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions // Build pipeline for chat completions
let pipeline = build_pipeline::< let pipeline = build_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(&card, engine) >(&card, engine, card.tokenizer_hf()?)
.await?; .await?;
// Verify pipeline was created // Verify pipeline was created
...@@ -331,13 +346,16 @@ mod tests { ...@@ -331,13 +346,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> { async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card // Create test model card
let card = ModelDeploymentCard::load(HF_PATH, None).await?; let card = ModelDeploymentCard::load(HF_PATH, None)?;
let engine = crate::engines::make_engine_core(); let engine = crate::engines::make_engine_core();
// Build pipeline for completions // Build pipeline for completions
let pipeline = let pipeline = build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(
build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine) &card,
.await?; engine,
card.tokenizer_hf()?,
)
.await?;
// Verify pipeline was created // Verify pipeline was created
assert!(Arc::strong_count(&pipeline) >= 1); assert!(Arc::strong_count(&pipeline) >= 1);
......
...@@ -73,9 +73,7 @@ pub async fn run( ...@@ -73,9 +73,7 @@ pub async fn run(
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>, ManyOut<Annotated<BackendOutput>>,
>::new(); >::new();
let backend = Backend::from_mdc(model.card().clone()) let backend = Backend::from_mdc(model.card()).into_operator();
.await?
.into_operator();
let engine = ServiceBackend::from_engine(inner_engine); let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend let pipeline = frontend
.link(backend.forward_edge())? .link(backend.forward_edge())?
......
...@@ -90,18 +90,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -90,18 +90,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None None
}; };
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone()) >(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
)
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::< let completions_engine =
NvCreateCompletionRequest, entrypoint::build_routed_pipeline::<
NvCreateCompletionResponse, NvCreateCompletionRequest,
>(card, &client, router_mode, None, kv_chooser) NvCreateCompletionResponse,
.await?; >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(local_model.display_name(), completions_engine)?;
grpc_service grpc_service
...@@ -122,17 +131,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -122,17 +131,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
let chat_pipeline = common::build_pipeline::< let tokenizer_hf = model.card().tokenizer_hf()?;
NvCreateChatCompletionRequest, let chat_pipeline =
NvCreateChatCompletionStreamResponse, common::build_pipeline::<
>(model.card(), inner_engine.clone()) NvCreateChatCompletionRequest,
.await?; NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
grpc_service grpc_service
......
...@@ -119,18 +119,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -119,18 +119,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None None
}; };
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone()) >(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
)
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::< let completions_engine =
NvCreateCompletionRequest, entrypoint::build_routed_pipeline::<
NvCreateCompletionResponse, NvCreateCompletionRequest,
>(card, &client, router_mode, None, kv_chooser) NvCreateCompletionResponse,
.await?; >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
...@@ -160,17 +169,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -160,17 +169,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let manager = http_service.model_manager(); let manager = http_service.model_manager();
let chat_pipeline = common::build_pipeline::< let tokenizer_hf = model.card().tokenizer_hf()?;
NvCreateChatCompletionRequest, let chat_pipeline =
NvCreateChatCompletionStreamResponse, common::build_pipeline::<
>(model.card(), inner_engine.clone()) NvCreateChatCompletionRequest,
.await?; NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints // Enable all endpoints
......
...@@ -253,8 +253,7 @@ impl LocalModelBuilder { ...@@ -253,8 +253,7 @@ impl LocalModelBuilder {
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path); let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = let mut card =
ModelDeploymentCard::load(&model_config_path, self.custom_template_path.as_deref()) ModelDeploymentCard::load(model_config_path, self.custom_template_path.as_deref())?;
.await?;
// Usually we infer from the path, self.model_name is user override // Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| { let model_name = self.model_name.take().unwrap_or_else(|| {
......
...@@ -28,15 +28,15 @@ pub struct Migration { ...@@ -28,15 +28,15 @@ pub struct Migration {
} }
impl Migration { impl Migration {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> { pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
tracing::debug!( tracing::debug!(
"model {} migration limit {}", "model {} migration limit {}",
mdc.display_name, mdc.display_name,
mdc.migration_limit mdc.migration_limit
); );
Ok(Arc::new(Self { Arc::new(Self {
migration_limit: mdc.migration_limit, migration_limit: mdc.migration_limit,
})) })
} }
} }
......
...@@ -372,19 +372,19 @@ impl ModelDeploymentCard { ...@@ -372,19 +372,19 @@ impl ModelDeploymentCard {
/// - a folder containing config.json, tokenizer.json and token_config.json /// - a folder containing config.json, tokenizer.json and token_config.json
/// - a GGUF file /// - a GGUF file
/// With an optional custom template /// With an optional custom template
pub async fn load( pub fn load(
config_path: impl AsRef<Path>, config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>, custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> { ) -> anyhow::Result<ModelDeploymentCard> {
let config_path = config_path.as_ref(); let config_path = config_path.as_ref();
if config_path.is_dir() { if config_path.is_dir() {
Self::from_local_path(config_path, custom_template_path).await Self::from_local_path(config_path, custom_template_path)
} else { } else {
// GGUF files don't support custom templates yet // GGUF files don't support custom templates yet
if custom_template_path.is_some() { if custom_template_path.is_some() {
anyhow::bail!("Custom templates are not supported for GGUF files"); anyhow::bail!("Custom templates are not supported for GGUF files");
} }
Self::from_gguf(config_path).await Self::from_gguf(config_path)
} }
} }
...@@ -403,7 +403,7 @@ impl ModelDeploymentCard { ...@@ -403,7 +403,7 @@ impl ModelDeploymentCard {
/// - The path doesn't exist or isn't a directory /// - The path doesn't exist or isn't a directory
/// - The path contains invalid Unicode characters /// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid /// - Required model files are missing or invalid
async fn from_local_path( fn from_local_path(
local_root_dir: impl AsRef<Path>, local_root_dir: impl AsRef<Path>,
custom_template_path: Option<&Path>, custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> { ) -> anyhow::Result<Self> {
...@@ -419,10 +419,10 @@ impl ModelDeploymentCard { ...@@ -419,10 +419,10 @@ impl ModelDeploymentCard {
.and_then(|n| n.to_str()) .and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?; .ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
Self::from_repo(&repo_id, model_name, custom_template_path).await Self::from_repo(&repo_id, model_name, custom_template_path)
} }
async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> { fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
let model_name = gguf_file let model_name = gguf_file
.iter() .iter()
.next_back() .next_back()
...@@ -461,14 +461,7 @@ impl ModelDeploymentCard { ...@@ -461,14 +461,7 @@ impl ModelDeploymentCard {
}) })
} }
#[allow(dead_code)] fn from_repo(
async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> {
Err(anyhow::anyhow!(
"ModelDeploymentCard::from_ngc_repo is not implemented"
))
}
async fn from_repo(
repo_id: &str, repo_id: &str,
model_name: &str, model_name: &str,
custom_template_path: Option<&Path>, custom_template_path: Option<&Path>,
...@@ -509,16 +502,16 @@ impl ModelDeploymentCard { ...@@ -509,16 +502,16 @@ impl ModelDeploymentCard {
template_path.display().to_string(), template_path.display().to_string(),
)) ))
} else { } else {
PromptFormatterArtifact::chat_template_from_repo(repo_id).await? PromptFormatterArtifact::chat_template_from_repo(repo_id)?
}; };
Ok(Self { Ok(Self {
display_name: model_name.to_string(), display_name: model_name.to_string(),
slug: Slug::from_string(model_name), slug: Slug::from_string(model_name),
model_info: Some(ModelInfoType::from_repo(repo_id).await?), model_info: Some(ModelInfoType::from_repo(repo_id)?),
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?), tokenizer: Some(TokenizerKind::from_repo(repo_id)?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional gen_config: GenerationConfig::from_repo(repo_id).ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?, prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
chat_template_file, chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
...@@ -568,9 +561,9 @@ pub trait ModelInfo: Send + Sync { ...@@ -568,9 +561,9 @@ pub trait ModelInfo: Send + Sync {
} }
impl ModelInfoType { impl ModelInfoType {
pub async fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> { pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
match self { match self {
Self::HfConfigJson(info) => HFConfig::from_json_file(info).await, Self::HfConfigJson(info) => HFConfig::from_json_file(info),
Self::GGUF(path) => HFConfig::from_gguf(path), Self::GGUF(path) => HFConfig::from_gguf(path),
} }
} }
...@@ -622,7 +615,7 @@ struct HFTextConfig { ...@@ -622,7 +615,7 @@ struct HFTextConfig {
} }
impl HFConfig { impl HFConfig {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> { fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let file_pathbuf = PathBuf::from(file); let file_pathbuf = PathBuf::from(file);
let contents = std::fs::read_to_string(file)?; let contents = std::fs::read_to_string(file)?;
let mut config: Self = serde_json::from_str(&contents)?; let mut config: Self = serde_json::from_str(&contents)?;
...@@ -800,79 +793,76 @@ fn capitalize(s: &str) -> String { ...@@ -800,79 +793,76 @@ fn capitalize(s: &str) -> String {
} }
impl ModelInfoType { impl ModelInfoType {
pub async fn from_repo(repo_id: &str) -> Result<Self> { pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id) Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract model info from repo {}", repo_id)) .with_context(|| format!("unable to extract model info from repo {}", repo_id))
} }
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfConfigJson( Ok(Self::HfConfigJson(check_for_file(repo, "config.json")?))
check_for_file(repo, "config.json").await?,
))
} }
} }
impl PromptFormatterArtifact { impl PromptFormatterArtifact {
pub async fn from_repo(repo_id: &str) -> Result<Option<Self>> { pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
// we should only error if we expect a prompt formatter and it's not found // we should only error if we expect a prompt formatter and it's not found
// right now, we don't know when to expect it, so we just return Ok(Some/None) // right now, we don't know when to expect it, so we just return Ok(Some/None)
Ok(Self::try_is_hf_repo(repo_id) Ok(Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) .with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok()) .ok())
} }
pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> { pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id) Ok(Self::chat_template_try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id)) .with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok()) .ok())
} }
async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate( Ok(Self::HfChatTemplate(check_for_file(
check_for_file(repo, "chat_template.jinja").await?, repo,
)) "chat_template.jinja",
)?))
} }
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson( Ok(Self::HfTokenizerConfigJson(check_for_file(
check_for_file(repo, "tokenizer_config.json").await?, repo,
)) "tokenizer_config.json",
)?))
} }
} }
impl TokenizerKind { impl TokenizerKind {
pub async fn from_repo(repo_id: &str) -> Result<Self> { pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id) Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id)) .with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id))
} }
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerJson( Ok(Self::HfTokenizerJson(check_for_file(
check_for_file(repo, "tokenizer.json").await?, repo,
)) "tokenizer.json",
)?))
} }
} }
impl GenerationConfig { impl GenerationConfig {
pub async fn from_repo(repo_id: &str) -> Result<Self> { pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id) Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract generation config from repo {repo_id}")) .with_context(|| format!("unable to extract generation config from repo {repo_id}"))
} }
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> { fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson( Ok(Self::HfGenerationConfigJson(check_for_file(
check_for_file(repo, "generation_config.json").await?, repo,
)) "generation_config.json",
)?))
} }
} }
/// Checks if the provided path contains the expected file. /// Checks if the provided path contains the expected file.
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> { fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
let p = PathBuf::from(repo_id).join(file); let p = PathBuf::from(repo_id).join(file);
let name = p.display().to_string(); let name = p.display().to_string();
if !p.exists() { if !p.exists() {
...@@ -911,20 +901,20 @@ mod tests { ...@@ -911,20 +901,20 @@ mod tests {
use super::HFConfig; use super::HFConfig;
use std::path::Path; use std::path::Path;
#[tokio::test] #[test]
pub async fn test_config_json_llama3() -> anyhow::Result<()> { pub fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json"); .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?; let config = HFConfig::from_json_file(&config_file.display().to_string())?;
assert_eq!(config.bos_token_id(), 128000); assert_eq!(config.bos_token_id(), 128000);
Ok(()) Ok(())
} }
#[tokio::test] #[test]
pub async fn test_config_json_llama4() -> anyhow::Result<()> { pub fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json"); .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?; let config = HFConfig::from_json_file(&config_file.display().to_string())?;
assert_eq!(config.bos_token_id(), 200000); assert_eq!(config.bos_token_id(), 200000);
Ok(()) Ok(())
} }
......
...@@ -22,7 +22,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; ...@@ -22,7 +22,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tracing; use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind}; use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding; use crate::tokenizers::Encoding;
...@@ -98,38 +98,27 @@ pub struct OpenAIPreprocessor { ...@@ -98,38 +98,27 @@ pub struct OpenAIPreprocessor {
} }
impl OpenAIPreprocessor { impl OpenAIPreprocessor {
pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> { pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let formatter = PromptFormatter::from_mdc(&mdc)?;
let tokenizer = mdc.tokenizer_hf()?;
match formatter { match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_formatter(mdc, formatter).await, PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
} }
} }
pub async fn new_with_formatter( pub fn new_with_parts(
mdc: ModelDeploymentCard, mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum(); let mdcsum = mdc.mdcsum();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => {
HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
}
None => {
anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
);
}
};
let tokenizer = Arc::new(tokenizer);
let Some(model_info) = mdc.model_info else { let Some(model_info) = mdc.model_info else {
anyhow::bail!( anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no model_info" "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
); );
}; };
let model_info = model_info.get_model_info().await?; let model_info = model_info.get_model_info()?;
Ok(Arc::new(Self { Ok(Arc::new(Self {
formatter, formatter,
......
...@@ -17,13 +17,14 @@ use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter}; ...@@ -17,13 +17,14 @@ use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{ChatTemplate, ChatTemplateValue}; use tokcfg::{ChatTemplate, ChatTemplateValue};
impl PromptFormatter { impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> { pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
match mdc match mdc
.prompt_formatter .prompt_formatter
.as_ref()
.ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))? .ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
{ {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => { PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file) let content = std::fs::read_to_string(file)
.with_context(|| format!("fs:read_to_string '{file}'"))?; .with_context(|| format!("fs:read_to_string '{file}'"))?;
let mut config: ChatTemplate = serde_json::from_str(&content)?; let mut config: ChatTemplate = serde_json::from_str(&content)?;
...@@ -32,9 +33,9 @@ impl PromptFormatter { ...@@ -32,9 +33,9 @@ impl PromptFormatter {
// put the chat template into config as normalization. // put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag. // This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) = if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file mdc.chat_template_file.as_ref()
{ {
let chat_template = std::fs::read_to_string(&chat_template_file) let chat_template = std::fs::read_to_string(chat_template_file)
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?; .with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
// clean up the string to remove newlines // clean up the string to remove newlines
let chat_template = chat_template.replace('\n', ""); let chat_template = chat_template.replace('\n', "");
...@@ -43,6 +44,7 @@ impl PromptFormatter { ...@@ -43,6 +44,7 @@ impl PromptFormatter {
Self::from_parts( Self::from_parts(
config, config,
mdc.prompt_context mdc.prompt_context
.clone()
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)), .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
) )
} }
...@@ -50,7 +52,7 @@ impl PromptFormatter { ...@@ -50,7 +52,7 @@ impl PromptFormatter {
"prompt_formatter should not have type HfChatTemplate" "prompt_formatter should not have type HfChatTemplate"
)), )),
PromptFormatterArtifact::GGUF(gguf_path) => { PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?; let config = ChatTemplate::from_gguf(gguf_path)?;
Self::from_parts(config, ContextMixins::default()) Self::from_parts(config, ContextMixins::default())
} }
} }
......
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
use dynamo_llm::backend::Backend; use dynamo_llm::backend::Backend;
use dynamo_llm::model_card::ModelDeploymentCard; use dynamo_llm::model_card::ModelDeploymentCard;
#[tokio::test] #[test]
async fn test_sequence_factory() { fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None) let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None).unwrap();
.await
.unwrap();
let operator = Backend::from_mdc(mdc).await.unwrap(); let operator = Backend::from_mdc(&mdc);
let mut decode_stream = operator let mut decode_stream = operator
.tokenizer .tokenizer
......
...@@ -8,8 +8,8 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1"; ...@@ -8,8 +8,8 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test] #[tokio::test]
async fn test_model_info_from_hf_like_local_repo() { async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
let info = mdc.model_info.unwrap().get_model_info().await.unwrap(); let info = mdc.model_info.unwrap().get_model_info().unwrap();
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1); assert_eq!(info.bos_token_id(), 1);
assert_eq!(info.eos_token_ids(), vec![2]); assert_eq!(info.eos_token_ids(), vec![2]);
...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_model_info_from_non_existent_local_repo() { async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist"; let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::load(path, None).await; let result = ModelDeploymentCard::load(path, None);
assert!(result.is_err()); assert!(result.is_err());
} }
#[tokio::test] #[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() { async fn test_tokenizer_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
// Verify tokenizer file was found // Verify tokenizer file was found
match mdc.tokenizer.unwrap() { match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (), TokenizerKind::HfTokenizerJson(_) => (),
...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() { ...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() { async fn test_prompt_formatter_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap(); let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
// Verify prompt formatter was found // Verify prompt formatter was found
match mdc.prompt_formatter { match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (), Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() { ...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() { async fn test_missing_required_files() {
// Create empty temp directory // Create empty temp directory
let temp_dir = tempdir().unwrap(); let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::load(temp_dir.path(), None).await; let result = ModelDeploymentCard::load(temp_dir.path(), None);
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err().to_string(); let err = result.unwrap_err().to_string();
// Should fail because config.json is missing // Should fail because config.json is missing
......
...@@ -57,9 +57,7 @@ async fn make_mdc_from_repo( ...@@ -57,9 +57,7 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above. //TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await; let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision); let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load(downloaded_path, None) let mut mdc = ModelDeploymentCard::load(downloaded_path, None).unwrap();
.await
.unwrap();
mdc.set_name(&display_name); mdc.set_name(&display_name);
mdc.prompt_context = mixins; mdc.prompt_context = mixins;
mdc mdc
...@@ -285,7 +283,7 @@ async fn test_single_turn() { ...@@ -285,7 +283,7 @@ async fn test_single_turn() {
let mdcs = make_mdcs().await; let mdcs = make_mdcs().await;
for mdc in mdcs.iter() { for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
...@@ -317,7 +315,7 @@ async fn test_single_turn_with_tools() { ...@@ -317,7 +315,7 @@ async fn test_single_turn_with_tools() {
let mdcs = make_mdcs().await; let mdcs = make_mdcs().await;
for mdc in mdcs.iter() { for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
...@@ -354,7 +352,7 @@ async fn test_mulit_turn_without_system() { ...@@ -354,7 +352,7 @@ async fn test_mulit_turn_without_system() {
let mdcs = make_mdcs().await; let mdcs = make_mdcs().await;
for mdc in mdcs.iter() { for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
...@@ -386,7 +384,7 @@ async fn test_mulit_turn_with_system() { ...@@ -386,7 +384,7 @@ async fn test_mulit_turn_with_system() {
let mdcs = make_mdcs().await; let mdcs = make_mdcs().await;
for mdc in mdcs.iter() { for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
...@@ -424,7 +422,7 @@ async fn test_multi_turn_with_system_with_tools() { ...@@ -424,7 +422,7 @@ async fn test_multi_turn_with_system_with_tools() {
let mdcs = make_mdcs().await; let mdcs = make_mdcs().await;
for mdc in mdcs.iter() { for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
...@@ -467,7 +465,7 @@ async fn test_multi_turn_with_continuation() { ...@@ -467,7 +465,7 @@ async fn test_multi_turn_with_continuation() {
) )
.await; .await;
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap(); let formatter = PromptFormatter::from_mdc(&mdc).unwrap();
// assert its an OAI formatter // assert its an OAI formatter
let formatter = match formatter { let formatter = match formatter {
......
...@@ -2265,7 +2265,7 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" ...@@ -2265,7 +2265,7 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.5.0" version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [ dependencies = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment