test_tokenization_deberta.py 2.83 KB
Newer Older
Pengcheng He's avatar
Pengcheng He committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2018 Microsoft, the Hugging Face Team.
Pengcheng He's avatar
Pengcheng He committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re
import unittest
from typing import Tuple

Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.models.deberta.tokenization_deberta import DebertaTokenizer
Pengcheng He's avatar
Pengcheng He committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from transformers.testing_utils import require_torch

from .test_tokenization_common import TokenizerTesterMixin


@require_torch
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

    tokenizer_class = DebertaTokenizer

    def setUp(self):
        super().setUp()

    def get_tokenizer(self, name="microsoft/deberta-base", **kwargs):
        return DebertaTokenizer.from_pretrained(name, **kwargs)

    def get_input_output_texts(self, tokenizer):
        input_text = "lower newer"
        output_text = "lower newer"
        return input_text, output_text

    def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
        toks = [
            (i, tokenizer.decode([i], clean_up_tokenization_spaces=False))
            for i in range(5, min(len(tokenizer), 50260))
        ]
        toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
        toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
        if max_length is not None and len(toks) > max_length:
            toks = toks[:max_length]
        # toks_str = [t[1] for t in toks]
        toks_ids = [t[0] for t in toks]

        # Ensure consistency
        output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
        if " " not in output_txt and len(toks_ids) > 1:
            output_txt = (
                tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
                + " "
                + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
            )
        if with_prefix_space and not output_txt.startswith(" "):
            output_txt = " " + output_txt
        output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
        return output_txt, output_ids

    def test_full_tokenizer(self):
        tokenizer = self.get_tokenizer("microsoft/deberta-base")
        input_str = "UNwant\u00E9d,running"
        tokens = tokenizer.tokenize(input_str)
        token_ids = tokenizer.convert_tokens_to_ids(tokens)

        self.assertEqual(tokenizer.decode(token_ids), input_str)