Unverified Commit 39956efb authored by Qiong Zhou Huang's avatar Qiong Zhou Huang Committed by GitHub
Browse files

[Bugfix] Fix bad words for Mistral models (#17753)


Signed-off-by: default avatarQiong Zhou Huang <qiong@phonic.co>
parent 597051e5
...@@ -4,11 +4,12 @@ from typing import Callable, Union ...@@ -4,11 +4,12 @@ from typing import Callable, Union
import torch import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], LogitsProcessor = Union[
Callable[[list[int], list[int], torch.Tensor], Callable[[list[int], torch.Tensor], torch.Tensor],
torch.Tensor]] Callable[[list[int], list[int], torch.Tensor], torch.Tensor],
]
"""LogitsProcessor is a function that takes a list """LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a for the next token and, optionally, prompt tokens as a
...@@ -29,12 +30,8 @@ def get_bad_words_logits_processors( ...@@ -29,12 +30,8 @@ def get_bad_words_logits_processors(
prefix = " " if add_prefix_space else "" prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip() prompt = prefix + bad_word.lstrip()
if isinstance(tokenizer, MistralTokenizer): prompt_token_ids = tokenizer.encode(text=prompt,
# Mistral tokenizers should not add special tokens add_special_tokens=False)
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
# If no space at the beginning # If no space at the beginning
# or if prefix space produces a new word token # or if prefix space produces a new word token
......
...@@ -13,7 +13,6 @@ from typing_extensions import deprecated ...@@ -13,7 +13,6 @@ from typing_extensions import deprecated
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -491,13 +490,8 @@ class SamplingParams( ...@@ -491,13 +490,8 @@ class SamplingParams(
for add_prefix_space in [False, True]: for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else "" prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip() prompt = prefix + bad_word.lstrip()
prompt_token_ids = tokenizer.encode(text=prompt,
if isinstance(tokenizer, MistralTokenizer): add_special_tokens=False)
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)
# If no space at the beginning # If no space at the beginning
# or if prefix space produces a new word token # or if prefix space produces a new word token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment