Unverified Commit 36434220 authored by Anthony MOI's avatar Anthony MOI Committed by GitHub
Browse files

[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized...


[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests (#4510)

* Use tokenizers pre-tokenized pipeline

* failing pretrokenized test

* Fix is_pretokenized in python

* add pretokenized tests

* style and quality

* better tests for batched pretokenized inputs

* tokenizers clean up - new padding_strategy - split the files

* [HUGE] refactoring tokenizers - padding - truncation - tests

* style and quality

* bump up requied tokenizers version to 0.8.0-rc1

* switched padding/truncation API - simpler better backward compat

* updating tests for custom tokenizers

* style and quality - tests on pad

* fix QA pipeline

* fix backward compatibility for max_length only

* style and quality

* Various cleans up - add verbose

* fix tests

* update docstrings

* Fix tests

* Docs reformatted

* __call__ method documented
Co-authored-by: default avatarThomas Wolf <thomwolf@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent ebba39e4
......@@ -17,12 +17,14 @@ The base classes ``PreTrainedTokenizer`` and ``PreTrainedTokenizerFast`` impleme
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizer
:special-members: __call__
:members:
``PreTrainedTokenizerFast``
~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedTokenizerFast
:special-members: __call__
:members:
``BatchEncoding``
......
......@@ -3,8 +3,6 @@ import os
import torch
from torch.utils.data import Dataset
from transformers.tokenization_utils import trim_batch
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
examples = []
......@@ -17,6 +15,17 @@ def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return
return examples
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset):
def __init__(
self,
......
......@@ -108,7 +108,7 @@ setup(
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.7.0",
"tokenizers == 0.8.0-rc1",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# utilities from PyPA to e.g. compare versions
......
......@@ -133,13 +133,16 @@ from .tokenization_reformer import ReformerTokenizer
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
from .tokenization_utils import (
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import (
BatchEncoding,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
CharSpan,
PreTrainedTokenizerBase,
SpecialTokensMixin,
TensorType,
TokenSpan,
)
from .tokenization_utils_fast import PreTrainedTokenizerFast
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
......
......@@ -1213,7 +1213,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
model = TFAlbertForMultipleChoice.from_pretrained('albert-base-v2')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
input_ids = tokenizer(choices, add_special_tokens=True, return_tensors='tf', truncation=True, padding=True)[None, :] # Batch size 1, 2 choices
labels = tf.reshape(tf.constant(1), (-1, 1))
outputs = model(input_ids, labels=labels)
......
......@@ -23,7 +23,8 @@ from typing import List, Optional
from tokenizers import BertWordPieceTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......
......@@ -23,7 +23,9 @@ from functools import lru_cache
import regex as re
from tokenizers import ByteLevelBPETokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import BatchEncoding
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......@@ -346,3 +348,24 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
unk_token=unk_token,
**kwargs,
)
self.add_prefix_space = add_prefix_space
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
"to use it with pretokenized inputs."
)
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
"to use it with pretokenized inputs."
)
return super()._encode_plus(*args, **kwargs)
......@@ -23,7 +23,8 @@ import re
from tokenizers import CharBPETokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
......
......@@ -35,7 +35,8 @@ from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
from tokenizers.processors import BertProcessing
from .file_utils import cached_path, is_torch_available
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_fast import PreTrainedTokenizerFast
if is_torch_available():
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -51,7 +51,7 @@ class XxxTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_tokenizer(self, **kwargs):
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
return input_text, output_text
......
......@@ -36,7 +36,7 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "this is a test"
output_text = "this is a test"
return input_text, output_text
......
......@@ -44,6 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"[UNK]",
"[CLS]",
"[SEP]",
"[PAD]",
"[MASK]",
"want",
"##want",
"##ed",
......@@ -62,7 +64,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
return input_text, output_text
......@@ -72,7 +74,7 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
......@@ -96,6 +98,25 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
# With lower casing
tokenizer = self.get_tokenizer(do_lower_case=True)
rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
sequence = "UNwant\u00E9d,running"
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_chinese(self):
tokenizer = BasicTokenizer()
......
......@@ -60,11 +60,26 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
return input_text, output_text
def get_clean_sequence(self, tokenizer):
input_text, output_text = self.get_input_output_texts(tokenizer)
ids = tokenizer.encode(output_text, add_special_tokens=False)
text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
return text, ids
def test_pretokenized_inputs(self):
pass # TODO add if relevant
def test_maximum_encoding_length_pair_input(self):
pass # TODO add if relevant
def test_maximum_encoding_length_single_input(self):
pass # TODO add if relevant
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file)
......@@ -157,11 +172,20 @@ class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestC
def get_tokenizer(self, **kwargs):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
return input_text, output_text
def test_pretokenized_inputs(self):
pass # TODO add if relevant
def test_maximum_encoding_length_pair_input(self):
pass # TODO add if relevant
def test_maximum_encoding_length_single_input(self):
pass # TODO add if relevant
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
......
This diff is collapsed.
......@@ -46,7 +46,7 @@ class CTRLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map)
return CTRLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
input_text = "adapt react readapt apt"
output_text = "adapt react readapt apt"
return input_text, output_text
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment