Unverified Commit 262f879a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Seq2seq fix (#1604)



* fix on --task list

* add fixes to tokeniation

* differentiate encoding for seq2seq and decoder

* return token setting

* format for pre-commit

* Seq2seq fix, pt2 (#1630)

* getting model class only when defined

* encode_pair handles None, add_special_tokens turned into dict with default value

---------
Co-authored-by: default avatarachervyakov <77295913+artemorloff@users.noreply.github.com>
parent 8e72f267
......@@ -5,6 +5,7 @@ import logging
import os
from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict
from tqdm import tqdm
......@@ -302,11 +303,17 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
if model_class == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
else:
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
......
......@@ -681,14 +681,21 @@ class HFLM(TemplateLM):
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True
special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token
}
# otherwise the method explicitly defines the value
else:
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
......@@ -707,17 +714,16 @@ class HFLM(TemplateLM):
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.tokenizer(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
**add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
......@@ -728,11 +734,8 @@ class HFLM(TemplateLM):
return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def tok_decode(self, tokens, skip_special_tokens=True):
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_call(self, inps, attn_mask=None, labels=None):
"""
......@@ -1175,7 +1178,7 @@ class HFLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment