Unverified Commit 262f879a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Seq2seq fix (#1604)



* fix on --task list

* add fixes to tokeniation

* differentiate encoding for seq2seq and decoder

* return token setting

* format for pre-commit

* Seq2seq fix, pt2 (#1630)

* getting model class only when defined

* encode_pair handles None, add_special_tokens turned into dict with default value

---------
Co-authored-by: default avatarachervyakov <77295913+artemorloff@users.noreply.github.com>
parent 8e72f267
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import os import os
from typing import List, Optional, Tuple, Type, TypeVar from typing import List, Optional, Tuple, Type, TypeVar
import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
from tqdm import tqdm from tqdm import tqdm
...@@ -302,11 +303,17 @@ class TemplateLM(LM): ...@@ -302,11 +303,17 @@ class TemplateLM(LM):
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation) model_class = getattr(self, "AUTO_MODEL_CLASS", None)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc) if model_class == transformers.AutoModelForSeq2SeqLM:
continuation_enc = whole_enc[context_enc_len:] context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
else:
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc return context_enc, continuation_enc
......
...@@ -681,14 +681,21 @@ class HFLM(TemplateLM): ...@@ -681,14 +681,21 @@ class HFLM(TemplateLM):
self, string: str, left_truncate_len=None, add_special_tokens=None self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]: ) -> List[int]:
""" """ """ """
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token special_tokens_kwargs = {
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: "add_special_tokens": False or self.add_bos_token
# TODO: investigate best practices for enc-dec models + special tokens }
add_special_tokens = True # otherwise the method explicitly defines the value
else:
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long # left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len: if left_truncate_len:
...@@ -707,17 +714,16 @@ class HFLM(TemplateLM): ...@@ -707,17 +714,16 @@ class HFLM(TemplateLM):
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
truncation=truncation, truncation=truncation,
padding="longest", padding="longest",
return_tensors="pt", return_tensors="pt",
add_special_tokens=add_special_tokens, **add_special_tokens,
) )
if left_truncate_len: if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
...@@ -728,11 +734,8 @@ class HFLM(TemplateLM): ...@@ -728,11 +734,8 @@ class HFLM(TemplateLM):
return encoding["input_ids"], encoding["attention_mask"] return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens): def tok_decode(self, tokens, skip_special_tokens=True):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
return self.tokenizer.decode(tokens)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
...@@ -1175,7 +1178,7 @@ class HFLM(TemplateLM): ...@@ -1175,7 +1178,7 @@ class HFLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences # add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id) eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until: if not until:
until = [eos] until = [eos]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment