"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "633062639bfd6be15abc072aaf7e18bce355f426"
Unverified Commit 063d8d27 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Refactor `prepare_seq2seq_batch` (#9524)

* Add target contextmanager and rework prepare_seq2seq_batch

* Fix tests, treat BART and Barthez

* Add last tokenizers

* Fix test

* Set src token before calling the superclass

* Remove special behavior for T5

* Remove needless imports

* Remove needless asserts
parent e6ecef71
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer from ..roberta.tokenization_roberta import RobertaTokenizer
...@@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer): ...@@ -54,45 +50,3 @@ class BartTokenizer(RobertaTokenizer):
"vocab_file": {m: vocab_url for m in _all_bart_models}, "vocab_file": {m: vocab_url for m in _all_bart_models},
"merges_file": {m: merges_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models},
} }
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation=True,
**kwargs,
) -> BatchEncoding:
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from .tokenization_bart import BartTokenizer from .tokenization_bart import BartTokenizer
...@@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast): ...@@ -49,43 +45,3 @@ class BartTokenizerFast(RobertaTokenizerFast):
"tokenizer_file": {m: tokenizer_url for m in _all_bart_models}, "tokenizer_file": {m: tokenizer_url for m in _all_bart_models},
} }
slow_tokenizer_class = BartTokenizer slow_tokenizer_class = BartTokenizer
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: Optional[str] = None,
truncation=True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
...@@ -21,9 +21,7 @@ from typing import List, Optional, Tuple ...@@ -21,9 +21,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...file_utils import add_start_docstrings
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
...@@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer): ...@@ -264,45 +262,3 @@ class BarthezTokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,) return (out_vocab_file,)
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
...@@ -19,8 +19,7 @@ import os ...@@ -19,8 +19,7 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import add_start_docstrings, is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): ...@@ -228,45 +227,3 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,) return (out_vocab_file,)
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
...@@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -23,9 +23,7 @@ from typing import Dict, List, Optional, Tuple
import sacremoses as sm import sacremoses as sm
from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging from ...utils import logging
...@@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -484,40 +482,6 @@ class FSMTTokenizer(PreTrainedTokenizer):
return len(token_ids_0 + sep) * [0] return len(token_ids_0 + sep) * [0]
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: Optional[str] = None,
truncation=True,
padding="longest",
**unused,
) -> BatchEncoding:
if type(src_texts) is not list:
raise ValueError("src_texts is expected to be a list")
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
return model_inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
......
...@@ -15,15 +15,14 @@ ...@@ -15,15 +15,14 @@
import json import json
import re import re
import warnings import warnings
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import sentencepiece import sentencepiece
from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
vocab_files_names = { vocab_files_names = {
...@@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -182,40 +181,15 @@ class MarianTokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING) @contextmanager
def prepare_seq2seq_batch( def as_target_tokenizer(self):
self, """
src_texts: List[str], Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
tgt_texts: Optional[List[str]] = None, sequence-to-sequence models that need a slightly different processing for the labels.
max_length: Optional[int] = None, """
max_target_length: Optional[int] = None,
return_tensors: Optional[str] = None,
truncation=True,
padding="longest",
**unused,
) -> BatchEncoding:
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
self.current_spm = self.spm_source
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
self.current_spm = self.spm_target self.current_spm = self.spm_target
model_inputs["labels"] = self(tgt_texts, **tokenizer_kwargs)["input_ids"] yield
self.current_spm = self.spm_source self.current_spm = self.spm_source
return model_inputs
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
......
...@@ -13,11 +13,10 @@ ...@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
from ...file_utils import add_start_docstrings
from ...tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
...@@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -172,52 +171,28 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
src_lang: str = "en_XX", src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO", tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
truncation: bool = True,
padding: str = "longest",
return_tensors: Optional[str] = None,
add_prefix_space: bool = False, # ignored
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: self.src_lang = src_lang
max_length = self.model_max_length self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(src_lang) self.set_src_lang_special_tokens(self.src_lang)
model_inputs: BatchEncoding = self( return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
src_texts,
add_special_tokens=True, @contextmanager
return_tensors=return_tensors, def as_target_tokenizer(self):
max_length=max_length, """
padding=padding, Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
truncation=truncation, sequence-to-sequence models that need a slightly different processing for the labels.
**kwargs, """
) self.set_tgt_lang_special_tokens(self.tgt_lang)
if tgt_texts is None: yield
return model_inputs self.set_src_lang_special_tokens(self.src_lang)
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang)
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
from tokenizers import processors from tokenizers import processors
from ...file_utils import add_start_docstrings, is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import BatchEncoding from ...tokenization_utils import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
...@@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -171,51 +171,28 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
src_lang: str = "en_XX", src_lang: str = "en_XX",
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
tgt_lang: str = "ro_RO", tgt_lang: str = "ro_RO",
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
truncation: bool = True,
padding: str = "longest",
return_tensors: str = None,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: self.src_lang = src_lang
max_length = self.model_max_length self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(src_lang) self.set_src_lang_special_tokens(self.src_lang)
model_inputs: BatchEncoding = self( return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
src_texts,
add_special_tokens=True, @contextmanager
return_tensors=return_tensors, def as_target_tokenizer(self):
max_length=max_length, """
padding=padding, Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
truncation=truncation, sequence-to-sequence models that need a slightly different processing for the labels.
**kwargs, """
) self.set_tgt_lang_special_tokens(self.tgt_lang)
if tgt_texts is None: yield
return model_inputs self.set_src_lang_special_tokens(self.src_lang)
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
self.set_tgt_lang_special_tokens(tgt_lang)
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
self.set_src_lang_special_tokens(src_lang) # sets to src_lang
return model_inputs
def set_src_lang_special_tokens(self, src_lang) -> None: def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code].""" """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
......
...@@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -18,9 +18,7 @@ from typing import Dict, List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...file_utils import add_start_docstrings
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
...@@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer): ...@@ -250,36 +248,6 @@ class PegasusTokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = None,
truncation=True,
padding="longest",
**unused,
) -> BatchEncoding:
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
......
...@@ -19,8 +19,7 @@ import os ...@@ -19,8 +19,7 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import add_start_docstrings, is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): ...@@ -188,36 +187,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
return_tensors: str = None,
truncation=True,
padding="longest",
**unused,
) -> BatchEncoding:
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
tokenizer_kwargs = dict(
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
truncation=truncation,
padding=padding,
)
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None:
return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
......
...@@ -17,9 +17,7 @@ import collections ...@@ -17,9 +17,7 @@ import collections
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging from ...utils import logging
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
...@@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer): ...@@ -288,43 +286,3 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
return token_ids_0 + [self.sep_token_id] return token_ids_0 + [self.sep_token_id]
sep = [self.sep_token_id] sep = [self.sep_token_id]
return token_ids_0 + sep + token_ids_1 + sep return token_ids_0 + sep + token_ids_1 + sep
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
return model_inputs
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
import os import os
from typing import List, Optional from typing import List, Optional
from ...file_utils import add_start_docstrings from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
...@@ -63,42 +62,18 @@ class RagTokenizer: ...@@ -63,42 +62,18 @@ class RagTokenizer:
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
return self.generator.batch_decode(*args, **kwargs) return self.generator.batch_decode(*args, **kwargs)
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation=True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: if max_length is None:
max_length = self.question_encoder.model_max_length max_length = self.question_encoder.model_max_length
model_inputs: BatchEncoding = self.question_encoder(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None: if max_target_length is None:
max_target_length = self.generator.model_max_length max_target_length = self.generator.model_max_length
labels = self.generator( return super().prepare_seq2seq_batch(
tgt_texts, src_texts, tgt_texts, max_length=max_length, max_target_length=max_target_length, **kwargs
add_special_tokens=True, )
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)["input_ids"]
model_inputs["labels"] = labels
return model_inputs
...@@ -23,9 +23,7 @@ from typing import List, Optional, Tuple ...@@ -23,9 +23,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...file_utils import add_start_docstrings from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...utils import logging from ...utils import logging
...@@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -295,43 +293,3 @@ class T5Tokenizer(PreTrainedTokenizer):
copyfile(self.vocab_file, out_vocab_file) copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,) return (out_vocab_file,)
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
return model_inputs
...@@ -19,9 +19,7 @@ import os ...@@ -19,9 +19,7 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import add_start_docstrings, is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import BatchEncoding
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast): ...@@ -212,47 +210,3 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
if token_ids_1 is None: if token_ids_1 is None:
return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.model_max_length
self.prefix_tokens = []
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
# set prefix_tokens for target text
self.prefix_tokens = [self.pad_token_id]
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
self.prefix_tokens = []
return model_inputs
...@@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -738,80 +738,3 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return clean_text return clean_text
else: else:
return text return text
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.AutoModelForSeq2SeqLM`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
raise NotImplementedError(
"If your model requires more than input_ids for a typical forward pass, you should implement this method. "
"Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a "
"reference implementation."
)
...@@ -23,6 +23,7 @@ import json ...@@ -23,6 +23,7 @@ import json
import os import os
import warnings import warnings
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
...@@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1473,68 +1474,6 @@ INIT_TOKENIZER_DOCSTRING = r"""
""" """
PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
src_texts (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts (:obj:`list`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts.
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizerBase(SpecialTokensMixin): class PreTrainedTokenizerBase(SpecialTokensMixin):
""" """
...@@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -3252,3 +3191,113 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
"indexing errors".format(len(ids), self.model_max_length) "indexing errors".format(len(ids), self.model_max_length)
) )
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
yield
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
src_texts (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts (:obj:`list`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts.
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
# mBART-specific kwargs that should be ignored by other models.
kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None)
if max_length is None:
max_length = self.model_max_length
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
with self.as_target_tokenizer():
labels = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
...@@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): ...@@ -508,12 +508,6 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
def test_batch_generation_en_ROMANCE_multi(self): def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
def test_tokenizer_handles_empty(self):
normalized = self.tokenizer.normalize("")
self.assertIsInstance(normalized, str)
with self.assertRaises(ValueError):
self.tokenizer.prepare_seq2seq_batch([""], return_tensors="pt")
@slow @slow
def test_pipeline(self): def test_pipeline(self):
device = 0 if torch_device == "cuda" else -1 device = 0 if torch_device == "cuda" else -1
......
...@@ -83,6 +83,7 @@ class TokenizerTesterMixin: ...@@ -83,6 +83,7 @@ class TokenizerTesterMixin:
from_pretrained_kwargs = None from_pretrained_kwargs = None
from_pretrained_filter = None from_pretrained_filter = None
from_pretrained_vocab_key = "vocab_file" from_pretrained_vocab_key = "vocab_file"
test_seq2seq = True
def setUp(self) -> None: def setUp(self) -> None:
# Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
...@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin: ...@@ -1799,10 +1800,11 @@ class TokenizerTesterMixin:
@require_torch @require_torch
def test_prepare_seq2seq_batch(self): def test_prepare_seq2seq_batch(self):
if not self.test_seq2seq:
return
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
return
# Longer text that will definitely require truncation. # Longer text that will definitely require truncation.
src_text = [ src_text = [
" UN Chief Says There Is No Military Solution in Syria", " UN Chief Says There Is No Military Solution in Syria",
......
...@@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -26,6 +26,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = CTRLTokenizer tokenizer_class = CTRLTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
test_seq2seq = False
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
...@@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -32,6 +32,7 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_tokenizer_class = GPT2TokenizerFast rust_tokenizer_class = GPT2TokenizerFast
test_rust_tokenizer = True test_rust_tokenizer = True
from_pretrained_kwargs = {"add_prefix_space": True} from_pretrained_kwargs = {"add_prefix_space": True}
test_seq2seq = False
def setUp(self): def setUp(self):
super().setUp() super().setUp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment