from typing import List, Optional, Tuple, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import * from vllm.utils import make_async logger = init_logger(__name__) def get_cached_tokenizer( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Get tokenizer with cached properties. This will patch the tokenizer object in place. By default, transformers will recompute multiple tokenizer properties each time they are called, leading to a significant slowdown. This function caches these properties for faster access.""" tokenizer_all_special_ids = set(tokenizer.all_special_ids) tokenizer_all_special_tokens_extended = ( tokenizer.all_special_tokens_extended) tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) class CachedTokenizer(tokenizer.__class__): @property def all_special_ids(self): return tokenizer_all_special_ids @property def all_special_tokens(self): return tokenizer_all_special_tokens @property def all_special_tokens_extended(self): return tokenizer_all_special_tokens_extended def __len__(self): return tokenizer_len CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" tokenizer.__class__ = CachedTokenizer return tokenizer def get_tokenizer( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError( "Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs) except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. if (not trust_remote_code and ("does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e))): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " "or using the `--trust-remote-code` flag in the CLI.") raise RuntimeError(err_msg) from e else: raise e except AttributeError as e: if "BaichuanTokenizer" in str(e): # This is for the error "'BaichuanTokenizer' object has no # attribute 'sp_model'". tokenizer = BaichuanTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs) else: raise e if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead.") return get_cached_tokenizer(tokenizer) def get_lora_tokenizer(lora_request: LoRARequest, *args, **kwargs) -> Optional[PreTrainedTokenizer]: if lora_request is None: return None try: tokenizer = get_tokenizer(lora_request.lora_local_path, *args, **kwargs) except OSError as e: # No tokenizer was found in the LoRA folder, # use base model tokenizer logger.warning( f"No tokenizer found in {lora_request.lora_local_path}, " "using base model tokenizer instead. " f"(Exception: {str(e)})") tokenizer = None return tokenizer get_lora_tokenizer_async = make_async(get_lora_tokenizer) def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], skip_special_tokens: bool, spaces_between_special_tokens: bool, ) -> str: # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. sub_texts = [] current_sub_text = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: continue if token in tokenizer.get_added_vocab(): if current_sub_text: sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_texts.append(sub_text) current_sub_text = [] sub_texts.append(token) else: current_sub_text.append(token) if current_sub_text: sub_text = tokenizer.convert_tokens_to_string(current_sub_text) sub_texts.append(sub_text) if spaces_between_special_tokens: return " ".join(sub_texts) else: return "".join(sub_texts) # 5 is an arbitrary value that should work for all # tokenizers (bigger = more conservative). INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 def convert_prompt_ids_to_tokens( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], prompt_ids: List[int], skip_special_tokens: bool = False, ) -> Tuple[List[str], int, int]: """Converts the prompt ids to tokens and returns the tokens and offsets for incremental detokenization. Note that not all tokens are converted to strings. Only the tokens that are necessary for incremental detokenization are converted to strings. """ # Offset a little more in case we have special tokens. prefix_offset = max( len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0) # We do not need to convert the whole prompt to tokens. new_tokens = tokenizer.convert_ids_to_tokens( prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens) prefix_offset = max( len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) read_offset = len(new_tokens) return new_tokens, prefix_offset, read_offset # Based on # https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 # under Apache 2.0 license def detokenize_incrementally( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], all_input_ids: List[int], prev_tokens: Optional[List[str]], prefix_offset: int, read_offset: int, skip_special_tokens: bool = False, spaces_between_special_tokens: bool = True, ) -> Tuple[List[str], str, int, int]: """Detokenizes the input ids incrementally and returns the new tokens and the new text. If `prev_tokens` is None, this function will convert the input ids to tokens and return the tokens and the new text. Otherwise, it will return the new tokens and the new text. This function will also return the new prefix offset and the new read offset to be used in the next iteration. The offsets are necessary to defeat cleanup algorithms in the decode which decide to add a space or not depending on the surrounding ids. Args: tokenizer: The tokenizer to use. all_input_ids: The input ids. The last id is the new token id. prev_tokens: The previous tokens. If None, this function will convert the input ids to tokens and return the tokens and the new text. prefix_offset: The prefix offset. read_offset: The read offset. skip_special_tokens: Whether to skip special tokens. spaces_between_special_tokens: Whether to add spaces between special tokens. """ new_token_id = all_input_ids[-1] # This is the first iteration for this sequence is_first_iter = prev_tokens is None if is_first_iter: (prev_tokens, prefix_offset, read_offset) = convert_prompt_ids_to_tokens( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): new_tokens = [""] else: # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. if is_first_iter: new_tokens = output_tokens # The prefix text is necessary only to defeat cleanup algorithms in # the decode which decide to add a space or not depending on the # surrounding ids. if tokenizer.is_fast or not tokenizer.get_added_vocab(): prefix_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_offset:read_offset]) new_text = tokenizer.convert_tokens_to_string( output_tokens[prefix_offset:]) else: prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) new_text = _convert_tokens_to_string_with_added_encoders( tokenizer, output_tokens[prefix_offset:], skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) if len(new_text) > len(prefix_text) and not new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. # If it's in the middle, it's probably a real invalid id generated # by the model new_text = new_text[len(prefix_text):] return new_tokens, new_text, read_offset, len(output_tokens) else: return new_tokens, "", prefix_offset, read_offset