Unverified Commit cc6775cd authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Fix max_length not taken into account when using pad_to_max_length on fast tokenizers (#2961)



* enable_padding should pad up to max_length if set.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Added more testing on padding.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 94ff2d6e
...@@ -79,7 +79,7 @@ def truncate_and_pad( ...@@ -79,7 +79,7 @@ def truncate_and_pad(
if pad_to_max_length and (pad_token and pad_token_id >= 0): if pad_to_max_length and (pad_token and pad_token_id >= 0):
tokenizer.enable_padding( tokenizer.enable_padding(
max_length=None, max_length=max_length,
direction=padding_side, direction=padding_side,
pad_id=pad_token_id, pad_id=pad_token_id,
pad_type_id=pad_token_type_id, pad_type_id=pad_token_type_id,
......
...@@ -76,6 +76,63 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -76,6 +76,63 @@ class FastTokenizerMatchingTest(unittest.TestCase):
for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()): for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold) self.assert_sequence_almost_equals(input_p[key], input_r[key], threshold)
def assert_padding(self, tokenizer_r, tokenizer_p):
# Simple input
input_r = tokenizer_r.encode("This is a simple input", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", max_length=15, pad_to_max_length=True)
self.assertSequenceEqual(input_r, input_p)
# Simple input
input_r = tokenizer_r.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode_plus("This is a simple input", max_length=15, pad_to_max_length=True)
self.assertSequenceEqual(input_r, input_p)
# Simple input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
# input_r = tokenizer_r.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
# )
# input_p = tokenizer_p.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"], max_length=15, pad_to_max_length=True
# )
# self.assertSequenceEqual(input_r, input_p)
# Pair input
input_r = tokenizer_r.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
input_p = tokenizer_p.encode("This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True)
self.assertSequenceEqual(input_r, input_p)
# Pair input
input_r = tokenizer_r.encode_plus(
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
)
input_p = tokenizer_p.encode_plus(
"This is a simple input", "This is a pair", max_length=15, pad_to_max_length=True
)
self.assertSequenceEqual(input_r, input_p)
# Pair input
# TODO: Re-enable this test when batch_encode_plus with padding correctly handles padding
# input_r = tokenizer_r.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"],
# ["This is a simple pair 1", "This is a simple pair 2"],
# max_length=15,
# pad_to_max_length=True,
# )
# input_p = tokenizer_p.batch_encode_plus(
# ["This is a simple input 1", "This is a simple input 2"],
# ["This is a simple pair 1", "This is a simple pair 2"],
# max_length=15,
# pad_to_max_length=True,
# )
# self.assertSequenceEqual(input_r, input_p)
def assert_add_tokens(self, tokenizer_r): def assert_add_tokens(self, tokenizer_r):
vocab_size = tokenizer_r.vocab_size vocab_size = tokenizer_r.vocab_size
self.assertEqual(tokenizer_r.add_tokens(""), 0) self.assertEqual(tokenizer_r.add_tokens(""), 0)
...@@ -239,6 +296,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -239,6 +296,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
@require_torch @require_torch
def test_transfoxl(self): def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
...@@ -278,6 +338,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -278,6 +338,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
def test_distilbert(self): def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
...@@ -317,6 +380,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -317,6 +380,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
def test_gpt2(self): def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
...@@ -355,6 +421,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -355,6 +421,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
def test_roberta(self): def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
...@@ -393,6 +462,10 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -393,6 +462,10 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
# TODO: Re-enable this test as soon as Roberta align with the python tokenizer.
# self.assert_padding(tokenizer_r, tokenizer_p)
def test_openai(self): def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
...@@ -431,6 +504,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -431,6 +504,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check the number of returned files for save_vocabulary # Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment