// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use crate::tokenizers::{ traits::{Decoder, Encoder, Tokenizer}, Encoding, Error, Result, TokenIdType, }; use sentencepiece::SentencePieceProcessor; /// A tokenizer implementation using the SentencePiece tokenization algorithm. /// This tokenizer can encode text into tokens and decode tokens back into text. pub struct SentencePieceTokenizer { /// The underlying SentencePiece processor instance spp: SentencePieceProcessor, } impl SentencePieceTokenizer { /// Creates a new SentencePieceTokenizer from a model file. /// /// # Arguments /// * `tokenizer_name` - Path to the SentencePiece model file /// /// # Returns /// * `Result` - A new tokenizer instance or an error if loading fails pub fn from_file(tokenizer_name: &str) -> Result { let spp = SentencePieceProcessor::open(tokenizer_name) .map_err(|err| Error::msg(format!("Error loading tokenizer: {}", err)))?; Ok(Self { spp }) } } impl Encoder for SentencePieceTokenizer { /// Encodes a string input into tokens using the SentencePiece model. /// /// # Arguments /// * `input` - The text to encode /// /// # Returns /// * `Result` - The encoded tokens, including IDs, text, and character spans fn encode(&self, input: &str) -> Result { let encoding = self .spp .encode(input) .map_err(|err| Error::msg(format!("Error encoding input: {}", err)))?; let mut token_ids = Vec::new(); let mut tokens = Vec::new(); let mut spans = Vec::new(); for piece in encoding { token_ids.push(piece.id); tokens.push(piece.piece); spans.push((piece.span.0 as usize, piece.span.1 as usize)); } Ok(Encoding { token_ids, tokens, spans, }) } } impl Decoder for SentencePieceTokenizer { /// Decodes a sequence of token IDs back into text. /// /// # Arguments /// * `token_ids` - The sequence of token IDs to decode /// * `skip_special_tokens` - Currently unsupported in SentencePieceTokenizer and /// it will return an error if true /// /// # Returns /// * `Result` - The decoded text /// /// # Errors /// * Returns an error if skip_special_tokens is true /// * Returns an error if the decoding process fails fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result { if skip_special_tokens { return Err(Error::msg( "SentencePieceTokenizer does not support skip_special_tokens=true.", )); } let text = self .spp .decode_piece_ids(token_ids) .map_err(|err| Error::msg(format!("Error decoding input: {}", err)))?; Ok(text) } } /// Implement the Tokenizer trait for SentencePieceTokenizer impl Tokenizer for SentencePieceTokenizer {}