Unverified Commit 8c40bbb0 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

fix: Add detokenize stream (#2413)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent c0ed76da
......@@ -45,21 +45,50 @@ pub fn decode(c: &mut Criterion) {
29889, 8413, 266, 5062, 364, 25443, 29889, 2296, 3708, 1127, 4964, 368, 29892, 11223, 9109,
472, 3271, 29889,
];
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, false);
let mut decoder = Decoder::new(ds, StopConditions::default());
let mut group = c.benchmark_group("decode-group");
group.throughput(Throughput::Bytes(TEST_TOKS.len() as u64));
group.throughput(Throughput::Elements(TEST_TOKS.len() as u64));
group.bench_function("tokenizer_decoder", |b| {
b.iter(|| {
b.iter_with_setup(
|| {
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default())
},
|mut decoder| {
for tok in black_box(TEST_TOKS) {
let _ = decoder.step(tok).unwrap();
}
})
},
)
});
group.finish();
}
pub fn decode_big(c: &mut Criterion) {
const NUM_TOKENS: usize = 2048;
const BIG_TEST_TOKS: [TokenIdType; NUM_TOKENS] = [450; NUM_TOKENS];
let mut group = c.benchmark_group("decode-big-group");
group.throughput(Throughput::Elements(NUM_TOKENS as u64));
group.bench_function("tokenizer_decoder_big", |b| {
b.iter_with_setup(
|| {
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default())
},
|mut decoder| {
for tok in black_box(&BIG_TEST_TOKS) {
let _ = decoder.step(*tok).unwrap();
}
},
)
});
group.finish();
}
criterion_group!(benches, encode, decode);
criterion_group!(benches, encode, decode, decode_big);
criterion_main!(benches);
......@@ -95,12 +95,16 @@ impl Backend {
fn decoder(
&self,
stream: ManyOut<ExecutionOutputStream>,
prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
};
let decoder = Decoder::new(tokenizer.decode_stream(false), stop_conditions);
let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, false),
stop_conditions,
);
Ok(DecoderUnfoldState {
stream,
......@@ -125,10 +129,13 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let stop_conditions = request.stop_conditions.clone();
let prompt_token_ids = request.token_ids.clone();
let next_stream = next.generate(request).await?;
let context = next_stream.context();
let state = self.decoder(next_stream, stop_conditions)?;
let state = self.decoder(next_stream, &prompt_token_ids, stop_conditions)?;
let processed_stream = stream::unfold(state, |mut state| async move {
match state.stream.next().await {
......
......@@ -105,8 +105,12 @@ impl Tokenizer {
}
/// Create a stateful sequence object for decoding token_ids into text
pub fn decode_stream(&self, skip_special_tokens: bool) -> DecodeStream {
DecodeStream::new(self.0.clone(), skip_special_tokens)
pub fn decode_stream(
&self,
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> DecodeStream {
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
}
}
......@@ -167,6 +171,13 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
}
}
// With incremental detokenization, we need to consider the final context tokens when handling the initial decode tokens.
// This is the initial offset from the end of the context that we start decoding from.
// Both Huggingface TGI and vLLM use this same value.
// See: https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/mamba.py#L169
// and https://github.com/vllm-project/vllm/blob/da2705198fa19030a25d0bea437f7be6547d47d4/vllm/transformers_utils/detokenizer_utils.py#L51
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids.
///
......@@ -188,62 +199,63 @@ pub struct DecodeStream {
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
ids: Vec<u32>,
all_token_ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctlyk
prefix_index: usize,
prefix_offset: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
read_offset: usize,
}
impl DecodeStream {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>, skip_special_tokens: bool) -> Self {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> Self {
let num_input_tokens = prompt_token_ids.len();
let prompt_token_ids = prompt_token_ids.to_vec();
Self {
tokenizer,
skip_special_tokens,
ids: Vec::with_capacity(64),
prefix: String::with_capacity(64),
prefix_index: 0,
read_index: 0,
all_token_ids: prompt_token_ids,
prefix_offset: num_input_tokens
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
read_offset: num_input_tokens,
}
}
/// Step appends a token_id to the internal state and tries to produce a text chunk.
///
/// The method only fails if the internal state is corrupted.
/// Implementation directly copied from Huggingface's TGI:
/// https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/model.py#L144
///
/// Returning `None` means the given id is not enough to produce a chunk.
/// This typically happens with `byte_fallback` options where some tokens do not
/// represent valid UTF-8, and only follow-up token_ids will help produce
/// a valid chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.ids.push(id);
let decoded = self.tokenizer.decode(&self.ids, self.skip_special_tokens)?;
self.all_token_ids.push(id);
if decoded.len() <= self.prefix.len() || decoded.ends_with('�') {
return Ok(None);
}
if !decoded.starts_with(&self.prefix) {
anyhow::bail!("Detokenizer failure: invalid prefix");
}
let new_text = decoded[self.prefix.len()..].to_string();
let prefix_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?;
let new_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..],
self.skip_special_tokens,
)?;
self.prefix = decoded;
self.read_index = self.prefix_index;
if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
let new_text = new_text[prefix_text.len()..].to_string();
let new_prefix_index = self.ids.len() - self.prefix_index;
self.prefix_index = new_prefix_index;
self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len();
Ok(Some(new_text))
} else {
Ok(None)
}
}
}
......
......@@ -12,11 +12,19 @@ async fn test_sequence_factory() {
let operator = Backend::from_mdc(mdc).await.unwrap();
let mut decode_stream = operator.tokenizer.as_ref().unwrap().decode_stream(false);
let mut decode_stream = operator
.tokenizer
.as_ref()
.unwrap()
.decode_stream(&[], false);
let output = decode_stream.step(1).unwrap();
assert_eq!(output, Some("<s>".to_string()));
let mut decode_stream = operator.tokenizer.as_ref().unwrap().decode_stream(true);
let mut decode_stream = operator
.tokenizer
.as_ref()
.unwrap()
.decode_stream(&[], true);
let output = decode_stream.step(1).unwrap();
assert_eq!(output, None);
}
......@@ -37,6 +37,22 @@ const TEST_PROMPTS: [&str; 4] = [
"another prompt",
];
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
("Tell me about the following text.", "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."),
("Tell me about the following text.", "Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt."),
("Tell me about the following text.", "Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem."),
// Note(jthomson04): Ishan asked me to add this one.
("Tell me about the following text.", "In the ancient realm of Tennisia, the very magic of the land is drawn from the sport itself. Forehands light the skies, backhands carve the earth, and serves rumble like thunder across kingdoms. At the center of this balance lie four sacred Grand Slam relics: the Sapphire Trophy of Melbourne, the Emerald Chalice of Paris, the Ruby Crown of London, and the Diamond Orb of New York. Together, they keep the game's spirit alive.
But the relics are scattered, guarded by champions of legendary skill. The first is the Fire King of Clay, ruler of the crimson courts, whose topspin arcs blaze high and heavy, scorching all who dare stand across from him. The second is the Tempest Trickster, master of the baseline fortress, whose footwork and precision can turn back any storm, and whose returns arrive as if pulled by invisible strings. The third is the Shadow-Dancer of the Highlands, a tactician who thrives in the long rallies of twilight, changing pace and spin until opponents lose their rhythm. The fourth and final guardian is a towering Diamond Titan, a net-charging colossus whose volleys shatter the air itself.
Into this arena of gods steps the Silver-Wristed Knight — a player of impossible grace, whose game is an art form. His quest: to claim each relic not for glory, but to restore harmony to the rankings of the realm.
He travels across the Kingdom of Clay, where the points stretch like marathons and the air tastes of iron; through the Grasslands of London, where the ball skids low and the margins are razor-thin; over the Hard Courts of the East, where rallies turn into duels of endurance; and finally to the Cathedral of Lights in New York, where night matches burn with fevered energy.
Each battle is played under enchanted floodlights, the lines patrolled by spectral line judges whose calls are final. The crowd's roar swells with every break point, and the Silver-Wristed Knight's racket glows brightest when the match teeters at deuce. There are moments when doubt grips him — when his serve falters or his touch deserts him — but each challenge teaches a new stroke, culminating in the legendary Forehand of Dawn.
When the last relic is claimed, he stands not as a conqueror but as a custodian of the game, knowing that rivalries forge the very magic he protects. The balance is restored — until the next season begins."),
// Emoji stress test
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
];
const TINYLLAMA_TOKENIZER_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1/tokenizer.json";
const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
......@@ -133,7 +149,7 @@ fn test_sequence() {
assert_eq!(decoder.token_ids(), sequence.token_ids());
assert_eq!(output, TEST_PROMPTS[0]);
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), false);
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), &[], false);
let mut output = String::new();
for token_id in encoding.token_ids() {
let text = decoder.step(*token_id).expect("Failed to decode token_id");
......@@ -143,3 +159,34 @@ fn test_sequence() {
}
assert_eq!(output, TEST_PROMPTS[0]);
}
#[test]
fn test_long_sequence_incremental_decode_with_prefill() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load remote HuggingFace tokenizer");
let shared_tokenizer = Arc::new(tokenizer);
for (input_text, output_text) in LONG_TEST_PROMPTS.iter() {
let input_encoding = shared_tokenizer
.encode(input_text)
.expect("Failed to encode prompt");
let output_encoding = shared_tokenizer
.encode(output_text)
.expect("Failed to encode prompt");
let mut decoder =
DecodeStream::new(shared_tokenizer.clone(), input_encoding.token_ids(), false);
let mut output = String::new();
for token_id in output_encoding.token_ids() {
let text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(text) = text {
output.push_str(text.as_str());
}
}
assert_eq!(output.trim(), output_text.to_string());
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment