Commit c7067fc2 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Build pre-processor from GGUF (#344)

This lets us do:
```
dynamo-run out=llamacpp <gguf_file>
```

Previously a `--model-config <hf-repo>` was also required, to configure our tokenizer.
parent d29f7fcc
...@@ -26,7 +26,7 @@ mod oai; ...@@ -26,7 +26,7 @@ mod oai;
mod tokcfg; mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter}; use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{raise_exception, tojson, ChatTemplate as HfTokenizerConfig}; use tokcfg::ChatTemplate;
impl PromptFormatter { impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> { pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
...@@ -36,16 +36,24 @@ impl PromptFormatter { ...@@ -36,16 +36,24 @@ impl PromptFormatter {
{ {
PromptFormatterArtifact::HfTokenizerConfigJson(file) => { PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(file)?; let content = std::fs::read_to_string(file)?;
let config: HfTokenizerConfig = serde_json::from_str(&content)?; let config: ChatTemplate = serde_json::from_str(&content)?;
let formatter = HfTokenizerConfigJsonFormatter::new( Self::from_parts(
config, config,
mdc.prompt_context mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)), .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)?; )
Ok(Self::OAI(Arc::new(formatter))) }
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default())
} }
} }
} }
pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
Ok(Self::OAI(Arc::new(formatter)))
}
} }
/// Chat Template Jinja Renderer /// Chat Template Jinja Renderer
...@@ -74,7 +82,7 @@ struct JinjaEnvironment { ...@@ -74,7 +82,7 @@ struct JinjaEnvironment {
#[derive(Debug)] #[derive(Debug)]
struct HfTokenizerConfigJsonFormatter { struct HfTokenizerConfigJsonFormatter {
env: Environment<'static>, env: Environment<'static>,
config: HfTokenizerConfig, config: ChatTemplate,
mixins: Arc<ContextMixins>, mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool, supports_add_generation_prompt: bool,
} }
......
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::*; use std::sync::Arc;
use super::tokcfg::{raise_exception, tojson, ChatTemplate};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either; use either::Either;
use minijinja::Environment;
use tracing; use tracing;
impl JinjaEnvironment { impl JinjaEnvironment {
...@@ -35,7 +39,7 @@ impl Default for JinjaEnvironment { ...@@ -35,7 +39,7 @@ impl Default for JinjaEnvironment {
} }
impl HfTokenizerConfigJsonFormatter { impl HfTokenizerConfigJsonFormatter {
pub fn new(config: HfTokenizerConfig, mixins: ContextMixins) -> Result<Self> { pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> {
let mut env = JinjaEnvironment::default().env(); let mut env = JinjaEnvironment::default().env();
let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!( let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
//based on: https://github.com/EricLBuehler/mistral.rs/blob/d970bb5feb863acf8e8ec90de97e18221fb959f1/mistralrs-core/src/pipeline/chat_template.rs //based on: https://github.com/EricLBuehler/mistral.rs/blob/d970bb5feb863acf8e8ec90de97e18221fb959f1/mistralrs-core/src/pipeline/chat_template.rs
use std::collections::HashMap; use std::{collections::HashMap, fs::File, path::Path};
use either::Either; use either::Either;
use ggus::{GGufMetaKV, GGufReader};
use memmap2::Mmap;
use minijinja::{value::Kwargs, Error, ErrorKind, Value}; use minijinja::{value::Kwargs, Error, ErrorKind, Value};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -64,19 +66,22 @@ pub struct PadTokenValue( ...@@ -64,19 +66,22 @@ pub struct PadTokenValue(
#[derive(Debug, Deserialize, Default)] #[derive(Debug, Deserialize, Default)]
/// Template for chat models including bos/eos/unk as well as the chat template. /// Template for chat models including bos/eos/unk as well as the chat template.
pub struct ChatTemplate { pub struct ChatTemplate {
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
pub bos_token: Option<BeginEndUnkTok>, pub bos_token: Option<BeginEndUnkTok>,
pub eos_token: Option<BeginEndUnkTok>,
pub unk_token: Option<BeginEndUnkTok>,
/// Jinja format [chat templating] for chat completion. /// Jinja format [chat templating] for chat completion.
/// ///
/// [chat templating]: https://huggingface.co/docs/transformers/chat_templating /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option<ChatTemplateValue>, pub chat_template: Option<ChatTemplateValue>,
// future
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
clean_up_tokenization_spaces: Option<bool>, clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>, device_map: Option<String>,
pub eos_token: Option<BeginEndUnkTok>,
legacy: Option<bool>, legacy: Option<bool>,
model_max_length: Option<f64>, model_max_length: Option<f64>,
pad_token: Option<PadTokenValue>, pad_token: Option<PadTokenValue>,
...@@ -84,11 +89,74 @@ pub struct ChatTemplate { ...@@ -84,11 +89,74 @@ pub struct ChatTemplate {
spaces_between_special_tokens: Option<bool>, spaces_between_special_tokens: Option<bool>,
tokenizer_class: Option<String>, tokenizer_class: Option<String>,
truncation_size: Option<String>, truncation_size: Option<String>,
pub unk_token: Option<BeginEndUnkTok>,
use_default_system_prompt: Option<bool>, use_default_system_prompt: Option<bool>,
} }
impl ChatTemplate { impl ChatTemplate {
pub fn from_gguf(path: &Path) -> anyhow::Result<Self> {
let file = match File::open(path) {
Ok(f) => unsafe { Mmap::map(&f)? },
Err(e) => {
anyhow::bail!("Failed to open file '{}': {e:?}", path.display());
}
};
let mut reader = GGufReader::new(&file);
let header = match reader.read_header() {
Ok(header) => header,
Err(e) => {
anyhow::bail!("Failed to read GGUF header of {}: {e:?}", path.display());
}
};
let num_metadata = header.metadata_kv_count;
let mut out = ChatTemplate::default();
// bos/eos/unk token conversion
fn convert(kv: GGufMetaKV) -> Option<BeginEndUnkTok> {
let id_string: String = kv.read_unsigned().to_string();
Some(BeginEndUnkTok(Either::Left(id_string)))
}
let mut num_found = 0;
let num_expected = 4; // How many fields we need
for _ in 0..num_metadata {
let kv = match reader.read_meta_kv() {
Ok(kv) => kv,
Err(err) => anyhow::bail!("read_meta_kv error in '{}': {err:?}", path.display()),
};
match kv.key() {
"tokenizer.ggml.bos_token_id" => {
out.bos_token = convert(kv);
num_found += 1;
}
"tokenizer.ggml.eos_token_id" => {
out.eos_token = convert(kv);
num_found += 1;
}
"tokenizer.ggml.unknown_token_id" => {
out.unk_token = convert(kv);
num_found += 1;
}
"tokenizer.chat_template" => {
out.chat_template = kv
.value_reader()
.read_str()
.ok()
.map(|s| ChatTemplateValue(Either::Left(s.to_string())));
num_found += 1;
}
_ => {}
}
if num_found == num_expected {
// No need to look at any more keys
break;
}
}
Ok(out)
}
// pub fn has_chat_template(&self) -> bool { // pub fn has_chat_template(&self) -> bool {
// self.chat_template.is_some() // self.chat_template.is_some()
// } // }
......
...@@ -44,9 +44,8 @@ pub struct BackendOutput { ...@@ -44,9 +44,8 @@ pub struct BackendOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing // TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
// Model Deployment Card checksum
/// Model Deployment Card checksum //pub mdcsum: String,
pub mdcsum: String,
} }
/// The LLM engine and backnd with manage it's own state, specifically translating how a /// The LLM engine and backnd with manage it's own state, specifically translating how a
......
...@@ -46,6 +46,7 @@ async fn test_tokenizer_from_hf_like_local_repo() { ...@@ -46,6 +46,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
// Verify tokenizer file was found // Verify tokenizer file was found
match mdc.tokenizer { match mdc.tokenizer {
TokenizerKind::HfTokenizerJson(_) => (), TokenizerKind::HfTokenizerJson(_) => (),
TokenizerKind::GGUF(_) => (),
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment