Commit c7067fc2 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Build pre-processor from GGUF (#344)

This lets us do:
```
dynamo-run out=llamacpp <gguf_file>
```

Previously a `--model-config <hf-repo>` was also required, to configure our tokenizer.
parent d29f7fcc
......@@ -26,7 +26,7 @@ mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::{raise_exception, tojson, ChatTemplate as HfTokenizerConfig};
use tokcfg::ChatTemplate;
impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
......@@ -36,16 +36,24 @@ impl PromptFormatter {
{
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(file)?;
let config: HfTokenizerConfig = serde_json::from_str(&content)?;
let formatter = HfTokenizerConfigJsonFormatter::new(
let config: ChatTemplate = serde_json::from_str(&content)?;
Self::from_parts(
config,
mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)?;
Ok(Self::OAI(Arc::new(formatter)))
)
}
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default())
}
}
}
pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
Ok(Self::OAI(Arc::new(formatter)))
}
}
/// Chat Template Jinja Renderer
......@@ -74,7 +82,7 @@ struct JinjaEnvironment {
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
env: Environment<'static>,
config: HfTokenizerConfig,
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
}
......
......@@ -13,8 +13,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
use std::sync::Arc;
use super::tokcfg::{raise_exception, tojson, ChatTemplate};
use super::{ContextMixins, HfTokenizerConfigJsonFormatter, JinjaEnvironment};
use either::Either;
use minijinja::Environment;
use tracing;
impl JinjaEnvironment {
......@@ -35,7 +39,7 @@ impl Default for JinjaEnvironment {
}
impl HfTokenizerConfigJsonFormatter {
pub fn new(config: HfTokenizerConfig, mixins: ContextMixins) -> Result<Self> {
pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> {
let mut env = JinjaEnvironment::default().env();
let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
......
......@@ -15,9 +15,11 @@
//based on: https://github.com/EricLBuehler/mistral.rs/blob/d970bb5feb863acf8e8ec90de97e18221fb959f1/mistralrs-core/src/pipeline/chat_template.rs
use std::collections::HashMap;
use std::{collections::HashMap, fs::File, path::Path};
use either::Either;
use ggus::{GGufMetaKV, GGufReader};
use memmap2::Mmap;
use minijinja::{value::Kwargs, Error, ErrorKind, Value};
use serde::{Deserialize, Serialize};
......@@ -64,19 +66,22 @@ pub struct PadTokenValue(
#[derive(Debug, Deserialize, Default)]
/// Template for chat models including bos/eos/unk as well as the chat template.
pub struct ChatTemplate {
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
pub bos_token: Option<BeginEndUnkTok>,
pub eos_token: Option<BeginEndUnkTok>,
pub unk_token: Option<BeginEndUnkTok>,
/// Jinja format [chat templating] for chat completion.
///
/// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
pub chat_template: Option<ChatTemplateValue>,
// future
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>,
pub eos_token: Option<BeginEndUnkTok>,
legacy: Option<bool>,
model_max_length: Option<f64>,
pad_token: Option<PadTokenValue>,
......@@ -84,11 +89,74 @@ pub struct ChatTemplate {
spaces_between_special_tokens: Option<bool>,
tokenizer_class: Option<String>,
truncation_size: Option<String>,
pub unk_token: Option<BeginEndUnkTok>,
use_default_system_prompt: Option<bool>,
}
impl ChatTemplate {
pub fn from_gguf(path: &Path) -> anyhow::Result<Self> {
let file = match File::open(path) {
Ok(f) => unsafe { Mmap::map(&f)? },
Err(e) => {
anyhow::bail!("Failed to open file '{}': {e:?}", path.display());
}
};
let mut reader = GGufReader::new(&file);
let header = match reader.read_header() {
Ok(header) => header,
Err(e) => {
anyhow::bail!("Failed to read GGUF header of {}: {e:?}", path.display());
}
};
let num_metadata = header.metadata_kv_count;
let mut out = ChatTemplate::default();
// bos/eos/unk token conversion
fn convert(kv: GGufMetaKV) -> Option<BeginEndUnkTok> {
let id_string: String = kv.read_unsigned().to_string();
Some(BeginEndUnkTok(Either::Left(id_string)))
}
let mut num_found = 0;
let num_expected = 4; // How many fields we need
for _ in 0..num_metadata {
let kv = match reader.read_meta_kv() {
Ok(kv) => kv,
Err(err) => anyhow::bail!("read_meta_kv error in '{}': {err:?}", path.display()),
};
match kv.key() {
"tokenizer.ggml.bos_token_id" => {
out.bos_token = convert(kv);
num_found += 1;
}
"tokenizer.ggml.eos_token_id" => {
out.eos_token = convert(kv);
num_found += 1;
}
"tokenizer.ggml.unknown_token_id" => {
out.unk_token = convert(kv);
num_found += 1;
}
"tokenizer.chat_template" => {
out.chat_template = kv
.value_reader()
.read_str()
.ok()
.map(|s| ChatTemplateValue(Either::Left(s.to_string())));
num_found += 1;
}
_ => {}
}
if num_found == num_expected {
// No need to look at any more keys
break;
}
}
Ok(out)
}
// pub fn has_chat_template(&self) -> bool {
// self.chat_template.is_some()
// }
......
......@@ -44,9 +44,8 @@ pub struct BackendOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
pub finish_reason: Option<FinishReason>,
/// Model Deployment Card checksum
pub mdcsum: String,
// Model Deployment Card checksum
//pub mdcsum: String,
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
......
......@@ -46,6 +46,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
// Verify tokenizer file was found
match mdc.tokenizer {
TokenizerKind::HfTokenizerJson(_) => (),
TokenizerKind::GGUF(_) => (),
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment