"benchmark/vscode:/vscode.git/clone" did not exist on "55de40f782d1949740aec74e88ae7cce00d59582"
Unverified Commit 870e6f29 authored by Patrick Deutschmann's avatar Patrick Deutschmann Committed by GitHub
Browse files

Fix DeBERTa `token_type_ids` (#17082)

parent 279bc584
......@@ -407,7 +407,7 @@ class DebertaConverter(Converter):
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
......
......@@ -210,7 +210,7 @@ class DebertaTokenizer(GPT2Tokenizer):
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
......
......@@ -183,7 +183,7 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
......@@ -203,4 +203,4 @@ class DebertaTokenizerFast(GPT2TokenizerFast):
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
......@@ -88,6 +88,12 @@ class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
def test_token_type_ids(self):
tokenizer = self.get_tokenizer()
tokd = tokenizer("Hello", "World")
expected_token_type_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
self.assertListEqual(tokd["token_type_ids"], expected_token_type_ids)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment