Unverified Commit 62449570 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

T5Tokenizer adds EOS token if not already added (#5866)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e11d923b
......@@ -18,6 +18,7 @@
import logging
import os
import re
import warnings
from shutil import copyfile
from typing import List, Optional
......@@ -148,6 +149,74 @@ class T5Tokenizer(PreTrainedTokenizer):
vocab.update(self.added_tokens_encoder)
return vocab
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of ids.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Set to True if the token list is already formatted with special tokens for the model
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1], 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
# normal case: some special tokens
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
)
return token_ids
else:
return token_ids + [self.eos_token_id]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
For some t5 tasks, model.config.prefix is specified. This must be used before tokenization.
A sequence has the following format:
- single sequence: ``X </s>``
- pair of sequences: ``A </s> B </s>``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return self.prefix_tokens + token_ids_0 + token_ids_1
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
......@@ -210,31 +279,6 @@ class T5Tokenizer(PreTrainedTokenizer):
return (out_vocab_file,)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
A T5 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos]``
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
Pairs of sequences are not the expected use case, but they will be handled without a separator.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1
def prepare_seq2seq_batch(
self,
src_texts: List[str],
......
This diff is collapsed.
......@@ -18,6 +18,7 @@ import os
import unittest
from transformers import BatchEncoding
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available
from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
......@@ -107,28 +108,37 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
)
@cached_property
def t5_base_tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
def test_eos_treatment(self):
tokenizer = self.t5_base_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_seq2seq_batch(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
)
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK,)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
# Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, [])
def test_empty_target_text(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
......@@ -138,7 +148,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertNotIn("decoder_attention_mask", batch)
def test_max_target_length(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
......@@ -158,7 +168,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
def test_outputs_not_longer_than_maxlen(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
......@@ -167,7 +177,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_eos_in_input(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summrization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment