# SPDX-License-Identifier: Apache-2.0 # mypy: ignore-errors import glob from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import sentencepiece from vllm.transformers_utils.tokenizer_base import TokenizerBase if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ConversationMessage @dataclass class Encoding: input_ids: List[int] class SentencePieceTokenizer(TokenizerBase): """SentencePieceTokenizer""" def __init__(self, model_file): self.name = "SentencePieceTokenizer" self.sp_model = sentencepiece.SentencePieceProcessor( model_file=model_file) # Set special tokens self._special_tokens = {} self._all_special_tokens = [] self._all_special_ids = [] self._vocab = {} for idx in range(self.sp_model.get_piece_size()): self._vocab[self.sp_model.id_to_piece(idx)] = idx if not self.sp_model.is_control(idx): continue self._special_tokens[self.sp_model.id_to_piece(idx)] = idx self._all_special_tokens.append(self.sp_model.id_to_piece(idx)) self._all_special_ids.append(idx) self._special_tokens[self.sp_model.id_to_piece( self.sp_model.unk_id())] = self.sp_model.unk_id() self._all_special_tokens.append( self.sp_model.id_to_piece(self.sp_model.unk_id())) self._all_special_ids.append(self.sp_model.unk_id()) # FIXME: compatible for decode self.length = self.sp_model.get_piece_size() @property def all_special_tokens_extended(self) -> List[str]: return self._all_special_tokens @property def all_special_tokens(self) -> List[str]: return self._all_special_tokens @property def all_special_ids(self) -> List[int]: return self._all_special_ids @property def eos_token_id(self): return self.sp_model.eos_id() @property def eos_token(self): return self.sp_model.id_to_piece(self.eos_token_id) @property def bos_token_id(self): return self.sp_model.bos_id() @property def unk_token_id(self): return self.sp_model.unk_id() @property def sep_token(self) -> str: raise NotImplementedError() @property def pad_token(self) -> str: raise NotImplementedError() @property def vocab_size(self): return self.length @property def is_fast(self) -> bool: return True @property def max_token_id(self) -> int: return self.sp_model.get_piece_size() - 1 def get_vocab(self): return self._vocab def encode_one( self, text: str, truncation: bool = False, max_length: Optional[int] = None, ) -> List[int]: # Mistral Tokenizers should not add special tokens input_ids = self.encode(text) if truncation: input_ids = input_ids[:max_length] return input_ids def encode(self, text: str, add_special_tokens: bool = False, add_bos: bool = True) -> List[int]: if add_special_tokens: # encode control token as normal string parts = [] current_text = text # Find all special tokens and their positions token_positions = [] for token in self._special_tokens: start = 0 while True: idx = current_text.find(token, start) if idx == -1: break token_positions.append((idx, token)) start = idx + 1 # Sort by position to process tokens in order they appear token_positions.sort(key=lambda x: x[0]) # Process tokens in order last_end = 0 for pos, token in token_positions: # Add text before this token if pos > last_end: parts.append(current_text[last_end:pos]) # Add the token parts.append({"token": token}) last_end = pos + len(token) # Add remaining text after last token if last_end < len(current_text): parts.append(current_text[last_end:]) return self.encode_chatml(parts, add_bos=add_bos) else: return self.sp_model.encode(text, add_bos=add_bos) def decode(self, token_ids: Union[List[int], int], skip_special_tokens: bool = True) -> str: if isinstance(token_ids, int): token_ids = [token_ids] elif isinstance(token_ids, list) and token_ids and not isinstance(token_ids[0], int): token_ids = [int(token) for token in token_ids] if skip_special_tokens: # Original behavior: decode all tokens including special ones return self.sp_model.decode(token_ids) else: # Handle special tokens separately result = [] normal_tokens = [] for token_id in token_ids: if token_id in self._all_special_ids: # Decode any accumulated normal tokens first if normal_tokens: result.append(self.sp_model.decode(normal_tokens)) normal_tokens = [] # Add the special token as string result.append(self.convert_id_to_token(token_id)) else: # Accumulate normal tokens normal_tokens.append(token_id) # Decode any remaining normal tokens if normal_tokens: result.append(self.sp_model.decode(normal_tokens)) return ''.join(result) def __call__( self, text: Union[str, List[str], List[int]], text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, max_length: Optional[int] = None, ): input_ids = self.encode(text, add_bos=True) if truncation: input_ids = input_ids[:max_length] return Encoding(input_ids=input_ids) def convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) def convert_tokens_to_ids(self, tokens): return self.sp_model.piece_to_id(tokens) def convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.sp_model.id_to_piece(index) def convert_ids_to_tokens( self, ids, **kwargs): # kwargs for compatibility of HF tokenizer return self.sp_model.id_to_piece(ids) def convert_tokens_to_string(self, tokens, skip_special_tokens=True): # FIXME(ys): hack for tool call tokens if skip_special_tokens: return self.sp_model.decode(tokens) else: # Handle special tokens separately result = [] normal_tokens = [] for token in tokens: if token in self._all_special_tokens: if normal_tokens: result.append(self.sp_model.decode(normal_tokens)) normal_tokens = [] result.append(token) else: # Accumulate normal tokens normal_tokens.append(token) # Decode any remaining normal tokens if normal_tokens: result.append(self.sp_model.decode(normal_tokens)) return ''.join(result) @classmethod def from_pretrained(cls, model_path): if model_path.endswith(".model"): model_file = model_path else: possible_files = glob.glob(f"{model_path}/*.model") if len(possible_files) != 1: raise ValueError( f"Expected exactly one .model file for tokenizer initialization in {model_path}, but found {possible_files}" ) model_file = possible_files[0] return cls(model_file=model_file) def encode_chatml(self, input, add_bos=True): input_ids = [self.bos_token_id] if add_bos else [] if isinstance(input, str): input = [input] # Compatible with the StepChat ChatML Protocol. for subprompt in input: if isinstance(subprompt, str): subprompt_ids = self.encode(subprompt, add_bos=False) input_ids += subprompt_ids elif isinstance(subprompt, dict): if "token" in subprompt: input_ids += [self.convert_token_to_id(subprompt["token"])] return input_ids def get_added_vocab(self): return None def __len__(self): return self.length def apply_chat_template(self, conversation: List["ConversationMessage"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: """Convert chat messages to token IDs sequence. Args: conversation: List of chat messages tools: Tool configurations (optional) Returns: List[int]: Sequence of token IDs """ ret = [self.bos_token_id] continue_final_message = kwargs.get("continue_final_message", False) # Handle tools parameter - properly insert tool_json_schemas message processed_conversation = [] # Add tool schemas message following preprocess_python logic if tools: import json tools_message = { "role": "tool_json_schemas", "content": json.dumps(tools, ensure_ascii=False) # tools should already be in the right format } if conversation and conversation[0]["role"] == "system": processed_conversation.append(conversation[0]) processed_conversation.append(tools_message) processed_conversation.extend(conversation[1:]) else: processed_conversation.append(tools_message) processed_conversation.extend(conversation) else: processed_conversation.extend(conversation) history_tool_calls_map = {} for message in processed_conversation: # Add BOT token ret.append(self._special_tokens["<|BOT|>"]) # Handle tool role - map to appropriate output role if message["role"] == "tool" and message.get("tool_call_id") in history_tool_calls_map: target_tool_call = history_tool_calls_map[message["tool_call_id"]] role = f"{target_tool_call['type']}_output\n{target_tool_call['name']}" else: role = "human" if message["role"] == "user" else message["role"] content = message.get("content") or "" # Process message content if isinstance(content, str): text = f"{role}\n{content}" ret.extend(self.encode(text, add_bos=False)) elif isinstance(content, list): text = f"{role}\n" for item in content: if isinstance(item, dict) and item.get("type") == "text": ret.extend( self.encode(text + item["text"], add_bos=False)) text = "" elif isinstance(item, dict) and item.get("type") == "image": if text: ret.extend(self.encode(text, add_bos=False)) text = "" ret.append(self._special_tokens[""]) elif isinstance(item, dict) and item.get("type") == "audio": if text: ret.extend(self.encode(text, add_bos=False)) text = "" ret.append(self._special_tokens[""]) elif isinstance(item, dict) and item.get("type") == "audio_token": if text: ret.extend(self.encode(text, add_bos=False)) text = "" ret.append(self._special_tokens[""]) ret.extend(self.encode(item["audio_token"], add_bos=False)) ret.append(self._special_tokens[""]) else: # Handle other multimodal content types if text: ret.extend(self.encode(text, add_bos=False)) text = "" # Handle tool calls if message.get("tool_calls"): for tool_call in message["tool_calls"]: if tool_call.get("type") == "function": type_name = "function" function = tool_call.get("function") if function is None: raise ValueError(f"Function is not set for tool call {tool_call.get('id')}: {tool_call}") name = function.get("name") arguments = function.get("arguments") if isinstance(arguments, str): content = arguments else: import json content = json.dumps(arguments, ensure_ascii=False) elif tool_call.get("type") == "code_interpreter": type_name = "code_interpreter" code_interpreter = tool_call.get("code_interpreter") if code_interpreter is None: raise ValueError( f"Code interpreter is not set for tool call {tool_call.get('id')}: {tool_call}" ) name = code_interpreter.get("language") content = code_interpreter.get("code") else: raise ValueError( f"Unknown tool call type {tool_call.get('type')}, must be either 'function' or 'code_interpreter': {tool_call}" ) # Store tool call info for mapping responses history_tool_calls_map[tool_call.get("id")] = { "type": type_name, "name": name, "content": content, } # Add tool call tokens ret.append(self._special_tokens["<|CALL_START|>"]) tool_call_text = f"{type_name}\n{name}\n{content}" ret.extend(self.encode(tool_call_text, add_bos=False)) ret.append(self._special_tokens["<|CALL_END|>"]) ret.append(self._special_tokens["<|EOT|>"]) # If the last message is not from assistant, add assistant prompt if processed_conversation[-1]["role"] != "assistant" and not continue_final_message: ret.append(self._special_tokens["<|BOT|>"]) ret.extend(self.encode("assistant\n", add_bos=False)) # If the last message is from assistant, remove the last EOT token elif ret[-1] == self._special_tokens["<|EOT|>"]: ret.pop() return ret