Commit d701d50f authored by Baber's avatar Baber
Browse files

fix bos token handling

parent aab23be4
...@@ -258,15 +258,7 @@ class HFLM(TemplateLM): ...@@ -258,15 +258,7 @@ class HFLM(TemplateLM):
else {} else {}
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token if add_bos_token is not None else None
if self.add_bos_token is None:
if getattr(self.tokenizer, "add_bos_token", False):
self.add_bos_token = True
eval_logger.info(
f"Tokenizer has 'add_bos_token' attribute set -- using BOS token based on tokenizer configuration for model type '{self.config.model_type}'. To control explicitly, set `add_bos_token=True|False`"
)
else:
self.add_bos_token = False
self._max_length = max_length self._max_length = max_length
self.pretrained = pretrained self.pretrained = pretrained
...@@ -748,7 +740,7 @@ class HFLM(TemplateLM): ...@@ -748,7 +740,7 @@ class HFLM(TemplateLM):
trust_remote_code: bool | None = False, trust_remote_code: bool | None = False,
use_fast_tokenizer: bool | None = True, use_fast_tokenizer: bool | None = True,
gguf_file: str | None = None, gguf_file: str | None = None,
add_bos_token: bool | None = False, add_bos_token: bool | None = None,
subfolder: str | None = "", subfolder: str | None = "",
) -> None: ) -> None:
"""Helper method during initialization. """Helper method during initialization.
...@@ -767,8 +759,8 @@ class HFLM(TemplateLM): ...@@ -767,8 +759,8 @@ class HFLM(TemplateLM):
else: else:
kwargs["use_fast"] = use_fast_tokenizer kwargs["use_fast"] = use_fast_tokenizer
if add_bos_token: if add_bos_token is not None:
kwargs["add_bos_token"] = True kwargs["add_bos_token"] = add_bos_token
if subfolder: if subfolder:
kwargs["subfolder"] = subfolder kwargs["subfolder"] = subfolder
...@@ -868,16 +860,12 @@ class HFLM(TemplateLM): ...@@ -868,16 +860,12 @@ class HFLM(TemplateLM):
) -> list[int]: ) -> list[int]:
# default for None - empty dict, use predefined tokenizer param # default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value # used for all models except for CausalLM or predefined value
special_tokens_kwargs: dict = (
{ special_tokens_kwargs = (
"add_special_tokens": self.add_bos_token {"add_special_tokens": add_special_tokens}
if add_special_tokens is None if (isinstance(add_special_tokens, bool))
else add_special_tokens else {"add_special_tokens": self.add_bos_token}
} if self.add_bos_token is not None
if self.backend == "causal"
# otherwise the method explicitly defines the value
else {"add_special_tokens": add_special_tokens}
if isinstance(add_special_tokens, bool)
else {} else {}
) )
...@@ -906,8 +894,10 @@ class HFLM(TemplateLM): ...@@ -906,8 +894,10 @@ class HFLM(TemplateLM):
strings[0], getattr(self.tokenizer, "bos_token", None) strings[0], getattr(self.tokenizer, "bos_token", None)
): ):
add_special_tokens = {"add_special_tokens": False} add_special_tokens = {"add_special_tokens": False}
elif self.add_bos_token is not None:
add_special_tokens = {"add_special_tokens": self.add_bos_token}
else: else:
add_special_tokens = {"add_special_tokens": False or self.add_bos_token} add_special_tokens = {"add_special_tokens": True}
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment