Unverified Commit 0a8c17d5 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[T5Tokenizer] remove prefix_tokens (#7078)

parent 4cbd50e6
...@@ -96,8 +96,6 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -96,8 +96,6 @@ class T5Tokenizer(PreTrainedTokenizer):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"] model_input_names = ["attention_mask"]
prefix_tokens: List[int] = []
def __init__( def __init__(
self, self,
vocab_file, vocab_file,
...@@ -210,10 +208,10 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -210,10 +208,10 @@ class T5Tokenizer(PreTrainedTokenizer):
""" """
token_ids_0 = self._add_eos_if_not_present(token_ids_0) token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None: if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 return token_ids_0
else: else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1) token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return self.prefix_tokens + token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
...@@ -343,7 +341,6 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -343,7 +341,6 @@ class T5Tokenizer(PreTrainedTokenizer):
""" """
if max_length is None: if max_length is None:
max_length = self.max_len max_length = self.max_len
self.prefix_tokens = []
model_inputs = self( model_inputs = self(
src_texts, src_texts,
add_special_tokens=True, add_special_tokens=True,
...@@ -358,8 +355,6 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -358,8 +355,6 @@ class T5Tokenizer(PreTrainedTokenizer):
# Process tgt_texts # Process tgt_texts
if max_target_length is None: if max_target_length is None:
max_target_length = max_length max_target_length = max_length
# set prefix_tokens for target text
self.prefix_tokens = [self.pad_token_id]
labels_and_decoder_mask = self( labels_and_decoder_mask = self(
tgt_texts, tgt_texts,
add_special_tokens=True, add_special_tokens=True,
...@@ -370,5 +365,4 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -370,5 +365,4 @@ class T5Tokenizer(PreTrainedTokenizer):
**kwargs, **kwargs,
) )
model_inputs["labels"] = labels_and_decoder_mask["input_ids"] model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
self.prefix_tokens = []
return model_inputs return model_inputs
...@@ -139,9 +139,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -139,9 +139,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual((2, 9), batch.input_ids.shape) self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape) self.assertEqual((2, 9), batch.attention_mask.shape)
# Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, [])
def test_empty_target_text(self): def test_empty_target_text(self):
tokenizer = self.t5_base_tokenizer tokenizer = self.t5_base_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."] src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
...@@ -184,7 +181,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -184,7 +181,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
src_text = ["A long paragraph for summarization. </s>"] src_text = ["A long paragraph for summarization. </s>"]
tgt_text = ["Summary of the text. </s>"] tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1] expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1] expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment