Commit d5c234ce authored by Baber's avatar Baber
Browse files

check tokenizer.add_bos_token for bos control

parent c2a07d9a
...@@ -84,7 +84,7 @@ class HFLM(TemplateLM): ...@@ -84,7 +84,7 @@ class HFLM(TemplateLM):
max_batch_size: int | None = 64, max_batch_size: int | None = 64,
trust_remote_code: bool | None = False, trust_remote_code: bool | None = False,
use_fast_tokenizer: bool | None = True, use_fast_tokenizer: bool | None = True,
add_bos_token: bool | None = False, add_bos_token: bool | None = None,
prefix_token_id: int | None = None, prefix_token_id: int | None = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
...@@ -258,11 +258,14 @@ class HFLM(TemplateLM): ...@@ -258,11 +258,14 @@ class HFLM(TemplateLM):
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in getattr(self.config, "model_type", ""): if self.add_bos_token is None:
self.add_bos_token = True if getattr(self.tokenizer, "add_bos_token", False):
eval_logger.info( self.add_bos_token = True
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." eval_logger.info(
) f"Tokenizer has 'add_bos_token' attribute set -- using BOS token based on tokenizer configuration for model type '{self.config.model_type}'. To control explicitly, set `add_bos_token=True|False`"
)
else:
self.add_bos_token = False
self._max_length = max_length self._max_length = max_length
self.pretrained = pretrained self.pretrained = pretrained
...@@ -858,17 +861,23 @@ class HFLM(TemplateLM): ...@@ -858,17 +861,23 @@ class HFLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: str, string: str,
left_truncate_len: int | None = None,
add_special_tokens: bool | None = None, add_special_tokens: bool | None = None,
left_truncate_len: int | None = None,
**kwargs,
) -> list[int]: ) -> list[int]:
""" """
# default for None - empty dict, use predefined tokenizer param # default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value # used for all models except for CausalLM or predefined value
special_tokens_kwargs = ( special_tokens_kwargs: dict = (
{"add_special_tokens": False or self.add_bos_token} {
"add_special_tokens": self.add_bos_token
if add_special_tokens is None
else add_special_tokens
}
if self.backend == "causal" if self.backend == "causal"
# otherwise the method explicitly defines the value # otherwise the method explicitly defines the value
else {"add_special_tokens": add_special_tokens} else {"add_special_tokens": add_special_tokens}
if isinstance(add_special_tokens, bool)
else {}
) )
encoding = self.tokenizer.encode(string, **special_tokens_kwargs) encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
...@@ -966,7 +975,7 @@ class HFLM(TemplateLM): ...@@ -966,7 +975,7 @@ class HFLM(TemplateLM):
context, context,
max_length: int, max_length: int,
stop: list[str], stop: list[str],
**generation_kwargs: dict[str, Any], **generation_kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# temperature = 0.0 if not set # temperature = 0.0 if not set
# if do_sample is false and temp==0.0: # if do_sample is false and temp==0.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment