"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "ea656c84435401c483468d5e4f1ed0f04a6d7c6d"
Commit 86e78589 authored by lintangsutawika's avatar lintangsutawika
Browse files

modified changes to fix loglikelihood prediction for seq2seq

parent 0d195e90
......@@ -409,12 +409,13 @@ class HFLM(LM):
utils.clear_torch_cache()
return batch_size
def tok_encode(self, string: str, left_truncate_len=None):
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
......@@ -529,8 +530,12 @@ class HFLM(LM):
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
whole_enc = self.tok_encode(context + continuation, add_special_tokens=False)
context_enc = self.tok_encode(context, add_special_tokens=False)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment