Commit aab23be4 authored by Baber's avatar Baber
Browse files

skip duplicate bos

parent f402411c
...@@ -32,6 +32,7 @@ from lm_eval.api.model import TemplateLM ...@@ -32,6 +32,7 @@ from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
bos_already_added,
clear_torch_cache, clear_torch_cache,
configure_pad_token, configure_pad_token,
get_dtype, get_dtype,
...@@ -901,6 +902,11 @@ class HFLM(TemplateLM): ...@@ -901,6 +902,11 @@ class HFLM(TemplateLM):
add_special_tokens = {} add_special_tokens = {}
if self.backend == "causal": if self.backend == "causal":
if bos_already_added(
strings[0], getattr(self.tokenizer, "bos_token", None)
):
add_special_tokens = {"add_special_tokens": False}
else:
add_special_tokens = {"add_special_tokens": False or self.add_bos_token} add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.tokenizer( encoding = self.tokenizer(
......
...@@ -881,3 +881,7 @@ def postprocess_generated_text( ...@@ -881,3 +881,7 @@ def postprocess_generated_text(
generation = generation.split(think_end_token)[-1].lstrip() generation = generation.split(think_end_token)[-1].lstrip()
return generation return generation
def bos_already_added(sequence: str, bos_string: Optional[str]):
return sequence[0] == bos_string
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment