Commit 22197e78 authored by Baber's avatar Baber
Browse files

fix params

parent d5c234ce
......@@ -336,7 +336,7 @@ class TemplateLM(LM):
return self.eot_token_id
@abc.abstractmethod
def tok_encode(self, string: str, **kwargs) -> list[int]:
def tok_encode(self, string: str, add_special_tokens=False, **kwargs) -> list[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
......@@ -377,6 +377,7 @@ class TemplateLM(LM):
This method does NOT handle empty context. The caller should
handle empty context (see loglikelihood method).
"""
assert context, "Context cannot be empty!"
import transformers
n_spaces = len(context) - len(context.rstrip())
......@@ -429,7 +430,9 @@ class TemplateLM(LM):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
continuation_enc = self.tok_encode(continuation)
continuation_enc = self.tok_encode(
continuation, add_special_tokens=False
)
# BOS or EOS as context
context_enc, continuation_enc = (
([self.prefix_token_id], continuation_enc)
......
......@@ -150,7 +150,7 @@ class Grouper:
def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment