Unverified Commit 6fc4e6e0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)

parent 9606c719
......@@ -11,4 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
......@@ -26,3 +26,4 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
......@@ -30,9 +30,11 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
......
......@@ -61,7 +61,8 @@ class ModelConfig:
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
......@@ -246,10 +247,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
if tokenizer_mode not in ["auto", "slow", "mistral"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
......
......@@ -198,10 +198,11 @@ class EngineArgs:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
......
......@@ -267,7 +267,7 @@ def apply_chat_template(
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
) -> Union[str, List[int]]:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
......@@ -280,6 +280,4 @@ def apply_chat_template(
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt
......@@ -390,15 +390,21 @@ class LLM:
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
prompts = apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt)
inputs: PromptInputs
if isinstance(prompt, list) and isinstance(prompt[0], int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)
return self.generate(
prompts,
sampling_params,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
......
......@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
......@@ -130,6 +131,7 @@ class OpenAIServingChat(OpenAIServing):
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
......@@ -137,6 +139,14 @@ class OpenAIServingChat(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,
......
......@@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset
......
import os
import warnings
from pathlib import Path
from typing import Optional, Union
......@@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
MistralTokenizer)
from vllm.utils import make_async
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
......@@ -99,24 +102,40 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
# if tokenizer is from official mistral org
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'encoding and decoding.',
FutureWarning,
stacklevel=2)
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)):
err_msg = ("Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
......@@ -129,7 +148,8 @@ def get_tokenizer(
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
**kwargs,
)
else:
raise e
......@@ -137,7 +157,9 @@ def get_tokenizer(
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return get_cached_tokenizer(tokenizer)
tokenizer = get_cached_tokenizer(tokenizer)
return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,
......
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
__all__ = [
"BaichuanTokenizer",
]
__all__ = ["BaichuanTokenizer", "MistralTokenizer"]
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import HfApi, hf_hub_download
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
# yapf: enable
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
@dataclass
class Encoding:
input_ids: List[int]
def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
raise OSError(f"Found {len(matched_files)} files matching the "
"pattern: {matched_files}. Make sure only one Mistral "
"tokenizer is present in {tokenizer_name}.")
elif len(matched_files) == 0:
raise OSError(f"Found {len(matched_files)} files matching the "
"pattern: {matched_files}. Make sure that a Mistral "
"tokenizer is present in {tokenizer_name}.")
return matched_files[0]
class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
self.vocab_size = len(self.tokenizer.vocab())
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
if self._is_tekken:
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
@classmethod
def from_pretrained(cls,
path_or_repo_id: str,
*,
revision: Optional[str] = None) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists():
assert len(path_or_repo_id.split("/")) == 2, (
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id.")
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision)
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(
os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str:
api = HfApi()
repo_info = api.model_info(tokenizer_name)
files = [s.rfilename for s in repo_info.siblings]
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(tokenizer_name,
filename=filename,
revision=revision)
return tokenizer_file
def __call__(
self,
prompt: str,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
if truncation:
input_ids = input_ids[:max_length]
return Encoding(input_ids=input_ids)
def get_added_vocab(self) -> List[str]:
# Mistral tokenizers have no added vocabulary
return []
def encode(self, prompt: str) -> List[int]:
# `encode ` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
def apply_chat_template(self,
conversation: List["ConversationMessage"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."
request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
return "".join(tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
def decode(self, ids: Union[List[int], int]) -> str:
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
@property
def eos_token_id(self):
return self.tokenizer.eos_id
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
), "Skipping special tokens is not supported for Mistral tokenizers."
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens
def __len__(self):
return self.vocab_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment