lib.rs 24 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Biswa Panda's avatar
Biswa Panda committed
2
3
// SPDX-License-Identifier: Apache-2.0

4
pub mod fastokens;
Biswa Panda's avatar
Biswa Panda committed
5
pub mod hf;
Nikita's avatar
Nikita committed
6
pub mod tiktoken;
Biswa Panda's avatar
Biswa Panda committed
7
8
9
10
11
12
13

// TODO: Add tokenizer benchmarks
// TODO: Enable README.md as a module doc
// #[doc = include_str!("../README.md")]

use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
14
use std::{fs::File, io::BufReader, ops::Deref, path::Path};
Biswa Panda's avatar
Biswa Panda committed
15

16
use anyhow::Context as _;
Biswa Panda's avatar
Biswa Panda committed
17
18
pub use anyhow::{Error, Result};

19
pub use fastokens::FastTokenizer;
Biswa Panda's avatar
Biswa Panda committed
20
pub use hf::HuggingFaceTokenizer;
Nikita's avatar
Nikita committed
21
pub use tiktoken::TikTokenTokenizer;
22
pub use traits::DecodeResult;
Biswa Panda's avatar
Biswa Panda committed
23

24
25
pub type TokenIdType = u32;

Biswa Panda's avatar
Biswa Panda committed
26
27
28
29
/// Represents the type of tokenizer being used
#[derive(Debug)]
pub enum TokenizerType {
    HuggingFace(String),
Nikita's avatar
Nikita committed
30
    TikToken(String),
Biswa Panda's avatar
Biswa Panda committed
31
32
33
34
35
36
}

/// character offsets in the original text
pub type Offsets = (usize, usize);

/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#[derive(Debug, Clone)]
pub enum Encoding {
    /// Hugging Face
    Hf(Box<tokenizers::tokenizer::Encoding>),
    /// Sentence Piece
    Sp(Vec<TokenIdType>),
}

impl Encoding {
    pub fn token_ids(&self) -> &[u32] {
        match self {
            Encoding::Hf(inner) => inner.get_ids(),
            Encoding::Sp(inner) => inner,
        }
    }
}

impl Hash for Encoding {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.token_ids().hash(state);
    }
Biswa Panda's avatar
Biswa Panda committed
58
59
60
61
62
63
64
}

pub mod traits {
    use super::*;

    pub trait Encoder: Send + Sync {
        fn encode(&self, input: &str) -> Result<Encoding>;
65
        fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
Biswa Panda's avatar
Biswa Panda committed
66
67
    }

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    /// Result of decoding token IDs to text.
    ///
    /// Distinguishes between fully valid UTF-8 output and output that contains
    /// trailing incomplete multi-byte sequences (represented as U+FFFD).
    /// This lets callers like `DecodeStream::step()` decide whether to emit or
    /// buffer without resorting to hardcoded replacement-character string checks.
    #[derive(Debug, Clone, PartialEq, Eq, strum::EnumIs)]
    pub enum DecodeResult {
        /// No trailing incomplete multi-byte sequences (text does not end with U+FFFD).
        /// Note: the string may still contain *interior* U+FFFD characters from
        /// mid-stream invalid byte sequences; only trailing status is tracked here.
        Complete(String),
        /// The decoded string ends with U+FFFD, indicating incomplete trailing
        /// multi-byte bytes that may be completed by subsequent tokens.
        Partial(String),
    }

    impl DecodeResult {
        /// Returns a reference to the inner string.
        pub fn as_str(&self) -> &str {
            match self {
                DecodeResult::Complete(s) | DecodeResult::Partial(s) => s,
            }
        }

        /// Construct from a decoded string: `Partial` if it ends with U+FFFD, else `Complete`.
        pub fn from_decoded(text: String) -> Self {
            if text.ends_with('\u{FFFD}') {
                DecodeResult::Partial(text)
            } else {
                DecodeResult::Complete(text)
            }
        }
    }

    impl From<String> for DecodeResult {
        fn from(text: String) -> Self {
            DecodeResult::from_decoded(text)
        }
    }

    impl From<DecodeResult> for String {
        fn from(result: DecodeResult) -> Self {
            match result {
                DecodeResult::Complete(s) | DecodeResult::Partial(s) => s,
            }
        }
    }

    /// Implementations must ensure that partial multi-byte sequences produce U+FFFD
    /// (`\u{FFFD}`) in the output rather than returning `Err`. This is commonly achieved
    /// via `String::from_utf8_lossy` (tiktoken) or library-internal byte-fallback handling
    /// (HuggingFace). `DecodeStream::step()` relies on `DecodeResult::Partial` to detect
    /// incomplete sequences and buffer tokens until the full character arrives.
Biswa Panda's avatar
Biswa Panda committed
122
    pub trait Decoder: Send + Sync {
123
124
125
126
127
        fn decode(
            &self,
            token_ids: &[TokenIdType],
            skip_special_tokens: bool,
        ) -> Result<DecodeResult>;
Biswa Panda's avatar
Biswa Panda committed
128
129
130
131
132
133
134
135
    }

    pub trait Tokenizer: Encoder + Decoder {
        // fn get_vocab_size(&self) -> usize;
        // fn make_unique_clone(&self) -> Box<dyn Tokenizer>;
    }
}

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
pub fn file_json_field<T: serde::de::DeserializeOwned>(
    json_file_path: &Path,
    field_name: &str,
) -> anyhow::Result<T> {
    let file = File::open(json_file_path)
        .with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
    let reader = BufReader::new(file);

    let json_data: serde_json::Value = serde_json::from_reader(reader)
        .with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;

    let map = json_data.as_object().ok_or_else(|| {
        anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
    })?;

    let field_value = map.get(field_name).ok_or_else(|| {
        anyhow::anyhow!(
            "Field '{}' not found in JSON file: {:?}",
            field_name,
            json_file_path
        )
    })?;

    serde_json::from_value(field_value.clone()).with_context(|| {
        format!(
            "Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
            field_name, field_value, json_file_path
        )
    })
}

pub fn log_json_err(filename: &str, json: &str, err: &serde_json::Error) {
    const ERROR_PREFIX: &str = ">>     ";

    if !(err.is_syntax() || err.is_data()) {
        return;
    }

    let line = err.line().saturating_sub(1);
    let column = err.column().saturating_sub(1);

    let json_lines: Vec<&str> = json.lines().collect();
    if json_lines.is_empty() {
        tracing::error!("JSON parsing error in {filename}: File is empty.");
        return;
    }

    let start_index = line.saturating_sub(2);
    let end_index = line.saturating_add(3).min(json_lines.len());

    let mut context_lines: Vec<String> = (start_index..end_index)
        .map(|i| {
            if i == line {
                format!("{ERROR_PREFIX}{}", json_lines[i])
            } else {
                format!("{:06} {}", i + 1, json_lines[i])
            }
        })
        .collect();

    let col_indicator = "_".to_string().repeat(column + ERROR_PREFIX.len()) + "^";
    let error_in_context_idx = line - start_index;
    if error_in_context_idx < context_lines.len() {
        context_lines.insert(error_in_context_idx + 1, col_indicator);
    }

    tracing::error!(
        "JSON parsing error in {filename}: Line {}, column {}:\n{}",
        err.line(),
        err.column(),
        context_lines.join("\n")
    );
}

Biswa Panda's avatar
Biswa Panda committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
impl Encoding {
    pub fn get_hash(&self) -> u64 {
        let mut hasher = DefaultHasher::new();
        self.hash(&mut hasher);
        hasher.finish()
    }
}

/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);

impl Tokenizer {
    pub fn from_file(file_path: &str) -> Result<Tokenizer> {
        Ok(Tokenizer(create_tokenizer_from_file(file_path)?))
    }

    /// Create a stateful sequence object for decoding token_ids into text
228
229
230
231
232
233
    pub fn decode_stream(
        &self,
        prompt_token_ids: &[TokenIdType],
        skip_special_tokens: bool,
    ) -> DecodeStream {
        DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
Biswa Panda's avatar
Biswa Panda committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    }
}

impl Deref for Tokenizer {
    type Target = Arc<dyn traits::Tokenizer>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
    fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
        Tokenizer(tokenizer)
    }
}

impl<T> From<Arc<T>> for Tokenizer
where
    T: traits::Tokenizer + 'static, // 'static is required to ensure T can be safely put into an Arc
{
    fn from(tokenizer: Arc<T>) -> Self {
        Tokenizer(tokenizer)
    }
}

/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
Nikita's avatar
Nikita committed
264
265
/// - model, tiktoken: tiktoken BPE tokenizer (requires `config.json` with a supported
///   `model_type` in the same directory; currently: kimi, kimi_k2, kimi_k25)
Biswa Panda's avatar
Biswa Panda committed
266
267
268
269
270
271
272
273
274
275
276
277
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
    let path = Path::new(file_path);
    let extension = path
        .extension()
        .and_then(std::ffi::OsStr::to_str)
        .ok_or_else(|| Error::msg("Failed to read file extension".to_string()))?;

    match extension {
        "json" => {
            let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
            Ok(Arc::new(tokenizer))
        }
Nikita's avatar
Nikita committed
278
279
280
281
282
283
284
        "model" | "tiktoken" => {
            let tokenizer = TikTokenTokenizer::from_file_auto(file_path)?;
            Ok(Arc::new(tokenizer))
        }
        _ => Err(Error::msg(format!(
            "Unsupported tokenizer file type: .{extension}"
        ))),
Biswa Panda's avatar
Biswa Panda committed
285
286
287
    }
}

288
289
290
291
292
293
294
// With incremental detokenization, we need to consider the final context tokens when handling the initial decode tokens.
// This is the initial offset from the end of the context that we start decoding from.
// Both Huggingface TGI and vLLM use this same value.
// See: https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/mamba.py#L169
// and https://github.com/vllm-project/vllm/blob/da2705198fa19030a25d0bea437f7be6547d47d4/vllm/transformers_utils/detokenizer_utils.py#L51
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;

Biswa Panda's avatar
Biswa Panda committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids.
///
/// This is necessary because decoding in general cannot achieve that since strings
/// depend on surrounding ids to provide a valid string. Typically stripping extra spaces.
pub struct DecodeStream {
    /// The tokenizer used to decode token_ids
    tokenizer: Arc<dyn traits::Tokenizer>,

    skip_special_tokens: bool,
    /// A temporary buffer of the necessary token_ids needed
    /// to produce valid string chunks.
    /// This typically contains 3 parts:
    ///  - read
    ///  - prefix
    ///  - rest
    ///
    /// Read is the bit necessary to surround the prefix
    /// so decoding the whole ids produces a valid prefix.
    /// Prefix is the previously produced string, kept around to trim off of
    /// the next valid chunk
316
    all_token_ids: Vec<u32>,
Biswa Panda's avatar
Biswa Panda committed
317

318
    prefix_offset: usize,
Biswa Panda's avatar
Biswa Panda committed
319

320
    read_offset: usize,
Biswa Panda's avatar
Biswa Panda committed
321
322
323
}

impl DecodeStream {
324
325
326
327
328
329
330
    pub fn new(
        tokenizer: Arc<dyn traits::Tokenizer>,
        prompt_token_ids: &[TokenIdType],
        skip_special_tokens: bool,
    ) -> Self {
        let num_input_tokens = prompt_token_ids.len();
        let prompt_token_ids = prompt_token_ids.to_vec();
Biswa Panda's avatar
Biswa Panda committed
331
332
333
        Self {
            tokenizer,
            skip_special_tokens,
334
335
336
337
            all_token_ids: prompt_token_ids,
            prefix_offset: num_input_tokens
                .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
            read_offset: num_input_tokens,
Biswa Panda's avatar
Biswa Panda committed
338
339
340
341
342
        }
    }

    /// Step appends a token_id to the internal state and tries to produce a text chunk.
    ///
343
344
    /// Implementation directly copied from Huggingface's TGI:
    /// https://github.com/huggingface/text-generation-inference/blob/24c2bff65924801ddf90fa24fcc72752d4f45538/server/text_generation_server/models/model.py#L144
Biswa Panda's avatar
Biswa Panda committed
345
346
347
348
349
350
    ///
    /// Returning `None` means the given id is not enough to produce a chunk.
    /// This typically happens with `byte_fallback` options where some tokens do not
    /// represent valid UTF-8, and only follow-up token_ids will help produce
    /// a valid chunk.
    pub fn step(&mut self, id: u32) -> Result<Option<String>> {
351
        self.all_token_ids.push(id);
Biswa Panda's avatar
Biswa Panda committed
352

353
354
355
356
357
358
359
        let prefix_text: String = self
            .tokenizer
            .decode(
                &self.all_token_ids[self.prefix_offset..self.read_offset],
                self.skip_special_tokens,
            )?
            .into();
360

361
        let new_result = self.tokenizer.decode(
362
363
364
            &self.all_token_ids[self.prefix_offset..],
            self.skip_special_tokens,
        )?;
365

366
367
368
        let new_text = new_result.as_str();
        if new_text.len() > prefix_text.len() && !new_result.is_partial() {
            let emitted = new_text[prefix_text.len()..].to_string();
369

370
371
372
            self.prefix_offset = self.read_offset;
            self.read_offset = self.all_token_ids.len();

373
            Ok(Some(emitted))
374
375
376
        } else {
            Ok(None)
        }
Biswa Panda's avatar
Biswa Panda committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    }
}

/// Maintains state for an ongoing sequence of tokens and their decoded text
pub struct Sequence {
    /// Encodes text -> token_ids
    tokenizer: Tokenizer,

    /// The current sequence of token ids
    token_ids: Vec<TokenIdType>,

    /// The position in the current sequence the last decoded token completed
    prefix_offset: usize,

    /// Current position in the sequence
    read_offset: usize,
}

impl std::fmt::Debug for Sequence {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Sequence")
            .field("tokenizer", &"Arc<dyn Tokenizer>")
            .field(
                "token_ids",
                &format_args!("{}", {
402
403
404
                    let token_ids = self.token_ids();
                    if token_ids.len() <= 20 {
                        format!("{:?}", token_ids)
Biswa Panda's avatar
Biswa Panda committed
405
                    } else {
406
407
                        let first_ten = &token_ids[..10];
                        let last_ten = &token_ids[token_ids.len() - 10..];
Biswa Panda's avatar
Biswa Panda committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                        format!("{:?} ... {:?}", first_ten, last_ten)
                    }
                }),
            )
            .field("prefix_offset", &self.prefix_offset)
            .field("read_offset", &self.read_offset)
            .field("token count", &self.token_ids.len())
            .finish()
    }
}

impl Sequence {
    pub fn new(tokenizer: Tokenizer) -> Self {
        Self {
            tokenizer,
            token_ids: Vec::new(),
            prefix_offset: 0,
            read_offset: 0,
        }
    }

    pub fn is_empty(&self) -> bool {
        self.token_ids.is_empty()
    }

    pub fn len(&self) -> usize {
        self.token_ids.len()
    }

    pub fn clear(&mut self) {
        self.token_ids.clear();
        self.prefix_offset = 0;
        self.read_offset = 0;
    }

    pub fn append_text(&mut self, input: &str) -> Result<()> {
        // let tokenizer = self.tokenizer.read().map_err(|err| {
        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
        // })?;

        let encoding = self.tokenizer.encode(input)?;
449
        self.token_ids.extend(encoding.token_ids());
Biswa Panda's avatar
Biswa Panda committed
450
451
452
453
454
455
456
457
458
459
        Ok(())
    }

    // Based on
    // https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
    // under Apache 2.0 license
    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<String> {
        self.token_ids.push(token_id);
        // log::trace!("pushed token_id: {}", token_id);

460
        let prefix_text: String = self
Biswa Panda's avatar
Biswa Panda committed
461
            .tokenizer
462
463
            .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?
            .into();
Biswa Panda's avatar
Biswa Panda committed
464

465
        let new_result = self
Biswa Panda's avatar
Biswa Panda committed
466
467
468
            .tokenizer
            .decode(&self.token_ids[self.prefix_offset..], false)?;

469
470
        let new_text = new_result.as_str();

Biswa Panda's avatar
Biswa Panda committed
471
472
473
474
475
476
477
478
479
480
        // if the end character of the previous returned sequence is a multi-byte character
        // then we can not split the text on that byte offset, so we roll back to the byte offset
        // of the start of that character
        let mut prefix_text_len = prefix_text.len();
        while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
            prefix_text_len -= 1;
        }
        let prefix_text_len = prefix_text_len;

        if new_text.len() > prefix_text.len() {
481
            if new_result.is_partial() {
Biswa Panda's avatar
Biswa Panda committed
482
483
484
                return Ok("".to_string());
            } else {
                // shift and update the state
485
486
487
                let new_text = new_text[prefix_text_len..]
                    .to_string()
                    .replace('\u{FFFD}', "");
Biswa Panda's avatar
Biswa Panda committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                self.prefix_offset = self.read_offset;
                self.read_offset = self.token_ids.len();
                return Ok(new_text);
            }
        }

        Ok("".to_string())
    }

    pub fn tokenizer(&self) -> Tokenizer {
        self.tokenizer.clone()
    }

    pub fn token_ids(&self) -> &[TokenIdType] {
        &self.token_ids
    }

    pub fn text(&self) -> Result<String> {
        // let tokenizer = self.tokenizer.read().map_err(|err| {
        //     Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
        // })?;
509
        Ok(self.tokenizer.decode(&self.token_ids, false)?.into())
Biswa Panda's avatar
Biswa Panda committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    }
}

/// The output conditions/values of a SequenceDecoder::add_token_id operation.
/// Result of decoding a token, indicating whether text was produced or a stop condition was met
pub enum SequenceDecoderOutput {
    /// The text for the appended token_id
    Text(String),

    /// A sequence of token_ids has been partially matched a stop sequence, so the text is held
    /// until either a match or a divergence
    Held,

    /// Indicates that a stop sequence has been matched and the decoder is stopped.
    /// Subsequent calls to append_token_id will return an error
    Stopped,

    /// Indicates that a stop token_id has been matched and the decoder is stopped.
    /// Subsequent calls to append_token_id will return an error
    /// The text for the stop token_id is returned
    StoppedWithText(String),
}

/// A Sequence for decoding a stream of token ids into text and detecting stop sequences.
/// A stop sequence is either a matching token_id or a sequence of texts/strings which match.
/// Matches happen first at the token-level, then at the sequence-level. Hidden takes precedence
/// over visible. For example, if you put the same token_id in both `stop_token_ids_visible` and
/// `stop_token_ids_hidden`, the token_id will be treated as hidden.
#[derive(Debug)]
pub struct StopSequenceDecoder {
    // The current sequence of token ids
    sequence: Sequence,

    // Stop Tokens - the presence of any one of these should trigger a stop
    // If found, the text for the matched token will be returned
    stop_token_ids_visible: Vec<TokenIdType>,

    // Stop Tokens - the presence of any one of these should trigger a stop
    // If found, the text for the matched token will NOT be returned
    stop_token_ids_hidden: Vec<TokenIdType>,

    // Stop Words - the presence of any one of these should trigger a stop
    // If found, the text for the matched token will be returned
    #[allow(dead_code)]
    stop_sequences_visible: Vec<String>,

    // Stop Words - the presence of any one of these should trigger a stop
    // If found, the text for the matched token will NOT be returned
    stop_sequences_hidden: Vec<String>,

    // If the decoder has observed and returned a stop SequenceDecoderOutput,
    // futhur calls to append_token_id will return an error
    stopped: bool,

    // text jail - if a partial stop sequence is being observed, we hold/jail the text
    // until either the stop sequence is matched or the sequence is reset by a divergence
    state: String,
}

impl StopSequenceDecoder {
    /// Builder object for configurating a StopSequenceDecoder
    pub fn builder(tokenizer: Tokenizer) -> StopSequenceDecoderBuilder {
        StopSequenceDecoderBuilder::new(tokenizer)
    }

    /// Add a token_id to the sequence and return the SequenceDecoderOutput
    pub fn append_token_id(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
        if self.stopped {
            return Err(Error::msg("Decoder is stopped"));
        }

        // update the sequence
        let text = self.sequence.append_token_id(token_id)?;

        // append the text to the state
        self.state.push_str(text.as_str());

        let mut stop: bool = false;
        let mut visible: bool = false;

        if self.stop_token_ids_visible.contains(&token_id) {
            stop = true;
            visible = true;
        }

        if self.stop_token_ids_hidden.contains(&token_id) {
            stop = true;
            visible = false;
        }

        if stop {
            self.stopped = true;
            let state = std::mem::take(&mut self.state);
            if visible {
                return Ok(SequenceDecoderOutput::StoppedWithText(state));
            }
            return Ok(SequenceDecoderOutput::Stopped);
        }

        // determine if state matches any of the stop sequences
        for stop_sequence in self.stop_sequences_hidden.iter() {
            if stop_sequence.starts_with(&self.state) {
                if stop_sequence == &self.state {
                    // on matched stop sequence, we do NOT return the jailed stop sequence
                    self.stopped = true;
                    return Ok(SequenceDecoderOutput::Stopped);
                } else {
                    return Ok(SequenceDecoderOutput::Held);
                }
            }
        }

        let state = std::mem::take(&mut self.state);
        Ok(SequenceDecoderOutput::Text(state))
    }

    pub fn is_empty(&self) -> bool {
        self.sequence.token_ids.is_empty()
    }

    pub fn len(&self) -> usize {
        self.sequence.token_ids.len()
    }

    pub fn is_complete(&self) -> bool {
        self.stopped
    }

    pub fn close(&mut self) {
        self.stopped = true;
    }
}

pub struct StopSequenceDecoderBuilder {
    tokenizer: Tokenizer,
    stop_token_ids_visible: Vec<TokenIdType>,
    stop_token_ids_hidden: Vec<TokenIdType>,
    stop_sequences_visible: Vec<String>,
    stop_sequences_hidden: Vec<String>,
}

impl StopSequenceDecoderBuilder {
    pub fn new(tokenizer: Tokenizer) -> Self {
        Self {
            tokenizer,
            stop_token_ids_visible: Vec::new(),
            stop_token_ids_hidden: Vec::new(),
            stop_sequences_visible: Vec::new(),
            stop_sequences_hidden: Vec::new(),
        }
    }

    /// Adds a visible stop token id to the StopSequenceDecoder
    pub fn add_stop_token_id_visible(mut self, token_id: TokenIdType) -> Self {
        self.stop_token_ids_visible.push(token_id);
        self
    }

    /// Adds a list of visible stop token ids to the StopSequenceDecoder
    /// Each token_id is added as for an individual match
    pub fn add_stop_token_ids_visible(mut self, token_ids: &[TokenIdType]) -> Self {
        self.stop_token_ids_visible.extend(token_ids);
        self
    }

    /// Adds a hidden stop token id to the StopSequenceDecoder
    pub fn add_stop_token_id_hidden(mut self, token_id: TokenIdType) -> Self {
        self.stop_token_ids_hidden.push(token_id);
        self
    }

    /// Adds a list of hidden stop token ids to the StopSequenceDecoder
    /// Each token_id is added as for an individual match
    pub fn add_stop_token_ids_hidden(mut self, token_ids: &[TokenIdType]) -> Self {
        self.stop_token_ids_hidden.extend(token_ids);
        self
    }

    pub fn add_stop_sequence_visible(mut self, text: &str) -> Self {
        self.stop_sequences_visible.push(text.to_string());
        self
    }

    pub fn add_stop_sequences_visible(mut self, strings: &[&str]) -> Self {
        self.stop_sequences_visible
            .extend(strings.iter().map(|text| text.to_string()));
        self
    }

    pub fn add_stop_sequence_hidden(mut self, text: &str) -> Self {
        self.stop_sequences_hidden.push(text.to_string());
        self
    }

    pub fn add_stop_sequences_hidden(mut self, strings: &[&str]) -> Self {
        self.stop_sequences_hidden
            .extend(strings.iter().map(|text| text.to_string()));
        self
    }

    pub fn build(self) -> Result<StopSequenceDecoder> {
        Ok(StopSequenceDecoder {
            sequence: Sequence::new(self.tokenizer.clone()),
            stop_token_ids_visible: self.stop_token_ids_visible,
            stop_token_ids_hidden: self.stop_token_ids_hidden,
            stop_sequences_visible: self.stop_sequences_visible,
            stop_sequences_hidden: self.stop_sequences_hidden,
            stopped: false,
            state: String::new(),
        })
    }
}