"docs/vscode:/vscode.git/clone" did not exist on "1e55dfa7e552e0995630a2563aeae443945e2e81"
Unverified Commit 39956efb authored by Qiong Zhou Huang's avatar Qiong Zhou Huang Committed by GitHub
Browse files

[Bugfix] Fix bad words for Mistral models (#17753)


Signed-off-by: default avatarQiong Zhou Huang <qiong@phonic.co>
parent 597051e5
......@@ -4,11 +4,12 @@ from typing import Callable, Union
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor],
torch.Tensor]]
LogitsProcessor = Union[
Callable[[list[int], torch.Tensor], torch.Tensor],
Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
......@@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
......
......@@ -13,7 +13,6 @@ from typing_extensions import deprecated
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__)
......@@ -491,13 +490,8 @@ class SamplingParams(
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning
# or if prefix space produces a new word token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment