Commit 86e78589 authored by lintangsutawika's avatar lintangsutawika
Browse files

modified changes to fix loglikelihood prediction for seq2seq

parent 0d195e90
...@@ -409,12 +409,13 @@ class HFLM(LM): ...@@ -409,12 +409,13 @@ class HFLM(LM):
utils.clear_torch_cache() utils.clear_torch_cache()
return batch_size return batch_size
def tok_encode(self, string: str, left_truncate_len=None): def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """ """ """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if add_special_tokens is None:
add_special_tokens = False if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: add_special_tokens = False
add_special_tokens = True elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
...@@ -529,8 +530,12 @@ class HFLM(LM): ...@@ -529,8 +530,12 @@ class HFLM(LM):
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context) whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc return context_enc, continuation_enc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment