Commit d5c234ce authored by Baber's avatar Baber
Browse files

check tokenizer.add_bos_token for bos control

parent c2a07d9a
......@@ -84,7 +84,7 @@ class HFLM(TemplateLM):
max_batch_size: int | None = 64,
trust_remote_code: bool | None = False,
use_fast_tokenizer: bool | None = True,
add_bos_token: bool | None = False,
add_bos_token: bool | None = None,
prefix_token_id: int | None = None,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
......@@ -258,11 +258,14 @@ class HFLM(TemplateLM):
)
self.add_bos_token = add_bos_token
if "gemma" in getattr(self.config, "model_type", ""):
self.add_bos_token = True
eval_logger.info(
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
)
if self.add_bos_token is None:
if getattr(self.tokenizer, "add_bos_token", False):
self.add_bos_token = True
eval_logger.info(
f"Tokenizer has 'add_bos_token' attribute set -- using BOS token based on tokenizer configuration for model type '{self.config.model_type}'. To control explicitly, set `add_bos_token=True|False`"
)
else:
self.add_bos_token = False
self._max_length = max_length
self.pretrained = pretrained
......@@ -858,17 +861,23 @@ class HFLM(TemplateLM):
def tok_encode(
self,
string: str,
left_truncate_len: int | None = None,
add_special_tokens: bool | None = None,
left_truncate_len: int | None = None,
**kwargs,
) -> list[int]:
""" """
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
special_tokens_kwargs = (
{"add_special_tokens": False or self.add_bos_token}
special_tokens_kwargs: dict = (
{
"add_special_tokens": self.add_bos_token
if add_special_tokens is None
else add_special_tokens
}
if self.backend == "causal"
# otherwise the method explicitly defines the value
else {"add_special_tokens": add_special_tokens}
if isinstance(add_special_tokens, bool)
else {}
)
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
......@@ -966,7 +975,7 @@ class HFLM(TemplateLM):
context,
max_length: int,
stop: list[str],
**generation_kwargs: dict[str, Any],
**generation_kwargs,
) -> torch.Tensor:
# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment