Commit 51afaca2 authored by lintangsutawika's avatar lintangsutawika
Browse files

seq2seq

parent 5a85f9bb
...@@ -305,8 +305,9 @@ class TemplateLM(LM): ...@@ -305,8 +305,9 @@ class TemplateLM(LM):
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context, add_special_tokens=True) # The encoder may require context end with special tokens
continuation_enc = self.tok_encode(continuation, add_special_tokens=True) context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
return context_enc, continuation_enc return context_enc, continuation_enc
......
...@@ -664,14 +664,14 @@ class HFLM(TemplateLM): ...@@ -664,14 +664,14 @@ class HFLM(TemplateLM):
self, string: str, left_truncate_len=None, add_special_tokens=None self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]: ) -> List[int]:
""" """ """ """
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = {
"add_special_tokens": False or self.add_bos_token
}
encoding = self.tokenizer.encode(string, **add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long # left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len: if left_truncate_len:
...@@ -690,17 +690,18 @@ class HFLM(TemplateLM): ...@@ -690,17 +690,18 @@ class HFLM(TemplateLM):
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False or self.add_bos_token add_special_tokens = {
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: "add_special_tokens": False or self.add_bos_token
add_special_tokens = True }
encoding = self.tokenizer( encoding = self.tokenizer(
strings, strings,
truncation=truncation, truncation=truncation,
padding="longest", padding="longest",
return_tensors="pt", return_tensors="pt",
add_special_tokens=add_special_tokens, **add_special_tokens,
) )
if left_truncate_len: if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment