Unverified Commit df07391e authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Fix] Enforce `do_sample=False` in HF model (#506)



* update hf model wrapper

* patch llama

---------
Co-authored-by: default avatarbot <bot@bot.com>
parent b6284233
......@@ -100,25 +100,33 @@ class HuggingFace(BaseModel):
if self.pad_token_id < 0:
self.pad_token_id += self.tokenizer.vocab_size
if self.tokenizer.pad_token_id is None:
self.logger.warning(
f'Using {self.pad_token_id} as pad_token_id')
self.logger.debug(f'Using {self.pad_token_id} as pad_token_id')
elif self.tokenizer.pad_token_id != self.pad_token_id:
self.logger.warning(
f'pad_token_id is not consistent with the tokenizer. Using {self.pad_token_id} as pad_token_id' # noqa
)
'pad_token_id is not consistent with the tokenizer. Using '
f'{self.pad_token_id} as pad_token_id')
self.tokenizer.pad_token_id = self.pad_token_id
elif self.tokenizer.pad_token_id is None:
self.logger.warning('pad_token_id is not set for the tokenizer.')
if self.tokenizer.eos_token is not None:
self.logger.warning('Using eos_token_id as pad_token_id.')
self.logger.warning(
f'{self.tokenizer.eos_token} la {self.tokenizer.eos_token is None}' # noqa
)
f'Using eos_token_id {self.tokenizer.eos_token} '
'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id={PAD_TOKEN_ID}` in model_cfg. You may find pad_token_id in `generation.json`' # noqa
)
from transformers.generation import GenerationConfig
gcfg = GenerationConfig.from_pretrained(path)
if gcfg.pad_token_id is not None:
self.logger.warning(
f'Using pad_token_id {gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
# A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path or \
......@@ -165,6 +173,7 @@ class HuggingFace(BaseModel):
peft_path,
is_trainable=False)
self.model.eval()
self.model.generation_config.do_sample = False
# A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path:
......@@ -432,3 +441,4 @@ class HuggingFaceCausalLM(HuggingFace):
peft_path,
is_trainable=False)
self.model.eval()
self.model.generation_config.do_sample = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment