Unverified Commit fb650df8 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Support for private models from huggingface.co (#9141)



* minor wording tweaks

* Create private model repo + exist_ok flag

* file_utils: `use_auth_token`

* Update src/transformers/file_utils.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Propagate doc from @sgugger
Co-Authored-By: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c69d19fa
......@@ -317,6 +317,9 @@ class PretrainedConfig(object):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
......@@ -332,6 +335,10 @@ class PretrainedConfig(object):
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the ``return_unused_kwargs`` keyword parameter.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Returns:
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
......@@ -373,6 +380,7 @@ class PretrainedConfig(object):
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
......@@ -395,6 +403,7 @@ class PretrainedConfig(object):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
......
......@@ -16,6 +16,7 @@ Utilities for working with the local dataset cache. Parts of this file is adapte
https://github.com/allenai/allennlp.
"""
import copy
import fnmatch
import io
import json
......@@ -42,6 +43,7 @@ import requests
from filelock import FileLock
from . import __version__
from .hf_api import HfFolder
from .utils import logging
......@@ -1024,6 +1026,7 @@ def cached_path(
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
......@@ -1036,6 +1039,8 @@ def cached_path(
force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
......@@ -1063,6 +1068,7 @@ def cached_path(
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
......@@ -1125,11 +1131,11 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = {"user-agent": http_user_agent(user_agent)}
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
......@@ -1159,6 +1165,7 @@ def get_from_cache(
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
......@@ -1178,11 +1185,19 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = "Bearer {}".format(use_auth_token)
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = "Bearer {}".format(token)
url_to_download = url
etag = None
if not local_files_only:
try:
headers = {"user-agent": http_user_agent(user_agent)}
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
......@@ -1272,7 +1287,7 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)
......
......@@ -206,7 +206,7 @@ class HfApi:
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
Get the public list of all the models on huggingface.co
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
......@@ -228,7 +228,13 @@ class HfApi:
return [RepoObj(**x) for x in d]
def create_repo(
self, token: str, name: str, organization: Optional[str] = None, lfsmultipartthresh: Optional[int] = None
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
......@@ -236,10 +242,14 @@ class HfApi:
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = "{}/api/repos/create".format(self.endpoint)
json = {"name": name, "organization": organization}
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
......@@ -247,6 +257,8 @@ class HfApi:
headers={"authorization": "Bearer {}".format(token)},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
......
......@@ -226,6 +226,7 @@ class FlaxPreTrainedModel(ABC):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
# Load config if we don't provide a configuration
......@@ -240,6 +241,7 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
**kwargs,
)
......@@ -283,6 +285,7 @@ class FlaxPreTrainedModel(ABC):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except EnvironmentError as err:
logger.error(err)
......
......@@ -894,6 +894,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
......@@ -916,6 +919,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples::
>>> from transformers import BertConfig, TFBertModel
......@@ -939,6 +946,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
......@@ -954,6 +962,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
**kwargs,
)
......@@ -996,6 +1005,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except EnvironmentError as err:
logger.error(err)
......
......@@ -886,6 +886,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
......@@ -908,6 +911,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples::
>>> from transformers import BertConfig, BertModel
......@@ -931,6 +938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
......@@ -946,6 +954,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
**kwargs,
)
......@@ -998,6 +1007,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except EnvironmentError as err:
logger.error(err)
......
......@@ -744,8 +744,8 @@ class TextGenerationPipeline(Pipeline):
task identifier: :obj:`"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available
community models on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models
on `huggingface.co/models <https://huggingface.co/models?filter=causal-lm>`__.
"""
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
......
......@@ -1648,6 +1648,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
......@@ -1662,6 +1665,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``,
``mask_token``, ``additional_special_tokens``. See parameters in the ``__init__`` for more details.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples::
# We can't instantiate directly the base class `PreTrainedTokenizerBase` so let's show our examples on a derived class: BertTokenizer
......@@ -1689,6 +1696,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
......@@ -1770,6 +1778,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment