tokenization_auto.py 9.56 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from .tokenization_bert import BertTokenizer
22
from .tokenization_bert_japanese import BertJapaneseTokenizer
thomwolf's avatar
thomwolf committed
23
24
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
keskarnitish's avatar
keskarnitish committed
25
from .tokenization_ctrl import CTRLTokenizer
thomwolf's avatar
thomwolf committed
26
27
28
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer
29
from .tokenization_roberta import RobertaTokenizer
thomwolf's avatar
thomwolf committed
30
from .tokenization_distilbert import DistilBertTokenizer
31
from .tokenization_camembert import CamembertTokenizer
Elad Segal's avatar
Elad Segal committed
32
from .tokenization_albert import AlbertTokenizer
33
from .tokenization_t5 import T5Tokenizer
34
from .tokenization_xlm_roberta import XLMRobertaTokenizer
thomwolf's avatar
thomwolf committed
35
36
37
38

logger = logging.getLogger(__name__)

class AutoTokenizer(object):
39
    r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
thomwolf's avatar
thomwolf committed
40
41
42
43
44
45
46
47
48
        that will be instantiated as one of the tokenizer classes of the library
        when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method take care of returning the correct tokenizer class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The tokenizer class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
49
            - contains `t5`: T5Tokenizer (T5 model)
thomwolf's avatar
thomwolf committed
50
            - contains `distilbert`: DistilBertTokenizer (DistilBert model)
Elad Segal's avatar
Elad Segal committed
51
52
            - contains `albert`: AlbertTokenizer (ALBERT model)
            - contains `camembert`: CamembertTokenizer (CamemBERT model)
53
            - contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
thomwolf's avatar
thomwolf committed
54
            - contains `roberta`: RobertaTokenizer (RoBERTa model)
thomwolf's avatar
thomwolf committed
55
56
57
58
59
60
            - contains `bert`: BertTokenizer (Bert model)
            - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
            - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
            - contains `xlnet`: XLNetTokenizer (XLNet model)
            - contains `xlm`: XLMTokenizer (XLM model)
Elad Segal's avatar
Elad Segal committed
61
            - contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
thomwolf's avatar
thomwolf committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75

        This class cannot be instantiated using `__init__()` (throw an error).
    """
    def __init__(self):
        raise EnvironmentError("AutoTokenizer is designed to be instantiated "
            "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        r""" Instantiate a one of the tokenizer classes of the library
        from a pre-trained model vocabulary.

        The tokenizer class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
76
            - contains `t5`: T5Tokenizer (T5 model)
thomwolf's avatar
thomwolf committed
77
            - contains `distilbert`: DistilBertTokenizer (DistilBert model)
Elad Segal's avatar
Elad Segal committed
78
79
            - contains `albert`: AlbertTokenizer (ALBERT model)
            - contains `camembert`: CamembertTokenizer (CamemBERT model)
80
            - contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
81
            - contains `roberta`: RobertaTokenizer (RoBERTa model)
82
            - contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
thomwolf's avatar
thomwolf committed
83
84
85
86
87
88
            - contains `bert`: BertTokenizer (Bert model)
            - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
            - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
            - contains `xlnet`: XLNetTokenizer (XLNet model)
            - contains `xlm`: XLMTokenizer (XLM model)
Elad Segal's avatar
Elad Segal committed
89
            - contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
thomwolf's avatar
thomwolf committed
90
91

        Params:
thomwolf's avatar
thomwolf committed
92
93
94
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
95
                - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
96
                - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
97
98
99
100
101
102
103
104
                - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the vocabulary files and override the cached versions if they exists.

105
106
107
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
108
109
110
111
112
113
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

114
            kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
thomwolf's avatar
thomwolf committed
115
116
117

        Examples::

118
119
120
121
122
123
124
125
            # Download vocabulary from S3 and cache.
            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

            # Download vocabulary from S3 (user-uploaded) and cache.
            tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')

            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
thomwolf's avatar
thomwolf committed
126
127

        """
128
129
130
        if 't5' in pretrained_model_name_or_path:
            return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
131
            return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
Elad Segal's avatar
Elad Segal committed
132
133
        elif 'albert' in pretrained_model_name_or_path:
            return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
134
135
        elif 'camembert' in pretrained_model_name_or_path:
            return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
136
137
        elif 'xlm-roberta' in pretrained_model_name_or_path:
            return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
138
        elif 'roberta' in pretrained_model_name_or_path:
139
            return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
140
        elif 'bert-base-japanese' in pretrained_model_name_or_path:
141
            return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
142
        elif 'bert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
143
144
145
146
147
148
149
150
151
152
153
            return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
            return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'gpt2' in pretrained_model_name_or_path:
            return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'transfo-xl' in pretrained_model_name_or_path:
            return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
            return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'xlm' in pretrained_model_name_or_path:
            return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
keskarnitish's avatar
keskarnitish committed
154
155
        elif 'ctrl' in pretrained_model_name_or_path:
            return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
156
157
        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
158
                         "'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))