Unverified Commit b7923a84 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Patch for Seq2Seq Model predictions (#1584)



* Differentiate _encode_pair setting for decoder and enc-dec models

* tok_decode to not skip special token so that eos doen't become empty string

* Update model.py

* Update model.py

* Update huggingface.py

* Update lm_eval/models/huggingface.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update model.py

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 92f30afd
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
...@@ -296,11 +297,16 @@ class TemplateLM(LM): ...@@ -296,11 +297,16 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation) if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
context_enc = self.tok_encode(context) whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
return context_enc, continuation_enc return context_enc, continuation_enc
......
...@@ -711,11 +711,15 @@ class HFLM(TemplateLM): ...@@ -711,11 +711,15 @@ class HFLM(TemplateLM):
return encoding["input_ids"], encoding["attention_mask"] return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens): def tok_decode(self, tokens, skip_special_tokens=True):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
)
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
...@@ -1158,7 +1162,7 @@ class HFLM(TemplateLM): ...@@ -1158,7 +1162,7 @@ class HFLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences # add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id) eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until: if not until:
until = [eos] until = [eos]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment