Commit 5c6e9b50 authored by Baber's avatar Baber
Browse files

fix duplicate `bos` token when `context==""`

parent ad506a13
......@@ -378,11 +378,14 @@ class TemplateLM(LM):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
continuation_enc = self.tok_encode(continuation)
# BOS or EOS as context
context_enc, continuation_enc = (
[self.prefix_token_id],
self.tok_encode(continuation),
([self.prefix_token_id], continuation_enc)
if self.prefix_token_id != continuation_enc[0]
else (continuation_enc[:1], continuation_enc[1:])
)
# BOS or EOS as context
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
......
......@@ -864,17 +864,12 @@ class HFLM(TemplateLM):
""" """
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None:
if self.backend == "causal":
special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token
}
# otherwise the method explicitly defines the value
else:
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
special_tokens_kwargs = (
{"add_special_tokens": False or self.add_bos_token}
if self.backend == "causal"
# otherwise the method explicitly defines the value
else {"add_special_tokens": add_special_tokens}
)
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment