"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8ba4c5885f03e437edff225e53f8fd334ebc3819"
Unverified Commit af38c837 authored by Towdo's avatar Towdo Committed by GitHub
Browse files

Fixed inconsistency in several fast tokenizers (#26561)

parent 8878eb1b
...@@ -265,7 +265,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ...@@ -265,7 +265,7 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -159,7 +159,7 @@ class ConvBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -159,7 +159,7 @@ class ConvBertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -164,7 +164,7 @@ class RetriBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -164,7 +164,7 @@ class RetriBertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -190,7 +190,7 @@ class DistilBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -190,7 +190,7 @@ class DistilBertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -192,7 +192,7 @@ class ElectraTokenizerFast(PreTrainedTokenizerFast): ...@@ -192,7 +192,7 @@ class ElectraTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -212,7 +212,7 @@ class FunnelTokenizerFast(PreTrainedTokenizerFast): ...@@ -212,7 +212,7 @@ class FunnelTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -166,7 +166,7 @@ class LayoutLMTokenizerFast(PreTrainedTokenizerFast): ...@@ -166,7 +166,7 @@ class LayoutLMTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -152,7 +152,7 @@ class LxmertTokenizerFast(PreTrainedTokenizerFast): ...@@ -152,7 +152,7 @@ class LxmertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -150,7 +150,7 @@ class MobileBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -150,7 +150,7 @@ class MobileBertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -282,7 +282,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast): ...@@ -282,7 +282,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -163,7 +163,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast): ...@@ -163,7 +163,7 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -173,7 +173,7 @@ class SqueezeBertTokenizerFast(PreTrainedTokenizerFast): ...@@ -173,7 +173,7 @@ class SqueezeBertTokenizerFast(PreTrainedTokenizerFast):
""" """
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
if token_ids_1: if token_ids_1 is not None:
output += token_ids_1 + [self.sep_token_id] output += token_ids_1 + [self.sep_token_id]
return output return output
......
...@@ -3209,19 +3209,27 @@ class TokenizerTesterMixin: ...@@ -3209,19 +3209,27 @@ class TokenizerTesterMixin:
# output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) # output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
# self.assertEqual(output_p, output_r) # self.assertEqual(output_p, output_r)
# Input tokens id input_pairs = [
input_simple = tokenizer_p.encode("This is a sample input", add_special_tokens=False) ("", ""),
input_pair = tokenizer_p.encode("This is a sample pair", add_special_tokens=False) ("", "This is a sample pair"),
("This is a sample input", ""),
("This is a sample input", "This is a sample pair"),
]
# Generate output for sample_input, sample_pair in input_pairs:
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple) # Input tokens id
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple) input_simple = tokenizer_p.encode(sample_input, add_special_tokens=False)
self.assertEqual(output_p, output_r) input_pair = tokenizer_p.encode(sample_pair, add_special_tokens=False)
# Generate pair output # Generate output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair) output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
self.assertEqual(output_p, output_r) self.assertEqual(output_p, output_r)
# Generate pair output
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
self.assertEqual(output_p, output_r)
def test_padding(self, max_length=50): def test_padding(self, max_length=50):
if not self.test_slow_tokenizer: if not self.test_slow_tokenizer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment