# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List, Optional from sentencepiece import SentencePieceProcessor class Tokenizer: def __init__(self, model_path: Optional[str]): # Reload tokenizer. assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) # BOS / EOS token IDs. self.n_words: int = self.sp_model.vocab_size() self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.pad_id() assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: """Converts a string into a list of tokens.""" assert isinstance(s, str) t = self.sp_model.encode(s) if bos: t = [self.bos_id] + t if eos: t = t + [self.eos_id] return t def decode(self, t: List[int]) -> str: """Converts a list of tokens into a string.""" return self.sp_model.decode(t)