Unverified Commit f10dd48f authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Fix] Update stop_words in huggingface_above_v4_33 (#1160)

parent 80f831b4
...@@ -156,7 +156,7 @@ class HuggingFacewithChatTemplate(BaseModel): ...@@ -156,7 +156,7 @@ class HuggingFacewithChatTemplate(BaseModel):
self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs)
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.fastchat_template = fastchat_template self.fastchat_template = fastchat_template
self.stop_words = stop_words self.stop_words = list(set(stop_words + self._get_potential_stop_words(path)))
for k, v in other_kwargs.items(): for k, v in other_kwargs.items():
if v is not None: if v is not None:
...@@ -213,6 +213,19 @@ class HuggingFacewithChatTemplate(BaseModel): ...@@ -213,6 +213,19 @@ class HuggingFacewithChatTemplate(BaseModel):
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False self.model.generation_config.do_sample = False
def _get_potential_stop_words(self, path: Optional[str]):
from transformers import GenerationConfig
potential_stop_words = []
try:
generation_config = GenerationConfig.from_pretrained(path)
for token_id in generation_config.eos_token_id:
potential_stop_words.append(self.tokenizer.decode(token_id))
except:
pass
potential_stop_words.append(self.tokenizer.eos_token)
potential_stop_words = list(set(potential_stop_words))
return potential_stop_words
def generate(self, def generate(self,
inputs: List[str], inputs: List[str],
max_out_len: int, max_out_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment