Unverified Commit fb650df8 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Support for private models from huggingface.co (#9141)



* minor wording tweaks

* Create private model repo + exist_ok flag

* file_utils: `use_auth_token`

* Update src/transformers/file_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Propagate doc from @sgugger
Co-Authored-By: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c69d19fa
...@@ -317,6 +317,9 @@ class PretrainedConfig(object): ...@@ -317,6 +317,9 @@ class PretrainedConfig(object):
proxies (:obj:`Dict[str, str]`, `optional`): proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
...@@ -332,6 +335,10 @@ class PretrainedConfig(object): ...@@ -332,6 +335,10 @@ class PretrainedConfig(object):
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the ``return_unused_kwargs`` keyword parameter. by the ``return_unused_kwargs`` keyword parameter.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Returns: Returns:
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model. :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
...@@ -373,6 +380,7 @@ class PretrainedConfig(object): ...@@ -373,6 +380,7 @@ class PretrainedConfig(object):
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
...@@ -395,6 +403,7 @@ class PretrainedConfig(object): ...@@ -395,6 +403,7 @@ class PretrainedConfig(object):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
) )
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
......
...@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte ...@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte
https://github.com/allenai/allennlp. https://github.com/allenai/allennlp.
""" """
import copy
import fnmatch import fnmatch
import io import io
import json import json
...@@ -42,6 +43,7 @@ import requests ...@@ -42,6 +43,7 @@ import requests
from filelock import FileLock from filelock import FileLock
from . import __version__ from . import __version__
from .hf_api import HfFolder
from .utils import logging from .utils import logging
...@@ -1024,6 +1026,7 @@ def cached_path( ...@@ -1024,6 +1026,7 @@ def cached_path(
user_agent: Union[Dict, str, None] = None, user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False, extract_compressed_file=False,
force_extract=False, force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False, local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
...@@ -1036,6 +1039,8 @@ def cached_path( ...@@ -1036,6 +1039,8 @@ def cached_path(
force_download: if True, re-download the file even if it's already cached in the cache dir. force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found. resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests. user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive. file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted, force_extract: if True when extract_compressed_file is True and the archive was already extracted,
...@@ -1063,6 +1068,7 @@ def cached_path( ...@@ -1063,6 +1068,7 @@ def cached_path(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
user_agent=user_agent, user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
...@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
""" """
Donwload remote file. Do not gobble up errors. Donwload remote file. Do not gobble up errors.
""" """
headers = {"user-agent": http_user_agent(user_agent)} headers = copy.deepcopy(headers)
if resume_size > 0: if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,) headers["Range"] = "bytes=%d-" % (resume_size,)
r = requests.get(url, stream=True, proxies=proxies, headers=headers) r = requests.get(url, stream=True, proxies=proxies, headers=headers)
...@@ -1159,6 +1165,7 @@ def get_from_cache( ...@@ -1159,6 +1165,7 @@ def get_from_cache(
etag_timeout=10, etag_timeout=10,
resume_download=False, resume_download=False,
user_agent: Union[Dict, str, None] = None, user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False, local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
...@@ -1178,11 +1185,19 @@ def get_from_cache( ...@@ -1178,11 +1185,19 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = "Bearer {}".format(use_auth_token)
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = "Bearer {}".format(token)
url_to_download = url url_to_download = url
etag = None etag = None
if not local_files_only: if not local_files_only:
try: try:
headers = {"user-agent": http_user_agent(user_agent)}
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status() r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
...@@ -1272,7 +1287,7 @@ def get_from_cache( ...@@ -1272,7 +1287,7 @@ def get_from_cache(
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
logger.info("storing %s in cache at %s", url, cache_path) logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path) os.replace(temp_file.name, cache_path)
......
...@@ -206,7 +206,7 @@ class HfApi: ...@@ -206,7 +206,7 @@ class HfApi:
def model_list(self) -> List[ModelInfo]: def model_list(self) -> List[ModelInfo]:
""" """
Get the public list of all the models on huggingface, including the community models Get the public list of all the models on huggingface.co
""" """
path = "{}/api/models".format(self.endpoint) path = "{}/api/models".format(self.endpoint)
r = requests.get(path) r = requests.get(path)
...@@ -228,7 +228,13 @@ class HfApi: ...@@ -228,7 +228,13 @@ class HfApi:
return [RepoObj(**x) for x in d] return [RepoObj(**x) for x in d]
def create_repo( def create_repo(
self, token: str, name: str, organization: Optional[str] = None, lfsmultipartthresh: Optional[int] = None self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str: ) -> str:
""" """
HuggingFace git-based system, used for models. HuggingFace git-based system, used for models.
...@@ -236,10 +242,14 @@ class HfApi: ...@@ -236,10 +242,14 @@ class HfApi:
Call HF API to create a whole repo. Call HF API to create a whole repo.
Params: Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes. lfsmultipartthresh: Optional: internal param for testing purposes.
""" """
path = "{}/api/repos/create".format(self.endpoint) path = "{}/api/repos/create".format(self.endpoint)
json = {"name": name, "organization": organization} json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None: if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post( r = requests.post(
...@@ -247,6 +257,8 @@ class HfApi: ...@@ -247,6 +257,8 @@ class HfApi:
headers={"authorization": "Bearer {}".format(token)}, headers={"authorization": "Bearer {}".format(token)},
json=json, json=json,
) )
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return d["url"] return d["url"]
......
...@@ -226,6 +226,7 @@ class FlaxPreTrainedModel(ABC): ...@@ -226,6 +226,7 @@ class FlaxPreTrainedModel(ABC):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
...@@ -240,6 +241,7 @@ class FlaxPreTrainedModel(ABC): ...@@ -240,6 +241,7 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision, revision=revision,
**kwargs, **kwargs,
) )
...@@ -283,6 +285,7 @@ class FlaxPreTrainedModel(ABC): ...@@ -283,6 +285,7 @@ class FlaxPreTrainedModel(ABC):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
) )
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
......
...@@ -894,6 +894,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -894,6 +894,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model). Whether or not to only look at local files (e.g., not try doanloading the model).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
...@@ -916,6 +919,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -916,6 +919,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function. attribute will be passed to the underlying model's ``__init__`` function.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples:: Examples::
>>> from transformers import BertConfig, TFBertModel >>> from transformers import BertConfig, TFBertModel
...@@ -939,6 +946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -939,6 +946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
...@@ -954,6 +962,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -954,6 +962,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision, revision=revision,
**kwargs, **kwargs,
) )
...@@ -996,6 +1005,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -996,6 +1005,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
) )
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
......
...@@ -886,6 +886,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -886,6 +886,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
...@@ -908,6 +911,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -908,6 +911,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function. attribute will be passed to the underlying model's ``__init__`` function.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples:: Examples::
>>> from transformers import BertConfig, BertModel >>> from transformers import BertConfig, BertModel
...@@ -931,6 +938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -931,6 +938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
...@@ -946,6 +954,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -946,6 +954,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision, revision=revision,
**kwargs, **kwargs,
) )
...@@ -998,6 +1007,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -998,6 +1007,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
) )
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
......
...@@ -744,8 +744,8 @@ class TextGenerationPipeline(Pipeline): ...@@ -744,8 +744,8 @@ class TextGenerationPipeline(Pipeline):
task identifier: :obj:`"text-generation"`. task identifier: :obj:`"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models
community models on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__. on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
""" """
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
......
...@@ -1648,6 +1648,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1648,6 +1648,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies (:obj:`Dict[str, str], `optional`): proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
...@@ -1662,6 +1665,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1662,6 +1665,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details. ``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples:: Examples::
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer # We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
...@@ -1689,6 +1696,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1689,6 +1696,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
...@@ -1770,6 +1778,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1770,6 +1778,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token,
) )
except requests.exceptions.HTTPError as err: except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err): if "404 Client Error" in str(err):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment