Commit 4f6f63cd authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: add rust based tokenizer

parent 53163693
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallType {
Function,
}
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CalledFunction {
pub name: String,
pub arguments: String,
}
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize)]
pub struct ToolCallResponse {
pub id: String,
#[serde(rename = "type")]
pub tp: ToolCallType,
pub function: CalledFunction,
}
...@@ -40,7 +40,7 @@ use super::{ ...@@ -40,7 +40,7 @@ use super::{
validate_logit_bias, ContentProvider, OpenAISamplingOptionsProvider, validate_logit_bias, ContentProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider, OpenAIStopConditionsProvider,
}; };
// use crate::AnnotationsProvider; use triton_distributed::protocols::annotated::AnnotationsProvider;
/// Request object which is used to generate chat completions. /// Request object which is used to generate chat completions.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
...@@ -791,21 +791,21 @@ impl NvExtProvider for ChatCompletionRequest { ...@@ -791,21 +791,21 @@ impl NvExtProvider for ChatCompletionRequest {
} }
} }
// impl AnnotationsProvider for ChatCompletionRequest { impl AnnotationsProvider for ChatCompletionRequest {
// fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.clone()) .and_then(|nvext| nvext.annotations.clone())
// } }
// fn has_annotation(&self, annotation: &str) -> bool { fn has_annotation(&self, annotation: &str) -> bool {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.as_ref()) .and_then(|nvext| nvext.annotations.as_ref())
// .map(|annotations| annotations.contains(&annotation.to_string())) .map(|annotations| annotations.contains(&annotation.to_string()))
// .unwrap_or(false) .unwrap_or(false)
// } }
// } }
impl OpenAISamplingOptionsProvider for ChatCompletionRequest { impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
......
...@@ -32,6 +32,8 @@ use super::{ ...@@ -32,6 +32,8 @@ use super::{
MAX_TOP_P, MIN_FREQUENCY_PENALTY, MIN_PRESENCE_PENALTY, MIN_TEMPERATURE, MIN_TOP_P, MAX_TOP_P, MIN_FREQUENCY_PENALTY, MIN_PRESENCE_PENALTY, MIN_TEMPERATURE, MIN_TOP_P,
}; };
use triton_distributed::protocols::annotated::AnnotationsProvider;
/// Legacy OpenAI CompletionRequest /// Legacy OpenAI CompletionRequest
/// ///
/// Reference: <https://platform.openai.com/docs/api-reference/completions> /// Reference: <https://platform.openai.com/docs/api-reference/completions>
...@@ -392,21 +394,21 @@ impl NvExtProvider for CompletionRequest { ...@@ -392,21 +394,21 @@ impl NvExtProvider for CompletionRequest {
} }
} }
// impl AnnotationsProvider for CompletionRequest { impl AnnotationsProvider for CompletionRequest {
// fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.clone()) .and_then(|nvext| nvext.annotations.clone())
// } }
// fn has_annotation(&self, annotation: &str) -> bool { fn has_annotation(&self, annotation: &str) -> bool {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.as_ref()) .and_then(|nvext| nvext.annotations.as_ref())
// .map(|annotations| annotations.contains(&annotation.to_string())) .map(|annotations| annotations.contains(&annotation.to_string()))
// .unwrap_or(false) .unwrap_or(false)
// } }
// } }
impl OpenAISamplingOptionsProvider for CompletionRequest { impl OpenAISamplingOptionsProvider for CompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod hf;
#[cfg(feature = "sentencepiece")]
pub mod sp;
// TODO: Add tokenizer benchmarks
// TODO: Enable README.md as a module doc
// #[doc = include_str!("../README.md")]
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::{ops::Deref, path::Path};
use crate::protocols::TokenIdType;
pub use anyhow::{Error, Result};
pub use hf::HuggingFaceTokenizer;
#[cfg(feature = "sentencepiece")]
pub use sp::SentencePieceTokenizer;
/// Represents the type of tokenizer being used
#[derive(Debug)]
pub enum TokenizerType {
HuggingFace(String),
#[cfg(feature = "sentencepiece")]
SentencePiece(String),
}
/// character offsets in the original text
pub type Offsets = (usize, usize);
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Hash)]
pub struct Encoding {
pub token_ids: Vec<TokenIdType>,
pub tokens: Vec<String>,
pub spans: Vec<Offsets>,
}
pub mod traits {
use super::*;
pub trait Encoder: Send + Sync {
fn encode(&self, input: &str) -> Result<Encoding>;
}
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
}
pub trait Tokenizer: Encoder + Decoder {
// fn get_vocab_size(&self) -> usize;
// fn make_unique_clone(&self) -> Box<dyn Tokenizer>;
}
}
impl Encoding {
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
impl Tokenizer {
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
}
/// Create a stateful sequence object for decoding token_ids into text
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream {
DecodeStream::new(self.0.clone(), skip_special_tokens)
}
}
impl Deref for Tokenizer {
type Target = Arc<dyn traits::Tokenizer>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
}
impl<T> From<Arc<T>> for Tokenizer
where
T: traits::Tokenizer + 'static, // 'static is required to ensure T can be safely put into an Arc
{
fn from(tokenizer: Arc<T>) -> Self {
Tokenizer(tokenizer)
}
}
/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - model: SentencePiece tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let path = Path::new(file_path);
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;
match extension {
"json" => {
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
Ok(Arc::new(tokenizer))
}
"model" => {
#[cfg(feature = "sentencepiece")]
{
let tokenizer = SentencePieceTokenizer::from_file(file_path)?;
Ok(Arc::new(tokenizer))
}
#[cfg(not(feature = "sentencepiece"))]
{
Err(Error::msg(
"SentencePiece tokenizer not supported".to_string(),
))
}
}
_ => Err(Error::msg("Unsupported file type".to_string())),
}
}
/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids.
///
/// This is necessary because decoding in general cannot achieve that since strings
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces.
pub struct DecodeStream {
/// The tokenizer used to decode token_ids
tokenizer: Arc<dyn traits::Tokenizer>,
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctlyk
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}
impl DecodeStream {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>, skip_special_tokens: bool) -> Self {
Self {
tokenizer,
skip_special_tokens,
ids: Vec::new(),
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}
/// Step appends a token_id to the internal state and tries to produce a text chunk.
///
/// The method only fails if the internal state is corrupted.
///
/// Returning `None` means the given id is not enough to produce a chunk.
/// This typically happens with `byte_fallback` options where some tokens do not
/// represent valid UTF-8, and only follow-up token_ids will help produce
/// a valid chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let string = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
if string.len() > self.prefix.len() && !string.ends_with('�') {
if !(string.starts_with(&self.prefix)) {
anyhow::bail!("Detokenizer failure: invalid prefix");
}
let new_text = &string[self.prefix.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix = self
.tokenizer
.decode(self.ids.as_slice(), self.skip_special_tokens)?;
self.read_index = self.prefix_index;
self.prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
Ok(None)
}
}
}
/// Maintains state for an ongoing sequence of tokens and their decoded text
pub struct Sequence {
/// Encodes text -> token_ids
tokenizer: Tokenizer,
/// The current sequence of token ids
token_ids: Vec<TokenIdType>,
/// The position in the current sequence the last decoded token completed
prefix_offset: usize,
/// Current position in the sequence
read_offset: usize,
}
impl std::fmt::Debug for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequence")
.field("tokenizer", &"Arc<dyn Tokenizer>")
.field(
"token_ids",
&format_args!("{}", {
if self.token_ids.len() <= 20 {
format!("{:?}", self.token_ids)
} else {
let first_ten = &self.token_ids[..10];
let last_ten = &self.token_ids[self.token_ids.len() - 10..];
format!("{:?} ... {:?}", first_ten, last_ten)
}
}),
)
.field("prefix_offset", &self.prefix_offset)
.field("read_offset", &self.read_offset)
.field("token count", &self.token_ids.len())
.finish()
}
}
impl Sequence {
pub fn new(tokenizer: Tokenizer) -> Self {
Self {
tokenizer,
token_ids: Vec::new(),
prefix_offset: 0,
read_offset: 0,
}
}
pub fn is_empty(&self) -> bool {
self.token_ids.is_empty()
}
pub fn len(&self) -> usize {
self.token_ids.len()
}
pub fn clear(&mut self) {
self.token_ids.clear();
self.prefix_offset = 0;
self.read_offset = 0;
}
pub fn append_text(&mut self, input: &str) -> Result<()> {
// let tokenizer = self.tokenizer.read().map_err(|err| {
// Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
// })?;
let encoding = self.tokenizer.encode(input)?;
self.token_ids.extend(encoding.token_ids);
Ok(())
}
// Based on
// https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
// under Apache 2.0 license
pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
self.token_ids.push(token_id);
// log::trace!("pushed token_id: {}", token_id);
let prefix_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?;
let new_text = self
.tokenizer
.decode(&self.token_ids[self.prefix_offset..], false)?;
// if the end character of the previous returned sequence is a multi-byte character
// then we can not split the text on that byte offset, so we roll back to the byte offset
// of the start of that character
let mut prefix_text_len = prefix_text.len();
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
prefix_text_len -= 1;
}
let prefix_text_len = prefix_text_len;
if new_text.len() > prefix_text.len() {
if new_text.ends_with("�") {
return Ok("".to_string());
} else {
// shift and update the state
let new_text = new_text[prefix_text_len..].to_string().replace("�", "");
self.prefix_offset = self.read_offset;
self.read_offset = self.token_ids.len();
return Ok(new_text);
}
}
Ok("".to_string())
}
pub fn tokenizer(&self) -> Tokenizer {
self.tokenizer.clone()
}
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
pub fn text(&self) -> Result<String> {
// let tokenizer = self.tokenizer.read().map_err(|err| {
// Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
// })?;
self.tokenizer.decode(&self.token_ids, false)
}
}
/// The output conditions/values of a SequenceDecoder::add_token_id operation.
/// Result of decoding a token, indicating whether text was produced or a stop condition was met
pub enum SequenceDecoderOutput {
/// The text for the appended token_id
Text(String),
/// A sequence of token_ids has been partially matched a stop sequence, so the text is held
/// until either a match or a divergence
Held,
/// Indicates that a stop sequence has been matched and the decoder is stopped.
/// Subsequent calls to append_token_id will return an error
Stopped,
/// Indicates that a stop token_id has been matched and the decoder is stopped.
/// Subsequent calls to append_token_id will return an error
/// The text for the stop token_id is returned
StoppedWithText(String),
}
/// A Sequence for decoding a stream of token ids into text and detecting stop sequences.
/// A stop sequence is either a matching token_id or a sequence of texts/strings which match.
/// Matches happen first at the token-level, then at the sequence-level. Hidden takes precedence
/// over visible. For example, if you put the same token_id in both `stop_token_ids_visible` and
/// `stop_token_ids_hidden`, the token_id will be treated as hidden.
#[derive(Debug)]
pub struct StopSequenceDecoder {
// The current sequence of token ids
sequence: Sequence,
// Stop Tokens - the presence of any one of these should trigger a stop
// If found, the text for the matched token will be returned
stop_token_ids_visible: Vec<TokenIdType>,
// Stop Tokens - the presence of any one of these should trigger a stop
// If found, the text for the matched token will NOT be returned
stop_token_ids_hidden: Vec<TokenIdType>,
// Stop Words - the presence of any one of these should trigger a stop
// If found, the text for the matched token will be returned
#[allow(dead_code)]
stop_sequences_visible: Vec<String>,
// Stop Words - the presence of any one of these should trigger a stop
// If found, the text for the matched token will NOT be returned
stop_sequences_hidden: Vec<String>,
// If the decoder has observed and returned a stop SequenceDecoderOutput,
// futhur calls to append_token_id will return an error
stopped: bool,
// text jail - if a partial stop sequence is being observed, we hold/jail the text
// until either the stop sequence is matched or the sequence is reset by a divergence
state: String,
}
impl StopSequenceDecoder {
/// Builder object for configurating a StopSequenceDecoder
pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
StopSequenceDecoderBuilder::new(tokenizer)
}
/// Add a token_id to the sequence and return the SequenceDecoderOutput
pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Err(Error::msg("Decoder is stopped"));
}
// update the sequence
let text = self.sequence.append_token_id(token_id)?;
// append the text to the state
self.state.push_str(text.as_str());
let mut stop: bool = false;
let mut visible: bool = false;
if self.stop_token_ids_visible.contains(&token_id) {
stop = true;
visible = true;
}
if self.stop_token_ids_hidden.contains(&token_id) {
stop = true;
visible = false;
}
if stop {
self.stopped = true;
let state = std::mem::take(&mut self.state);
if visible {
return Ok(SequenceDecoderOutput::StoppedWithText(state));
}
return Ok(SequenceDecoderOutput::Stopped);
}
// determine if state matches any of the stop sequences
for stop_sequence in self.stop_sequences_hidden.iter() {
if stop_sequence.starts_with(&self.state) {
if stop_sequence == &self.state {
// on matched stop sequence, we do NOT return the jailed stop sequence
self.stopped = true;
return Ok(SequenceDecoderOutput::Stopped);
} else {
return Ok(SequenceDecoderOutput::Held);
}
}
}
let state = std::mem::take(&mut self.state);
Ok(SequenceDecoderOutput::Text(state))
}
pub fn is_empty(&self) -> bool {
self.sequence.token_ids.is_empty()
}
pub fn len(&self) -> usize {
self.sequence.token_ids.len()
}
pub fn is_complete(&self) -> bool {
self.stopped
}
pub fn close(&mut self) {
self.stopped = true;
}
}
pub struct StopSequenceDecoderBuilder {
tokenizer: Tokenizer,
stop_token_ids_visible: Vec<TokenIdType>,
stop_token_ids_hidden: Vec<TokenIdType>,
stop_sequences_visible: Vec<String>,
stop_sequences_hidden: Vec<String>,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Tokenizer) -> Self {
Self {
tokenizer,
stop_token_ids_visible: Vec::new(),
stop_token_ids_hidden: Vec::new(),
stop_sequences_visible: Vec::new(),
stop_sequences_hidden: Vec::new(),
}
}
/// Adds a visible stop token id to the StopSequenceDecoder
pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
self.stop_token_ids_visible.push(token_id);
self
}
/// Adds a list of visible stop token ids to the StopSequenceDecoder
/// Each token_id is added as for an individual match
pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
self.stop_token_ids_visible.extend(token_ids);
self
}
/// Adds a hidden stop token id to the StopSequenceDecoder
pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
self.stop_token_ids_hidden.push(token_id);
self
}
/// Adds a list of hidden stop token ids to the StopSequenceDecoder
/// Each token_id is added as for an individual match
pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
self.stop_token_ids_hidden.extend(token_ids);
self
}
pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
self.stop_sequences_visible.push(text.to_string());
self
}
pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
self.stop_sequences_visible
.extend(strings.iter().map(|text| text.to_string()));
self
}
pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
self.stop_sequences_hidden.push(text.to_string());
self
}
pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
self.stop_sequences_hidden
.extend(strings.iter().map(|text| text.to_string()));
self
}
pub fn build(self) -> Result<StopSequenceDecoder> {
Ok(StopSequenceDecoder {
sequence: Sequence::new(self.tokenizer.clone()),
stop_token_ids_visible: self.stop_token_ids_visible,
stop_token_ids_hidden: self.stop_token_ids_hidden,
stop_sequences_visible: self.stop_sequences_visible,
stop_sequences_hidden: self.stop_sequences_hidden,
stopped: false,
state: String::new(),
})
}
}
# Tokenizers
## Introduction
`tokenizers` is designed for efficient and versatile tokenization in natural language processing. It supports both HuggingFace and SentencePiece models, offering a streamlined API for text encoding and decoding.
## Features
- **Support for HuggingFace and SentencePiece Tokenizers**: Easily integrate popular tokenization models into your NLP projects.
- **Hash Verification**: Ensures tokenization consistency and accuracy across different models.
- **Simple Encoding and Decoding**: Facilitates the conversion of text to token IDs and back.
- **Sequence Management**: Manage sequences of tokens for complex NLP tasks effectively.
## Quick Start
#### HuggingFace Tokenizer
```rust
use triton_llm::tokenizers::hf::HuggingFaceTokenizer;
let hf_tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer");
```
### Encoding and Decoding Text
```rust
use triton_llm::tokenizers::{HuggingFaceTokenizer, traits::{Encoder, Decoder}};
let tokenizer = HuggingFaceTokenizer::from_file("tests/data/sample-models/TinyLlama_v1.1/tokenizer.json")
.expect("Failed to load HuggingFace tokenizer");
let text = "Your sample text here";
let encoding = tokenizer.encode(text)
.expect("Failed to encode text");
println!("Encoding: {:?}", encoding);
let decoded_text = tokenizer.decode(&encoding.token_ids, false)
.expect("Failed to decode token IDs");
assert_eq!(text, decoded_text);
// Using the Sequence object for encoding and decoding
use triton_llm::tokenizers::{Sequence, Tokenizer};
use std::sync::{Arc, RwLock};
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
let mut sequence = Sequence::new(tokenizer.clone());
sequence.append_text("Your sample text here")
.expect("Failed to append text");
let delta = sequence.append_token_id(1337)
.expect("Failed to append token_id");
```
\ No newline at end of file
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::{
traits::{Decoder, Encoder, Tokenizer},
Encoding, Error, Result, TokenIdType,
};
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
}
impl HuggingFaceTokenizer {
pub fn from_file(model_name: &str) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(model_name)
.map_err(|err| Error::msg(format!("Error loading tokenizer: {}", err)))?;
Ok(HuggingFaceTokenizer { tokenizer })
}
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
HuggingFaceTokenizer { tokenizer }
}
}
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let encoding = self
.tokenizer
.encode(input, false)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
let token_ids = encoding.get_ids().to_vec();
let tokens = encoding.get_tokens().to_vec();
let spans = encoding.get_offsets().to_vec();
Ok(Encoding {
token_ids,
tokens,
spans,
})
}
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let text = self
.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?;
Ok(text)
}
}
impl Tokenizer for HuggingFaceTokenizer {}
impl From<HfTokenizer> for HuggingFaceTokenizer {
fn from(tokenizer: HfTokenizer) -> Self {
HuggingFaceTokenizer { tokenizer }
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::tokenizers::{
traits::{Decoder, Encoder, Tokenizer},
Encoding, Error, Result, TokenIdType,
};
use sentencepiece::SentencePieceProcessor;
/// A tokenizer implementation using the SentencePiece tokenization algorithm.
/// This tokenizer can encode text into tokens and decode tokens back into text.
pub struct SentencePieceTokenizer {
/// The underlying SentencePiece processor instance
spp: SentencePieceProcessor,
}
impl SentencePieceTokenizer {
/// Creates a new SentencePieceTokenizer from a model file.
///
/// # Arguments
/// * `tokenizer_name` - Path to the SentencePiece model file
///
/// # Returns
/// * `Result<Self>` - A new tokenizer instance or an error if loading fails
pub fn from_file(tokenizer_name: &str) -> Result<Self> {
let spp = SentencePieceProcessor::open(tokenizer_name)
.map_err(|err| Error::msg(format!("Error loading tokenizer: {}", err)))?;
Ok(Self { spp })
}
}
impl Encoder for SentencePieceTokenizer {
/// Encodes a string input into tokens using the SentencePiece model.
///
/// # Arguments
/// * `input` - The text to encode
///
/// # Returns
/// * `Result<Encoding>` - The encoded tokens, including IDs, text, and character spans
fn encode(&self, input: &str) -> Result<Encoding> {
let encoding = self
.spp
.encode(input)
.map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?;
let mut token_ids = Vec::new();
let mut tokens = Vec::new();
let mut spans = Vec::new();
for piece in encoding {
token_ids.push(piece.id);
tokens.push(piece.piece);
spans.push((piece.span.0 as usize, piece.span.1 as usize));
}
Ok(Encoding {
token_ids,
tokens,
spans,
})
}
}
impl Decoder for SentencePieceTokenizer {
/// Decodes a sequence of token IDs back into text.
///
/// # Arguments
/// * `token_ids` - The sequence of token IDs to decode
/// * `skip_special_tokens` - Currently unsupported in SentencePieceTokenizer and
/// it will return an error if true
///
/// # Returns
/// * `Result<String>` - The decoded text
///
/// # Errors
/// * Returns an error if skip_special_tokens is true
/// * Returns an error if the decoding process fails
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
if skip_special_tokens {
return Err(Error::msg(
"SentencePieceTokenizer does not support skip_special_tokens=true.",
));
}
let text = self
.spp
.decode_piece_ids(token_ids)
.map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?;
Ok(text)
}
}
/// Implement the Tokenizer trait for SentencePieceTokenizer
impl Tokenizer for SentencePieceTokenizer {}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use triton_llm::backend::Backend;
use triton_llm::model_card::model::ModelDeploymentCard;
#[tokio::test]
async fn test_sequence_factory() {
let mdc = ModelDeploymentCard::from_local_path("tests/data/sample-models/TinyLlama_v1.1", None)
.await
.unwrap();
let operator = Backend::from_mdc(mdc).await.unwrap();
let mut decode_stream = operator.tokenizer.decode_stream(false);
let output = decode_stream.step(1).unwrap();
assert_eq!(output, Some("<s>".to_string()));
let mut decode_stream = operator.tokenizer.decode_stream(true);
let output = decode_stream.step(1).unwrap();
assert_eq!(output, None);
}
...@@ -2,26 +2,23 @@ ...@@ -2,26 +2,23 @@
"architectures": [ "architectures": [
"LlamaForCausalLM" "LlamaForCausalLM"
], ],
"attention_bias": false, "bos_token_id": 1,
"attention_dropout": 0.0, "eos_token_id": 2,
"bos_token_id": 128000,
"eos_token_id": 128009,
"hidden_act": "silu", "hidden_act": "silu",
"hidden_size": 4096, "hidden_size": 2048,
"initializer_range": 0.02, "initializer_range": 0.02,
"intermediate_size": 14336, "intermediate_size": 5632,
"max_position_embeddings": 8192, "max_position_embeddings": 2048,
"model_type": "llama", "model_type": "llama",
"num_attention_heads": 32, "num_attention_heads": 32,
"num_hidden_layers": 32, "num_hidden_layers": 22,
"num_key_value_heads": 8, "num_key_value_heads": 4,
"pretraining_tp": 1, "pretraining_tp": 1,
"rms_norm_eps": 1e-05, "rms_norm_eps": 1e-05,
"rope_scaling": null, "rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false, "tie_word_embeddings": false,
"torch_dtype": "bfloat16", "torch_dtype": "float32",
"transformers_version": "4.40.0.dev0", "transformers_version": "4.31.0.dev0",
"use_cache": true, "use_cache": true,
"vocab_size": 128256 "vocab_size": 32000
} }
{
"bos_token_id": 1,
"eos_token_id": 2,
"pad_token_id": 0,
"max_length": 2048,
"transformers_version": "4.31.0.dev0"
}
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
This source diff could not be displayed because it is too large. You can view the blob instead.
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"legacy": false,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"padding_side": "right",
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
{
"bos_token_id": 128000,
"eos_token_id": [128001, 128009],
"do_sample": true,
"temperature": 0.6,
"max_length": 4096,
"top_p": 0.9,
"transformers_version": "4.40.0.dev0"
}
{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 128000,
"content": "<|begin_of_text|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128001,
"content": "<|end_of_text|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128002,
"content": "<|reserved_special_token_0|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128003,
"content": "<|reserved_special_token_1|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128004,
"content": "<|reserved_special_token_2|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128005,
"content": "<|reserved_special_token_3|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128006,
"content": "<|start_header_id|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128007,
"content": "<|end_header_id|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128008,
"content": "<|reserved_special_token_4|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128009,
"content": "<|eot_id|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 128010,
"content": "<|reserved_special_token_5|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
"pre_tokenizer": {
"type": "Sequence",
"pretokenizers": [
{
"type": "Split",
"pattern": {
"Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
},
"behavior": "Isolated",
"invert": false
},
{
"type": "ByteLevel",
"add_prefix_space": false,
"trim_offsets": true,
"use_regex": false
}
]
},
"post_processor": {
"type": "Sequence",
"processors": [
{
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": false,
"use_regex": true
},
{
"type": "TemplateProcessing",
"single": [
{
"SpecialToken": {
"id": "<|begin_of_text|>",
"type_id": 0
}
},
{
"Sequence": {
"id": "A",
"type_id": 0
}
}
],
"pair": [
{
"SpecialToken": {
"id": "<|begin_of_text|>",
"type_id": 0
}
},
{
"Sequence": {
"id": "A",
"type_id": 0
}
},
{
"SpecialToken": {
"id": "<|begin_of_text|>",
"type_id": 1
}
},
{
"Sequence": {
"id": "B",
"type_id": 1
}
}
],
"special_tokens": {
"<|begin_of_text|>": {
"id": "<|begin_of_text|>",
"ids": [
128000
],
"tokens": [
"<|begin_of_text|>"
]
}
}
}
]
},
"decoder": {
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": true,
"use_regex": true
},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"byte_fallback": false,
"ignore_merges": true,
"vocab": {},
"merges": []
}
}
\ No newline at end of file
{
"added_tokens_decoder": {
"128000": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128001": {
"content": "<|end_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128002": {
"content": "<|reserved_special_token_0|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128003": {
"content": "<|reserved_special_token_1|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128004": {
"content": "<|reserved_special_token_2|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128005": {
"content": "<|reserved_special_token_3|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128006": {
"content": "<|start_header_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128007": {
"content": "<|end_header_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128008": {
"content": "<|reserved_special_token_4|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"128009": {
"content": "<|eot_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|begin_of_text|>",
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim %}{% if loop.first %}{% set content = bos_token + content %}{% endif %}{% if not loop.last %}{% set content = content + '<|eot_id|>'%}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"clean_up_tokenization_spaces": true,
"eos_token": "<|eot_id|>",
"model_input_names": [
"input_ids",
"attention_mask"
],
"model_max_length": 1000000000000000019884624838656,
"tokenizer_class": "PreTrainedTokenizerFast"
}
...@@ -14,33 +14,35 @@ ...@@ -14,33 +14,35 @@
// limitations under the License. // limitations under the License.
use tempfile::tempdir; use tempfile::tempdir;
use triton_llm::model_card::model::{ use triton_llm::model_card::model::{ModelDeploymentCard, PromptFormatterArtifact, TokenizerKind};
ModelDeploymentCard, ModelInfoType, PromptFormatterArtifact, TokenizerKind,
}; const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test] #[tokio::test]
async fn test_model_info_from_hf_like_local_repo() { async fn test_model_info_from_hf_like_local_repo() {
let path = "tests/data/sample-models/mock-llama-3.1-8b-instruct"; let mdc = ModelDeploymentCard::from_local_path(HF_PATH, None)
let mdc = ModelDeploymentCard::from_local_path(path).await.unwrap(); .await
.unwrap();
let info = mdc.model_info.get_model_info().await.unwrap(); let info = mdc.model_info.get_model_info().await.unwrap();
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 128000); assert_eq!(info.bos_token_id(), 1);
assert_eq!(info.eos_token_ids(), vec![128009]); assert_eq!(info.eos_token_ids(), vec![2]);
assert_eq!(info.max_position_embeddings(), 8192); assert_eq!(info.max_position_embeddings(), 2048);
assert_eq!(info.vocab_size(), 128256); assert_eq!(info.vocab_size(), 32000);
} }
#[tokio::test] #[tokio::test]
async fn test_model_info_from_non_existent_local_repo() { async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist"; let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::from_local_path(path).await; let result = ModelDeploymentCard::from_local_path(path, None).await;
assert!(result.is_err()); assert!(result.is_err());
} }
#[tokio::test] #[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() { async fn test_tokenizer_from_hf_like_local_repo() {
let path = "tests/data/sample-models/mock-llama-3.1-8b-instruct"; let mdc = ModelDeploymentCard::from_local_path(HF_PATH, None)
let mdc = ModelDeploymentCard::from_local_path(path).await.unwrap(); .await
.unwrap();
// Verify tokenizer file was found // Verify tokenizer file was found
match mdc.tokenizer { match mdc.tokenizer {
TokenizerKind::HfTokenizerJson(_) => (), TokenizerKind::HfTokenizerJson(_) => (),
...@@ -50,8 +52,9 @@ async fn test_tokenizer_from_hf_like_local_repo() { ...@@ -50,8 +52,9 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() { async fn test_prompt_formatter_from_hf_like_local_repo() {
let path = "tests/data/sample-models/mock-llama-3.1-8b-instruct"; let mdc = ModelDeploymentCard::from_local_path(HF_PATH, None)
let mdc = ModelDeploymentCard::from_local_path(path).await.unwrap(); .await
.unwrap();
// Verify prompt formatter was found // Verify prompt formatter was found
match mdc.prompt_formatter { match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (), Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
...@@ -63,7 +66,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() { ...@@ -63,7 +66,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() { async fn test_missing_required_files() {
// Create empty temp directory // Create empty temp directory
let temp_dir = tempdir().unwrap(); let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::from_local_path(temp_dir.path()).await; let result = ModelDeploymentCard::from_local_path(temp_dir.path(), None).await;
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err().to_string(); let err = result.unwrap_err().to_string();
// Should fail because config.json is missing // Should fail because config.json is missing
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Ok;
use serde::{Deserialize, Serialize};
use triton_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin};
use triton_llm::preprocessor::prompt::PromptFormatter;
use triton_llm::protocols::openai::chat_completions::{
ChatCompletionMessage, ChatCompletionRequest, Tool, ToolChoiceType,
};
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
use std::path::PathBuf;
/// ----------------- NOTE ---------------
/// Currently ModelDeploymentCard does support downloading models using nim-hub.
/// As a temporary workaround, we will download the models from Hugging Face to a local cache
/// directory in `tests/data/sample-models`. These tests require a Hugging Face token to be
/// set in the environment variable `HF_TOKEN`.
/// The model is downloaded and cached in `tests/data/sample-models` directory.
/// make sure the token has access to `meta-llama/Llama-3.1-70B-Instruct` model
fn check_hf_token() -> bool {
let hf_token = std::env::var("HF_TOKEN").ok();
return hf_token.is_some();
}
async fn make_mdc_from_repo(
local_path: &str,
hf_repo: &str,
hf_revision: &str,
mixins: Option<Vec<PromptContextMixin>>,
) -> ModelDeploymentCard {
//TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::from_local_path(downloaded_path, Some(display_name))
.await
.unwrap();
mdc.prompt_context = mixins;
mdc
}
async fn maybe_download_model(local_path: &str, model: &str, revision: &str) -> String {
let cache = Cache::new(PathBuf::from(local_path));
let api = ApiBuilder::from_cache(cache)
.with_progress(false)
.with_token(Some(std::env::var("HF_TOKEN").unwrap()))
.build()
.unwrap();
let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision));
let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"];
let repo_builder = api.repo(repo);
let mut downloaded_path = PathBuf::new();
for file in &files_to_download {
downloaded_path = repo_builder.get(file).await.unwrap();
}
return downloaded_path.parent().unwrap().display().to_string();
}
async fn make_mdcs() -> Vec<ModelDeploymentCard> {
vec![
make_mdc_from_repo(
"tests/data/sample-models",
"meta-llama/Llama-3.1-70B-Instruct",
"1605565",
Some(vec![PromptContextMixin::Llama3DateTime]),
)
.await,
]
}
// fn load_nim_mdcs() -> Vec<ModelDeploymentCard> {
// // get all .json files from test/data/model_deployment_cards/nim
// std::fs::read_dir("tests/data/model_deployment_cards/nim")
// .unwrap()
// .map(|res| res.map(|e| e.path()).unwrap().clone())
// .filter(|path| path.extension().unwrap() == "json")
// .map(|path| ModelDeploymentCard::load_from_json_file(path).unwrap())
// .collect::<Vec<_>>()
// }
// #[ignore]
// #[tokio::test]
// async fn create_mdc_from_repo() {
// for repo in NGC_MODEL_REPOS.iter() {
// println!("Creating MDC for {}", repo);
// let mdc = make_mdc_from_repo(repo).await;
// mdc.save_to_json_file(&format!(
// "tests/data/model_deployment_cards/nim/{}.json",
// Slug::slugify(repo)
// ))
// .unwrap();
// }
// }
const SINGLE_CHAT_MESSAGE: &str = r#"
[
{
"role": "user",
"content": "What is deep learning?"
}
]
"#;
/// Sample Message with `user` and `assistant`, no `system`
const THREE_TURN_CHAT_MESSAGE: &str = r#"
[
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a string in Python using slicing:\n\n```python\nreversed_string = your_string[::-1]\n```\n\nAlternatively, you can use `reversed()` with `join()`:\n\n```python\nreversed_string = ''.join(reversed(your_string))\n```\n"
},
{
"role": "user",
"content": "What if I want to reverse each word in a sentence but keep their order?"
}
]"#;
/// Sample Message with `user` and `assistant`, no `system`
const THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM: &str = r#"
[
{
"role": "system",
"content": "You are a very helpful assistant!"
},
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a string in Python using slicing:\n\n```python\nreversed_string = your_string[::-1]\n```\n\nAlternatively, you can use `reversed()` with `join()`:\n\n```python\nreversed_string = ''.join(reversed(your_string))\n```\n"
},
{
"role": "user",
"content": "What if I want to reverse each word in a sentence but keep their order?"
}
]"#;
/// Sample Message with `user` and `assistant`, no `system`
const MULTI_TURN_WITH_CONTINUATION: &str = r#"
[
{
"role": "system",
"content": "You are a very helpful assistant!"
},
{
"role": "user",
"content": "How do I reverse a string in Python?"
},
{
"role": "assistant",
"content": "You can reverse a "
}
]"#;
const TOOLS: &str = r#"
[
{
"type": "function",
"function": {
"name": "get_current_temperature",
"description": "Get the current temperature for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["Celsius", "Fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": ["location", "unit"]
}
}
},
{
"type": "function",
"function": {
"name": "get_rain_probability",
"description": "Get the probability of rain for a specific location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g., San Francisco, CA"
}
},
"required": ["location"]
}
}
}
]
"#;
#[derive(Serialize, Deserialize)]
struct Request {
messages: Vec<ChatCompletionMessage>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoiceType>,
}
impl Request {
fn from(
messages: &str,
tools: Option<&str>,
tool_choice: Option<ToolChoiceType>,
model: String,
) -> ChatCompletionRequest {
let messages: Vec<ChatCompletionMessage> = serde_json::from_str(messages).unwrap();
let tools: Option<Vec<Tool>> = tools.map(|x| serde_json::from_str(x).unwrap());
ChatCompletionRequest::builder()
.model(model)
.messages(messages)
.tools(tools)
.tool_choice(tool_choice)
.build()
.unwrap()
}
}
#[tokio::test]
async fn test_single_turn() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(SINGLE_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_single_turn_with_tools() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
SINGLE_CHAT_MESSAGE,
Some(TOOLS),
Some(ToolChoiceType::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_mulit_turn_without_system() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(THREE_TURN_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
#[tokio::test]
async fn test_mulit_turn_with_system() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
None,
None,
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
/// Test the prompt formatter with a multi-turn conversation that includes system message and tools
#[tokio::test]
async fn test_multi_turn_with_system_with_tools() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdcs = make_mdcs().await;
for mdc in mdcs.iter() {
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
Some(TOOLS),
Some(ToolChoiceType::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
}
/// Test the prompt formatter with a multi-turn conversation that includes a continuation
#[tokio::test]
async fn test_multi_turn_with_continuation() {
if !check_hf_token() {
println!("HF_TOKEN is not set, skipping test");
return;
}
let mdc = make_mdc_from_repo(
"tests/data/sample-models",
"meta-llama/Llama-3.1-70B-Instruct",
"1605565",
Some(vec![PromptContextMixin::Llama3DateTime]),
)
.await;
let formatter = PromptFormatter::from_mdc(mdc.clone()).await.unwrap();
// assert its an OAI formatter
let formatter = match formatter {
PromptFormatter::OAI(formatter) => Ok(formatter),
}
.unwrap();
let request = Request::from(
MULTI_TURN_WITH_CONTINUATION,
None,
None,
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
insta::with_settings!({
info => &request,
snapshot_suffix => mdc.slug().to_string(),
filters => vec![
(r"Today Date: .*", "Today Date: <redacted>"),
]
}, {
insta::assert_snapshot!(formatted_prompt);
});
}
---
source: triton-llm/tests/preprocessor.rs
expression: formatted_prompt
info:
messages:
- role: system
content: You are a very helpful assistant!
- role: user
content: How do I reverse a string in Python?
- role: assistant
content: "You can reverse a "
model: meta_llama_llama_3_1_70b_instruct__1605565_e45e5991
---
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
Today Date: <redacted>
You are a very helpful assistant!<|eot_id|><|start_header_id|>user<|end_header_id|>
How do I reverse a string in Python?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
You can reverse a<|eot_id|>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment