Commit 51afaca2 authored by lintangsutawika's avatar lintangsutawika
Browse files

seq2seq

parent 5a85f9bb
......@@ -305,8 +305,9 @@ class TemplateLM(LM):
continuation_enc = whole_enc[context_enc_len:]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context, add_special_tokens=True)
continuation_enc = self.tok_encode(continuation, add_special_tokens=True)
# The encoder may require context end with special tokens
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
return context_enc, continuation_enc
......
......@@ -664,14 +664,14 @@ class HFLM(TemplateLM):
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = {
"add_special_tokens": False or self.add_bos_token
}
encoding = self.tokenizer.encode(string, **add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
......@@ -690,17 +690,18 @@ class HFLM(TemplateLM):
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
add_special_tokens = {
"add_special_tokens": False or self.add_bos_token
}
encoding = self.tokenizer(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
**add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment