Unverified Commit 62449570 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

T5Tokenizer adds EOS token if not already added (#5866)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e11d923b
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import logging import logging
import os import os
import re import re
import warnings
from shutil import copyfile from shutil import copyfile
from typing import List, Optional from typing import List, Optional
...@@ -148,6 +149,74 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -148,6 +149,74 @@ class T5Tokenizer(PreTrainedTokenizer):
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
return vocab return vocab
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True if the token list is already formatted with special tokens for the model
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1], 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
# normal case: some special tokens
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
)
return token_ids
else:
return token_ids + [self.eos_token_id]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
For some t5 tasks, model.config.prefix is specified. This must be used before tokenization.
A sequence has the following format:
- single sequence: ``X </s>``
- pair of sequences: ``A </s> B </s>``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return self.prefix_tokens + token_ids_0 + token_ids_1
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
...@@ -210,31 +279,6 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -210,31 +279,6 @@ class T5Tokenizer(PreTrainedTokenizer):
return (out_vocab_file,) return (out_vocab_file,)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
A T5 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos]``
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
......
This diff is collapsed.
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import unittest import unittest
from transformers import BatchEncoding from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available from transformers.testing_utils import _torch_available
from transformers.tokenization_t5 import T5Tokenizer from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE from transformers.tokenization_xlnet import SPIECE_UNDERLINE
...@@ -107,28 +108,37 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -107,28 +108,37 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
], ],
) )
@cached_property
def t5_base_tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self): def test_prepare_seq2seq_batch(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [ tgt_text = [
"Summary of the text.", "Summary of the text.",
"Another summary.", "Another summary.",
] ]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5] expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK,)
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
result = list(batch.input_ids.numpy()[0]) result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result) self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
# Test that special tokens are reset # Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, []) self.assertEqual(tokenizer.prefix_tokens, [])
def test_empty_target_text(self): def test_empty_target_text(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids # check if input_ids are returned and no decoder_input_ids
...@@ -138,7 +148,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -138,7 +148,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch) self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self): def test_max_target_length(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [ tgt_text = [
"Summary of the text.", "Summary of the text.",
...@@ -158,7 +168,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -158,7 +168,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self): def test_outputs_not_longer_than_maxlen(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch( batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
...@@ -167,7 +177,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -167,7 +177,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(batch.input_ids.shape, (2, 512)) self.assertEqual(batch.input_ids.shape, (2, 512))
def test_eos_in_input(self): def test_eos_in_input(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization. </s>"] src_text = ["A long paragraph for summrization. </s>"]
tgt_text = ["Summary of the text. </s>"] tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment