Unverified Commit f871646f authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Revert "Patch for Seq2Seq Model predictions (#1584)" (#1601)

This reverts commit b7923a84.
parent a4192489
...@@ -5,7 +5,6 @@ import logging ...@@ -5,7 +5,6 @@ import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
...@@ -304,17 +303,12 @@ class TemplateLM(LM): ...@@ -304,17 +303,12 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
whole_enc = self.tok_encode(context + continuation) whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
......
...@@ -726,15 +726,11 @@ class HFLM(TemplateLM): ...@@ -726,15 +726,11 @@ class HFLM(TemplateLM):
return encoding["input_ids"], encoding["attention_mask"] return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens, skip_special_tokens=True): def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode( return self.tokenizer.decode(tokens)
tokens, skip_special_tokens=skip_special_tokens
)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode( return self.tokenizer.decode(tokens, skip_special_tokens=True)
tokens, skip_special_tokens=skip_special_tokens
)
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
...@@ -1177,7 +1173,7 @@ class HFLM(TemplateLM): ...@@ -1177,7 +1173,7 @@ class HFLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences # add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) eos = self.tok_decode(self.eot_token_id)
if not until: if not until:
until = [eos] until = [eos]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment