Unverified Commit 36434220 authored by Anthony MOI's avatar Anthony MOI Committed by GitHub
Browse files

[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized...


[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests (#4510)

* Use tokenizers pre-tokenized pipeline

* failing pretrokenized test

* Fix is_pretokenized in python

* add pretokenized tests

* style and quality

* better tests for batched pretokenized inputs

* tokenizers clean up - new padding_strategy - split the files

* [HUGE] refactoring tokenizers - padding - truncation - tests

* style and quality

* bump up requied tokenizers version to 0.8.0-rc1

* switched padding/truncation API - simpler better backward compat

* updating tests for custom tokenizers

* style and quality - tests on pad

* fix QA pipeline

* fix backward compatibility for max_length only

* style and quality

* Various cleans up - add verbose

* fix tests

* update docstrings

* Fix tests

* Docs reformatted

* __call__ method documented
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent ebba39e4
...@@ -17,12 +17,14 @@ The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` impleme ...@@ -17,12 +17,14 @@ The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` impleme
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizer .. autoclass:: transformers.PreTrainedTokenizer
:special-members: __call__
:members: :members:
``PreTrainedTokenizerFast`` ``PreTrainedTokenizerFast``
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizerFast .. autoclass:: transformers.PreTrainedTokenizerFast
:special-members: __call__
:members: :members:
``BatchEncoding`` ``BatchEncoding``
......
...@@ -3,8 +3,6 @@ import os ...@@ -3,8 +3,6 @@ import os
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers.tokenization_utils import trim_batch
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
examples = [] examples = []
...@@ -17,6 +15,17 @@ def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return ...@@ -17,6 +15,17 @@ def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return
return examples return examples
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset): class SummarizationDataset(Dataset):
def __init__( def __init__(
self, self,
......
...@@ -108,7 +108,7 @@ setup( ...@@ -108,7 +108,7 @@ setup(
packages=find_packages("src"), packages=find_packages("src"),
install_requires=[ install_requires=[
"numpy", "numpy",
"tokenizers == 0.7.0", "tokenizers == 0.8.0-rc1",
# dataclasses for Python versions that don't have it # dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'", "dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions # utilities from PyPA to e.g. compare versions
......
...@@ -133,13 +133,16 @@ from .tokenization_reformer import ReformerTokenizer ...@@ -133,13 +133,16 @@ from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
from .tokenization_utils import ( from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import (
BatchEncoding, BatchEncoding,
PreTrainedTokenizer, CharSpan,
PreTrainedTokenizerFast, PreTrainedTokenizerBase,
SpecialTokensMixin, SpecialTokensMixin,
TensorType, TensorType,
TokenSpan,
) )
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
......
...@@ -1213,7 +1213,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): ...@@ -1213,7 +1213,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2') model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices input_ids = tokenizer(choices, add_special_tokens=True, return_tensors='tf', truncation=True, padding=True)[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1)) labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels) outputs = model(input_ids, labels=labels)
......
...@@ -23,7 +23,8 @@ from typing import List, Optional ...@@ -23,7 +23,8 @@ from typing import List, Optional
from tokenizers import BertWordPieceTokenizer from tokenizers import BertWordPieceTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -23,7 +23,9 @@ from functools import lru_cache ...@@ -23,7 +23,9 @@ from functools import lru_cache
import regex as re import regex as re
from tokenizers import ByteLevelBPETokenizer from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -346,3 +348,24 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): ...@@ -346,3 +348,24 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
unk_token=unk_token, unk_token=unk_token,
**kwargs, **kwargs,
) )
self.add_prefix_space = add_prefix_space
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
"to use it with pretokenized inputs."
)
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
"to use it with pretokenized inputs."
)
return super()._encode_plus(*args, **kwargs)
...@@ -23,7 +23,8 @@ import re ...@@ -23,7 +23,8 @@ import re
from tokenizers import CharBPETokenizer from tokenizers import CharBPETokenizer
from .tokenization_bert import BasicTokenizer from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -35,7 +35,8 @@ from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit ...@@ -35,7 +35,8 @@ from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available from .file_utils import cached_path, is_torch_available
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
if is_torch_available(): if is_torch_available():
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -51,7 +51,7 @@ class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -51,7 +51,7 @@ class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
......
...@@ -36,7 +36,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -36,7 +36,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB) tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_input_output_texts(self): def get_input_output_texts(self, tokenizer):
input_text = "this is a test" input_text = "this is a test"
output_text = "this is a test" output_text = "this is a test"
return input_text, output_text return input_text, output_text
......
...@@ -44,6 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -44,6 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"[UNK]", "[UNK]",
"[CLS]", "[CLS]",
"[SEP]", "[SEP]",
"[PAD]",
"[MASK]",
"want", "want",
"##want", "##want",
"##ed", "##ed",
...@@ -62,7 +64,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -62,7 +64,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_rust_tokenizer(self, **kwargs): def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs) return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
...@@ -72,7 +74,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -72,7 +74,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens = tokenizer.tokenize("UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
def test_rust_and_python_full_tokenizers(self): def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer: if not self.test_rust_tokenizer:
...@@ -96,6 +98,25 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -96,6 +98,25 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence) rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids) self.assertListEqual(ids, rust_ids)
# With lower casing
tokenizer = self.get_tokenizer(do_lower_case=True)
rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
sequence = "UNwant\u00E9d,running"
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
This diff is collapsed.
This diff is collapsed.
...@@ -46,7 +46,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -46,7 +46,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map) kwargs.update(self.special_tokens_map)
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs) return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self, tokenizer):
input_text = "adapt react readapt apt" input_text = "adapt react readapt apt"
output_text = "adapt react readapt apt" output_text = "adapt react readapt apt"
return input_text, output_text return input_text, output_text
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment