Unverified Commit cb5a657a authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Load the tokenizer JSON once for chat and completions. (#2910)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 9ef13289
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use crate::llm::model_card::ModelDeploymentCard;
......@@ -33,10 +21,7 @@ pub(crate) struct Backend {
impl Backend {
#[new]
fn new(mdc: ModelDeploymentCard, endpoint: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
let backend = runtime
.block_on(llm_rs::backend::Backend::from_mdc(mdc.inner))
.map_err(to_pyerr)?;
let backend = llm_rs::backend::Backend::from_mdc(&mdc.inner);
Ok(Self {
inner: backend,
endpoint,
......
......@@ -16,14 +16,10 @@ impl ModelDeploymentCard {}
impl ModelDeploymentCard {
// Previously called "from_local_path"
#[staticmethod]
fn load(path: String, model_name: String, py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut card = RsModelDeploymentCard::load(&path, None)
.await
.map_err(to_pyerr)?;
card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card })
})
fn load(path: String, model_name: String) -> PyResult<ModelDeploymentCard> {
let mut card = RsModelDeploymentCard::load(&path, None).map_err(to_pyerr)?;
card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card })
}
#[staticmethod]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use crate::llm::model_card::ModelDeploymentCard;
......@@ -42,10 +30,7 @@ pub(crate) struct OAIChatPreprocessor {
impl OAIChatPreprocessor {
#[new]
fn new(mdc: ModelDeploymentCard, current: Endpoint, next: Endpoint) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
let preprocessor = runtime
.block_on(OpenAIPreprocessor::new(mdc.inner.clone()))
.map_err(to_pyerr)?;
let preprocessor = OpenAIPreprocessor::new(mdc.inner.clone()).map_err(to_pyerr)?;
Ok(Self {
inner: preprocessor,
current,
......
......@@ -17,11 +17,11 @@
use std::{collections::HashSet, sync::Arc};
use anyhow::{Error, Result};
use anyhow::Result;
use futures::stream::{self, StreamExt};
use tracing as log;
use crate::model_card::{ModelDeploymentCard, TokenizerKind};
use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
......@@ -66,30 +66,27 @@ struct DecoderUnfoldState {
}
impl Backend {
pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Arc<Self> {
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
Ok(Arc::new(Self {
Arc::new(Self {
tokenizer: Some(tokenizer),
validate_engine_decode: false,
}))
})
}
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => {
HfTokenizer::from_file(file).map_err(Error::msg)?
}
Some(TokenizerKind::GGUF(t)) => *t.clone(),
None => {
return Ok(Arc::new(Self {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
match mdc.tokenizer_hf() {
Ok(tokenizer) => Self::from_tokenizer(tokenizer),
Err(err) => {
tracing::warn!(%err, "tokenizer_hf error converting ModelDeploymentCard to HF tokenizer");
Arc::new(Self {
tokenizer: None,
validate_engine_decode: false,
}));
})
}
};
Self::from_tokenizer(tokenizer).await
}
}
fn decoder(
......
......@@ -176,13 +176,13 @@ impl ModelWatcher {
}
WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
Ok(Some(model_name)) => {
tracing::info!("removed model {}", model_name);
tracing::info!(model_name, "removed model");
}
Ok(None) => {
// There are other instances running this model, nothing to do
}
Err(e) => {
tracing::error!("error removing model: {}", e);
tracing::error!(error = %e, "error removing model");
}
},
}
......@@ -271,7 +271,7 @@ impl ModelWatcher {
Some(card)
}
Err(err) => {
tracing::info!(%err, "load_mdc did not complete");
tracing::info!(error = %err, "load_mdc did not complete");
None
}
};
......@@ -308,6 +308,9 @@ impl ModelWatcher {
None
};
// This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf()?;
// Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::<
......@@ -319,18 +322,23 @@ impl ModelWatcher {
self.router_mode,
self.busy_threshold,
kv_chooser.clone(),
tokenizer_hf.clone(),
)
.await?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
tracing::info!("Chat completions is ready");
}
// Add completions engine only if the model supports completions
if model_entry.model_type.supports_completions() {
let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter;
let preprocessor =
OpenAIPreprocessor::new_with_formatter(card.clone(), formatter).await?;
let preprocessor = OpenAIPreprocessor::new_with_parts(
card.clone(),
formatter,
tokenizer_hf.clone(),
)?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
......@@ -341,10 +349,12 @@ impl ModelWatcher {
self.busy_threshold,
kv_chooser,
preprocessor,
tokenizer_hf,
)
.await?;
self.manager
.add_completions_model(&model_entry.name, completions_engine)?;
tracing::info!("Completions is ready");
}
} else if model_entry.model_input == ModelInput::Text
&& model_entry.model_type.supports_chat()
......@@ -391,8 +401,8 @@ impl ModelWatcher {
ManyOut<Annotated<NvCreateEmbeddingResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
let backend = Backend::from_mdc(&card).into_operator();
let router = PushRouter::<
PreprocessedEmbeddingRequest,
......
......@@ -67,7 +67,9 @@ pub async fn run(
let mut prepared_engine = common::prepare_engine(runtime, engine_config).await?;
let pre_processor = if prepared_engine.has_tokenizer() {
Some(OpenAIPreprocessor::new(prepared_engine.card.take().unwrap()).await?)
Some(OpenAIPreprocessor::new(
prepared_engine.card.take().unwrap(),
)?)
} else {
None
};
......
......@@ -11,7 +11,7 @@ use crate::{
kv_router::{KvPushRouter, KvRouter},
migration::Migration,
model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
types::{
......@@ -131,10 +131,18 @@ pub async fn prepare_engine(
None
};
let hf_tokenizer = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone())
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
hf_tokenizer,
)
.await?;
let service_name = local_model.service_name().to_string();
......@@ -167,7 +175,7 @@ pub async fn prepare_engine(
let pipeline = build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine)
>(model.card(), inner_engine, model.card().tokenizer_hf()?)
.await?;
let service_name = model.service_name().to_string();
......@@ -186,6 +194,7 @@ pub async fn prepare_engine(
pub async fn build_pipeline<Req, Resp>(
card: &ModelDeploymentCard,
engine: ExecutionContext,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<Arc<ServiceFrontend<SingleIn<Req>, ManyOut<Annotated<Resp>>>>>
where
Req: Data,
......@@ -198,10 +207,11 @@ where
>,
{
let frontend = ServiceFrontend::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor = OpenAIPreprocessor::new((*card).clone())
.await?
.into_operator();
let backend = Backend::from_mdc((*card).clone()).await?.into_operator();
let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?
.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let engine = ServiceBackend::from_engine(engine);
Ok(frontend
......@@ -219,6 +229,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
router_mode: RouterMode,
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -230,7 +241,9 @@ where
Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>>,
>,
{
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?;
let PromptFormatter::OAI(formatter) = PromptFormatter::from_mdc(card)?;
let preprocessor =
OpenAIPreprocessor::new_with_parts(card.clone(), formatter, hf_tokenizer.clone())?;
build_routed_pipeline_with_preprocessor(
card,
client,
......@@ -238,6 +251,7 @@ where
busy_threshold,
chooser,
preprocessor,
hf_tokenizer,
)
.await
}
......@@ -249,6 +263,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -262,8 +277,8 @@ where
{
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card).into_operator();
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client.clone(),
......@@ -312,14 +327,14 @@ mod tests {
#[tokio::test]
async fn test_build_chat_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let card = ModelDeploymentCard::load(HF_PATH, None)?;
let engine = crate::engines::make_engine_core();
// Build pipeline for chat completions
let pipeline = build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(&card, engine)
>(&card, engine, card.tokenizer_hf()?)
.await?;
// Verify pipeline was created
......@@ -331,13 +346,16 @@ mod tests {
#[tokio::test]
async fn test_build_completions_pipeline_core_engine_succeeds() -> anyhow::Result<()> {
// Create test model card
let card = ModelDeploymentCard::load(HF_PATH, None).await?;
let card = ModelDeploymentCard::load(HF_PATH, None)?;
let engine = crate::engines::make_engine_core();
// Build pipeline for completions
let pipeline =
build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine)
.await?;
let pipeline = build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(
&card,
engine,
card.tokenizer_hf()?,
)
.await?;
// Verify pipeline was created
assert!(Arc::strong_count(&pipeline) >= 1);
......
......@@ -73,9 +73,7 @@ pub async fn run(
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
>::new();
let backend = Backend::from_mdc(model.card().clone())
.await?
.into_operator();
let backend = Backend::from_mdc(model.card()).into_operator();
let engine = ServiceBackend::from_engine(inner_engine);
let pipeline = frontend
.link(backend.forward_edge())?
......
......@@ -90,18 +90,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone())
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser)
.await?;
let completions_engine =
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
grpc_service
......@@ -122,17 +131,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone())
.await?;
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine)
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
grpc_service
......
......@@ -119,18 +119,27 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None
};
let tokenizer_hf = card.tokenizer_hf()?;
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(card, &client, router_mode, None, kv_chooser.clone())
>(
card,
&client,
router_mode,
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser)
.await?;
let completions_engine =
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() {
......@@ -160,17 +169,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let chat_pipeline = common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone())
.await?;
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
common::build_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine)
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints
......
......@@ -253,8 +253,7 @@ impl LocalModelBuilder {
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card =
ModelDeploymentCard::load(&model_config_path, self.custom_template_path.as_deref())
.await?;
ModelDeploymentCard::load(model_config_path, self.custom_template_path.as_deref())?;
// Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| {
......
......@@ -28,15 +28,15 @@ pub struct Migration {
}
impl Migration {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Arc<Self> {
tracing::debug!(
"model {} migration limit {}",
mdc.display_name,
mdc.migration_limit
);
Ok(Arc::new(Self {
Arc::new(Self {
migration_limit: mdc.migration_limit,
}))
})
}
}
......
......@@ -372,19 +372,19 @@ impl ModelDeploymentCard {
/// - a folder containing config.json, tokenizer.json and token_config.json
/// - a GGUF file
/// With an optional custom template
pub async fn load(
pub fn load(
config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> {
let config_path = config_path.as_ref();
if config_path.is_dir() {
Self::from_local_path(config_path, custom_template_path).await
Self::from_local_path(config_path, custom_template_path)
} else {
// GGUF files don't support custom templates yet
if custom_template_path.is_some() {
anyhow::bail!("Custom templates are not supported for GGUF files");
}
Self::from_gguf(config_path).await
Self::from_gguf(config_path)
}
}
......@@ -403,7 +403,7 @@ impl ModelDeploymentCard {
/// - The path doesn't exist or isn't a directory
/// - The path contains invalid Unicode characters
/// - Required model files are missing or invalid
async fn from_local_path(
fn from_local_path(
local_root_dir: impl AsRef<Path>,
custom_template_path: Option<&Path>,
) -> anyhow::Result<Self> {
......@@ -419,10 +419,10 @@ impl ModelDeploymentCard {
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid model directory name"))?;
Self::from_repo(&repo_id, model_name, custom_template_path).await
Self::from_repo(&repo_id, model_name, custom_template_path)
}
async fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
fn from_gguf(gguf_file: &Path) -> anyhow::Result<Self> {
let model_name = gguf_file
.iter()
.next_back()
......@@ -461,14 +461,7 @@ impl ModelDeploymentCard {
})
}
#[allow(dead_code)]
async fn from_ngc_repo(_: &str) -> anyhow::Result<Self> {
Err(anyhow::anyhow!(
"ModelDeploymentCard::from_ngc_repo is not implemented"
))
}
async fn from_repo(
fn from_repo(
repo_id: &str,
model_name: &str,
custom_template_path: Option<&Path>,
......@@ -509,16 +502,16 @@ impl ModelDeploymentCard {
template_path.display().to_string(),
))
} else {
PromptFormatterArtifact::chat_template_from_repo(repo_id).await?
PromptFormatterArtifact::chat_template_from_repo(repo_id)?
};
Ok(Self {
display_name: model_name.to_string(),
slug: Slug::from_string(model_name),
model_info: Some(ModelInfoType::from_repo(repo_id).await?),
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
model_info: Some(ModelInfoType::from_repo(repo_id)?),
tokenizer: Some(TokenizerKind::from_repo(repo_id)?),
gen_config: GenerationConfig::from_repo(repo_id).ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
......@@ -568,9 +561,9 @@ pub trait ModelInfo: Send + Sync {
}
impl ModelInfoType {
pub async fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
pub fn get_model_info(&self) -> Result<Arc<dyn ModelInfo>> {
match self {
Self::HfConfigJson(info) => HFConfig::from_json_file(info).await,
Self::HfConfigJson(info) => HFConfig::from_json_file(info),
Self::GGUF(path) => HFConfig::from_gguf(path),
}
}
......@@ -622,7 +615,7 @@ struct HFTextConfig {
}
impl HFConfig {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let file_pathbuf = PathBuf::from(file);
let contents = std::fs::read_to_string(file)?;
let mut config: Self = serde_json::from_str(&contents)?;
......@@ -800,79 +793,76 @@ fn capitalize(s: &str) -> String {
}
impl ModelInfoType {
pub async fn from_repo(repo_id: &str) -> Result<Self> {
pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract model info from repo {}", repo_id))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfConfigJson(
check_for_file(repo, "config.json").await?,
))
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfConfigJson(check_for_file(repo, "config.json")?))
}
}
impl PromptFormatterArtifact {
pub async fn from_repo(repo_id: &str) -> Result<Option<Self>> {
pub fn from_repo(repo_id: &str) -> Result<Option<Self>> {
// we should only error if we expect a prompt formatter and it's not found
// right now, we don't know when to expect it, so we just return Ok(Some/None)
Ok(Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok())
}
pub async fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
pub fn chat_template_from_repo(repo_id: &str) -> Result<Option<Self>> {
Ok(Self::chat_template_try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract prompt format from repo {}", repo_id))
.ok())
}
async fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(
check_for_file(repo, "chat_template.jinja").await?,
))
fn chat_template_try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfChatTemplate(check_for_file(
repo,
"chat_template.jinja",
)?))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson(
check_for_file(repo, "tokenizer_config.json").await?,
))
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerConfigJson(check_for_file(
repo,
"tokenizer_config.json",
)?))
}
}
impl TokenizerKind {
pub async fn from_repo(repo_id: &str) -> Result<Self> {
pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract tokenizer kind from repo {}", repo_id))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerJson(
check_for_file(repo, "tokenizer.json").await?,
))
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfTokenizerJson(check_for_file(
repo,
"tokenizer.json",
)?))
}
}
impl GenerationConfig {
pub async fn from_repo(repo_id: &str) -> Result<Self> {
pub fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract generation config from repo {repo_id}"))
}
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson(
check_for_file(repo, "generation_config.json").await?,
))
fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson(check_for_file(
repo,
"generation_config.json",
)?))
}
}
/// Checks if the provided path contains the expected file.
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
let p = PathBuf::from(repo_id).join(file);
let name = p.display().to_string();
if !p.exists() {
......@@ -911,20 +901,20 @@ mod tests {
use super::HFConfig;
use std::path::Path;
#[tokio::test]
pub async fn test_config_json_llama3() -> anyhow::Result<()> {
#[test]
pub fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
let config = HFConfig::from_json_file(&config_file.display().to_string())?;
assert_eq!(config.bos_token_id(), 128000);
Ok(())
}
#[tokio::test]
pub async fn test_config_json_llama4() -> anyhow::Result<()> {
#[test]
pub fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
let config = HFConfig::from_json_file(&config_file.display().to_string())?;
assert_eq!(config.bos_token_id(), 200000);
Ok(())
}
......
......@@ -22,7 +22,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{collections::HashMap, sync::Arc};
use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding;
......@@ -98,38 +98,27 @@ pub struct OpenAIPreprocessor {
}
impl OpenAIPreprocessor {
pub async fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let formatter = PromptFormatter::from_mdc(&mdc)?;
let tokenizer = mdc.tokenizer_hf()?;
match formatter {
PromptFormatter::OAI(formatter) => Self::new_with_formatter(mdc, formatter).await,
PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
}
}
pub async fn new_with_formatter(
pub fn new_with_parts(
mdc: ModelDeploymentCard,
formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => {
HuggingFaceTokenizer::from_tokenizer(*tokenizer.clone())
}
None => {
anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no tokenizer"
);
}
};
let tokenizer = Arc::new(tokenizer);
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else {
anyhow::bail!(
"Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
);
};
let model_info = model_info.get_model_info().await?;
let model_info = model_info.get_model_info()?;
Ok(Arc::new(Self {
formatter,
......
......@@ -17,13 +17,14 @@ use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{ChatTemplate, ChatTemplateValue};
impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
pub fn from_mdc(mdc: &ModelDeploymentCard) -> Result<PromptFormatter> {
match mdc
.prompt_formatter
.as_ref()
.ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
{
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file)
let content = std::fs::read_to_string(file)
.with_context(|| format!("fs:read_to_string '{file}'"))?;
let mut config: ChatTemplate = serde_json::from_str(&content)?;
......@@ -32,9 +33,9 @@ impl PromptFormatter {
// put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate(chat_template_file)) =
mdc.chat_template_file
mdc.chat_template_file.as_ref()
{
let chat_template = std::fs::read_to_string(&chat_template_file)
let chat_template = std::fs::read_to_string(chat_template_file)
.with_context(|| format!("fs:read_to_string '{}'", chat_template_file))?;
// clean up the string to remove newlines
let chat_template = chat_template.replace('\n', "");
......@@ -43,6 +44,7 @@ impl PromptFormatter {
Self::from_parts(
config,
mdc.prompt_context
.clone()
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)
}
......@@ -50,7 +52,7 @@ impl PromptFormatter {
"prompt_formatter should not have type HfChatTemplate"
)),
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
let config = ChatTemplate::from_gguf(gguf_path)?;
Self::from_parts(config, ContextMixins::default())
}
}
......
......@@ -4,13 +4,11 @@
use dynamo_llm::backend::Backend;
use dynamo_llm::model_card::ModelDeploymentCard;
#[tokio::test]
async fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None)
.await
.unwrap();
#[test]
fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None).unwrap();
let operator = Backend::from_mdc(mdc).await.unwrap();
let operator = Backend::from_mdc(&mdc);
let mut decode_stream = operator
.tokenizer
......
......@@ -8,8 +8,8 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test]
async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
let info = mdc.model_info.unwrap().get_model_info().await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
let info = mdc.model_info.unwrap().get_model_info().unwrap();
assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1);
assert_eq!(info.eos_token_ids(), vec![2]);
......@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() {
#[tokio::test]
async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::load(path, None).await;
let result = ModelDeploymentCard::load(path, None);
assert!(result.is_err());
}
#[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
// Verify tokenizer file was found
match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (),
......@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).await.unwrap();
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap();
// Verify prompt formatter was found
match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
......@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() {
// Create empty temp directory
let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::load(temp_dir.path(), None).await;
let result = ModelDeploymentCard::load(temp_dir.path(), None);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
// Should fail because config.json is missing
......
......@@ -57,9 +57,7 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load(downloaded_path, None)
.await
.unwrap();
let mut mdc = ModelDeploymentCard::load(downloaded_path, None).unwrap();
mdc.set_name(&display_name);
mdc.prompt_context = mixins;
mdc
......@@ -285,7 +283,7 @@ async fn test_single_turn() {
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......@@ -317,7 +315,7 @@ async fn test_single_turn_with_tools() {
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......@@ -354,7 +352,7 @@ async fn test_mulit_turn_without_system() {
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......@@ -386,7 +384,7 @@ async fn test_mulit_turn_with_system() {
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......@@ -424,7 +422,7 @@ async fn test_multi_turn_with_system_with_tools() {
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......@@ -467,7 +465,7 @@ async fn test_multi_turn_with_continuation() {
)
.await;
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
let formatter = PromptFormatter::from_mdc(&mdc).unwrap();
// assert its an OAI formatter
let formatter = match formatter {
......
......@@ -2265,7 +2265,7 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustc_version"
version = "0.5.0"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment