Commit 5a85f9bb authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed encoding for seq2seq models

parent a85e0150
...@@ -4,6 +4,7 @@ import json ...@@ -4,6 +4,7 @@ import json
import logging import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
...@@ -296,12 +297,17 @@ class TemplateLM(LM): ...@@ -296,12 +297,17 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
context_enc = self.tok_encode(context) whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context, add_special_tokens=True)
continuation_enc = self.tok_encode(continuation, add_special_tokens=True)
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
......
...@@ -707,8 +707,6 @@ class HFLM(TemplateLM): ...@@ -707,8 +707,6 @@ class HFLM(TemplateLM):
encoding["attention_mask"] = encoding["attention_mask"][ encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len: :, -left_truncate_len:
] ]
# print(encoding["input_ids"][0])
# import sys; sys.exit()
self.tokenizer.padding_side = old_padding_side self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"] return encoding["input_ids"], encoding["attention_mask"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment