Unverified Commit 870e6f29 authored by Patrick Deutschmann's avatar Patrick Deutschmann Committed by GitHub
Browse files

Fix DeBERTa `token_type_ids` (#17082)

parent 279bc584
...@@ -407,7 +407,7 @@ class DebertaConverter(Converter): ...@@ -407,7 +407,7 @@ class DebertaConverter(Converter):
tokenizer.decoder = decoders.ByteLevel() tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.TemplateProcessing( tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0", single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:0 [SEP]:0", pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[ special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")), ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")), ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
......
...@@ -210,7 +210,7 @@ class DebertaTokenizer(GPT2Tokenizer): ...@@ -210,7 +210,7 @@ class DebertaTokenizer(GPT2Tokenizer):
if token_ids_1 is None: if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
......
...@@ -183,7 +183,7 @@ class DebertaTokenizerFast(GPT2TokenizerFast): ...@@ -183,7 +183,7 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
sequence pair mask has the following format: sequence pair mask has the following format:
``` ```
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence | | first sequence | second sequence |
``` ```
...@@ -203,4 +203,4 @@ class DebertaTokenizerFast(GPT2TokenizerFast): ...@@ -203,4 +203,4 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
if token_ids_1 is None: if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
...@@ -88,6 +88,12 @@ class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -88,6 +88,12 @@ class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19] input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_token_type_ids(self):
tokenizer = self.get_tokenizer()
tokd = tokenizer("Hello", "World")
expected_token_type_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
self.assertListEqual(tokd["token_type_ids"], expected_token_type_ids)
@slow @slow
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base") tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment