Unverified Commit e6767642 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Override build_inputs_with_special_tokens for fast tokenizers (#2912)



* Override build_inputs_with_special_tokens for fast impl + unittest.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Quality + format.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 59c23ad9
...@@ -572,3 +572,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ...@@ -572,3 +572,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
) )
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1:
output += token_ids_1 + [self.sep_token_id]
return output
...@@ -210,3 +210,10 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -210,3 +210,10 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
# We need to recompute max_len according to the newly register post_processor to get real values. # We need to recompute max_len according to the newly register post_processor to get real values.
self.max_len_single_sentence = self.max_len - self.num_added_tokens(False) # take into account special tokens self.max_len_single_sentence = self.max_len - self.num_added_tokens(False) # take into account special tokens
self.max_len_sentences_pair = self.max_len - self.num_added_tokens(True) # take into account special tokens self.max_len_sentences_pair = self.max_len - self.num_added_tokens(True) # take into account special tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None:
return output
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
...@@ -1669,6 +1669,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): ...@@ -1669,6 +1669,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
self._update_special_tokens() self._update_special_tokens()
return added return added
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return token_ids_0
else:
return token_ids_0 + token_ids_1
def num_added_tokens(self, pair=False): def num_added_tokens(self, pair=False):
return self.tokenizer.num_special_tokens_to_add(pair) return self.tokenizer.num_special_tokens_to_add(pair)
......
...@@ -172,6 +172,35 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -172,6 +172,35 @@ class FastTokenizerMatchingTest(unittest.TestCase):
self.assertEqual(len(tokens[key].shape), 2) self.assertEqual(len(tokens[key].shape), 2)
self.assertEqual(tokens[key].shape[-1], 6) self.assertEqual(tokens[key].shape[-1], 6)
def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p):
# Input string
input_simple = tokenizer_p.tokenize("This is a sample input")
input_pair = tokenizer_p.tokenize("This is a sample pair")
# Generate output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
self.assertEqual(output_p, output_r)
# Generate pair output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
self.assertEqual(output_p, output_r)
# Input tokens id
input_simple = tokenizer_p.encode("This is a sample input")
input_pair = tokenizer_p.encode("This is a sample pair")
# Generate output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
self.assertEqual(output_p, output_r)
# Generate pair output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
self.assertEqual(output_p, output_r)
def test_bert(self): def test_bert(self):
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name) tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
...@@ -204,6 +233,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -204,6 +233,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r) self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
@require_torch @require_torch
def test_transfoxl(self): def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
...@@ -237,6 +269,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -237,6 +269,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
def test_distilbert(self): def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
...@@ -270,6 +305,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -270,6 +305,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r) self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
def test_gpt2(self): def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
...@@ -302,6 +340,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -302,6 +340,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
def test_roberta(self): def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
...@@ -334,6 +375,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -334,6 +375,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assert_batch_encode_dynamic_overflowing(tokenizer_r) self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
def test_openai(self): def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
...@@ -366,6 +410,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): ...@@ -366,6 +410,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check for dynamic encoding sequence handling in batch_encode_plus # Check for dynamic encoding sequence handling in batch_encode_plus
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r) self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment