Unverified Commit df07391e authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Fix] Enforce `do_sample=False` in HF model (#506)



* update hf model wrapper

* patch llama

---------
Co-authored-by: default avatarbot <bot@bot.com>
parent b6284233
...@@ -100,25 +100,33 @@ class HuggingFace(BaseModel): ...@@ -100,25 +100,33 @@ class HuggingFace(BaseModel):
if self.pad_token_id < 0: if self.pad_token_id < 0:
self.pad_token_id += self.tokenizer.vocab_size self.pad_token_id += self.tokenizer.vocab_size
if self.tokenizer.pad_token_id is None: if self.tokenizer.pad_token_id is None:
self.logger.warning( self.logger.debug(f'Using {self.pad_token_id} as pad_token_id')
f'Using {self.pad_token_id} as pad_token_id')
elif self.tokenizer.pad_token_id != self.pad_token_id: elif self.tokenizer.pad_token_id != self.pad_token_id:
self.logger.warning( self.logger.warning(
f'pad_token_id is not consistent with the tokenizer. Using {self.pad_token_id} as pad_token_id' # noqa 'pad_token_id is not consistent with the tokenizer. Using '
) f'{self.pad_token_id} as pad_token_id')
self.tokenizer.pad_token_id = self.pad_token_id self.tokenizer.pad_token_id = self.pad_token_id
elif self.tokenizer.pad_token_id is None: elif self.tokenizer.pad_token_id is None:
self.logger.warning('pad_token_id is not set for the tokenizer.') self.logger.warning('pad_token_id is not set for the tokenizer.')
if self.tokenizer.eos_token is not None: if self.tokenizer.eos_token is not None:
self.logger.warning('Using eos_token_id as pad_token_id.')
self.logger.warning( self.logger.warning(
f'{self.tokenizer.eos_token} la {self.tokenizer.eos_token is None}' # noqa f'Using eos_token_id {self.tokenizer.eos_token} '
) 'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
else: else:
raise ValueError( from transformers.generation import GenerationConfig
'pad_token_id is not set for this tokenizer. Try to set pad_token_id via passing `pad_token_id={PAD_TOKEN_ID}` in model_cfg. You may find pad_token_id in `generation.json`' # noqa gcfg = GenerationConfig.from_pretrained(path)
)
if gcfg.pad_token_id is not None:
self.logger.warning(
f'Using pad_token_id {gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
# A patch for llama when batch_padding = True # A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path or \ if 'decapoda-research/llama' in path or \
...@@ -165,6 +173,7 @@ class HuggingFace(BaseModel): ...@@ -165,6 +173,7 @@ class HuggingFace(BaseModel):
peft_path, peft_path,
is_trainable=False) is_trainable=False)
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False
# A patch for llama when batch_padding = True # A patch for llama when batch_padding = True
if 'decapoda-research/llama' in path: if 'decapoda-research/llama' in path:
...@@ -432,3 +441,4 @@ class HuggingFaceCausalLM(HuggingFace): ...@@ -432,3 +441,4 @@ class HuggingFaceCausalLM(HuggingFace):
peft_path, peft_path,
is_trainable=False) is_trainable=False)
self.model.eval() self.model.eval()
self.model.generation_config.do_sample = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment