tokenization_auto.py 9.18 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """

from __future__ import absolute_import, division, print_function, unicode_literals

import logging

from .tokenization_bert import BertTokenizer
22
from .tokenization_bert_japanese import BertJapaneseTokenizer
thomwolf's avatar
thomwolf committed
23
24
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
keskarnitish's avatar
keskarnitish committed
25
from .tokenization_ctrl import CTRLTokenizer
thomwolf's avatar
thomwolf committed
26
27
28
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_xlnet import XLNetTokenizer
from .tokenization_xlm import XLMTokenizer
29
from .tokenization_roberta import RobertaTokenizer
thomwolf's avatar
thomwolf committed
30
from .tokenization_distilbert import DistilBertTokenizer
31
from .tokenization_camembert import CamembertTokenizer
Elad Segal's avatar
Elad Segal committed
32
from .tokenization_albert import AlbertTokenizer
33
from .tokenization_t5 import T5Tokenizer
thomwolf's avatar
thomwolf committed
34
35
36
37

logger = logging.getLogger(__name__)

class AutoTokenizer(object):
38
    r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
thomwolf's avatar
thomwolf committed
39
40
41
42
43
44
45
46
47
        that will be instantiated as one of the tokenizer classes of the library
        when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
        class method.

        The `from_pretrained()` method take care of returning the correct tokenizer class instance
        using pattern matching on the `pretrained_model_name_or_path` string.

        The tokenizer class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
48
            - contains `t5`: T5Tokenizer (T5 model)
thomwolf's avatar
thomwolf committed
49
            - contains `distilbert`: DistilBertTokenizer (DistilBert model)
Elad Segal's avatar
Elad Segal committed
50
51
            - contains `albert`: AlbertTokenizer (ALBERT model)
            - contains `camembert`: CamembertTokenizer (CamemBERT model)
thomwolf's avatar
thomwolf committed
52
            - contains `roberta`: RobertaTokenizer (RoBERTa model)
thomwolf's avatar
thomwolf committed
53
54
55
56
57
58
            - contains `bert`: BertTokenizer (Bert model)
            - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
            - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
            - contains `xlnet`: XLNetTokenizer (XLNet model)
            - contains `xlm`: XLMTokenizer (XLM model)
Elad Segal's avatar
Elad Segal committed
59
            - contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
thomwolf's avatar
thomwolf committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73

        This class cannot be instantiated using `__init__()` (throw an error).
    """
    def __init__(self):
        raise EnvironmentError("AutoTokenizer is designed to be instantiated "
            "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.")

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        r""" Instantiate a one of the tokenizer classes of the library
        from a pre-trained model vocabulary.

        The tokenizer class to instantiate is selected as the first pattern matching
        in the `pretrained_model_name_or_path` string (in the following order):
74
            - contains `t5`: T5Tokenizer (T5 model)
thomwolf's avatar
thomwolf committed
75
            - contains `distilbert`: DistilBertTokenizer (DistilBert model)
Elad Segal's avatar
Elad Segal committed
76
77
            - contains `albert`: AlbertTokenizer (ALBERT model)
            - contains `camembert`: CamembertTokenizer (CamemBERT model)
78
            - contains `roberta`: RobertaTokenizer (RoBERTa model)
79
            - contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
thomwolf's avatar
thomwolf committed
80
81
82
83
84
85
            - contains `bert`: BertTokenizer (Bert model)
            - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
            - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
            - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
            - contains `xlnet`: XLNetTokenizer (XLNet model)
            - contains `xlm`: XLMTokenizer (XLM model)
Elad Segal's avatar
Elad Segal committed
86
            - contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
thomwolf's avatar
thomwolf committed
87
88

        Params:
thomwolf's avatar
thomwolf committed
89
90
91
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
92
                - a string with the `identifier name` of a predefined tokenizer that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
93
                - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
thomwolf's avatar
thomwolf committed
94
95
96
97
98
99
100
101
                - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

            force_download: (`optional`) boolean, default False:
                Force to (re-)download the vocabulary files and override the cached versions if they exists.

102
103
104
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

thomwolf's avatar
thomwolf committed
105
106
107
108
109
110
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

            inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

111
            kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
thomwolf's avatar
thomwolf committed
112
113
114

        Examples::

115
116
117
118
119
120
121
122
            # Download vocabulary from S3 and cache.
            tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

            # Download vocabulary from S3 (user-uploaded) and cache.
            tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-german-cased')

            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')
thomwolf's avatar
thomwolf committed
123
124

        """
125
126
127
        if 't5' in pretrained_model_name_or_path:
            return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'distilbert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
128
            return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
Elad Segal's avatar
Elad Segal committed
129
130
        elif 'albert' in pretrained_model_name_or_path:
            return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
131
132
        elif 'camembert' in pretrained_model_name_or_path:
            return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
133
        elif 'roberta' in pretrained_model_name_or_path:
134
            return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
135
        elif 'bert-base-japanese' in pretrained_model_name_or_path:
136
            return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
137
        elif 'bert' in pretrained_model_name_or_path:
thomwolf's avatar
thomwolf committed
138
139
140
141
142
143
144
145
146
147
148
            return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'openai-gpt' in pretrained_model_name_or_path:
            return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'gpt2' in pretrained_model_name_or_path:
            return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'transfo-xl' in pretrained_model_name_or_path:
            return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'xlnet' in pretrained_model_name_or_path:
            return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
        elif 'xlm' in pretrained_model_name_or_path:
            return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
keskarnitish's avatar
keskarnitish committed
149
150
        elif 'ctrl' in pretrained_model_name_or_path:
            return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
151
152
        raise ValueError("Unrecognized model identifier in {}. Should contains one of "
                         "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
Elad Segal's avatar
Elad Segal committed
153
                         "'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))