Unverified Commit bf2e0cf7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer push to hub (#11328)



* Initial support for upload to hub

* push -> upload

* Fixes + examples

* Fix torchhub test

* Torchhub test I hate you

* push_model_to_hub -> push_to_hub

* Apply mixin to other pretrained models

* Remove ABC inheritance

* Add tests

* Typo

* Run tests

* Install git-lfs

* Change approach

* Add push_to_hub to all

* Staging test suite

* Typo

* Maybe like this?

* More deps

* Cache

* Adapt name

* Quality

* MOAR tests

* Put it in testing_utils

* Docs + torchhub last hope

* Styling

* Wrong method

* Typos

* Update src/transformers/file_utils.py
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>

* Address review comments

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7bc86bea
...@@ -317,24 +317,33 @@ jobs: ...@@ -317,24 +317,33 @@ jobs:
- store_artifacts: - store_artifacts:
path: ~/transformers/reports path: ~/transformers/reports
run_tests_git_lfs: run_tests_hub:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.7 - image: circleci/python:3.7
environment: environment:
HUGGINGFACE_CO_STAGING: yes
RUN_GIT_LFS_TESTS: yes RUN_GIT_LFS_TESTS: yes
TRANSFORMERS_IS_CI: yes TRANSFORMERS_IS_CI: yes
resource_class: xlarge resource_class: xlarge
parallelism: 1 parallelism: 1
steps: steps:
- checkout - checkout
- restore_cache:
keys:
- v0.4-hub-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get install git-lfs - run: sudo apt-get install git-lfs
- run: | - run: |
git config --global user.email "ci@dummy.com" git config --global user.email "ci@dummy.com"
git config --global user.name "ci" git config --global user.name "ci"
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[testing] - run: pip install .[torch,sentencepiece,testing]
- run: python -m pytest -sv ./tests/test_hf_api.py -k "HfLargefilesTest" - save_cache:
key: v0.4-hub-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -sv ./tests/ -m is_staging_test
build_doc: build_doc:
working_directory: ~/transformers working_directory: ~/transformers
...@@ -469,7 +478,7 @@ workflows: ...@@ -469,7 +478,7 @@ workflows:
- run_tests_flax - run_tests_flax
- run_tests_pipelines_torch - run_tests_pipelines_torch
- run_tests_pipelines_tf - run_tests_pipelines_tf
- run_tests_git_lfs - run_tests_hub
- build_doc - build_doc
- deploy_doc: *workflow_filters - deploy_doc: *workflow_filters
# tpu_testing_jobs: # tpu_testing_jobs:
......
...@@ -73,3 +73,10 @@ Generation ...@@ -73,3 +73,10 @@ Generation
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin .. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
:members: :members:
Pushing to the Hub
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.file_utils.PushToHubMixin
:members:
...@@ -22,8 +22,6 @@ the `model hub <https://huggingface.co/models>`__. ...@@ -22,8 +22,6 @@ the `model hub <https://huggingface.co/models>`__.
Optionally, you can join an existing organization or create a new one. Optionally, you can join an existing organization or create a new one.
Prepare your model for uploading
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We have seen in the :doc:`training tutorial <training>`: how to fine-tune a model on a given task. You have probably We have seen in the :doc:`training tutorial <training>`: how to fine-tune a model on a given task. You have probably
done something similar on your task, either using the model directly in your own training loop or using the done something similar on your task, either using the model directly in your own training loop or using the
...@@ -31,7 +29,7 @@ done something similar on your task, either using the model directly in your own ...@@ -31,7 +29,7 @@ done something similar on your task, either using the model directly in your own
`model hub <https://huggingface.co/models>`__. `model hub <https://huggingface.co/models>`__.
Model versioning Model versioning
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since version v3.5.0, the model hub has built-in model versioning based on git and git-lfs. It is based on the paradigm Since version v3.5.0, the model hub has built-in model versioning based on git and git-lfs. It is based on the paradigm
that one model *is* one repo. that one model *is* one repo.
...@@ -54,6 +52,106 @@ For instance: ...@@ -54,6 +52,106 @@ For instance:
>>> revision="v2.0.1" # tag name, or branch name, or commit hash >>> revision="v2.0.1" # tag name, or branch name, or commit hash
>>> ) >>> )
Push your model from Python
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Preparation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The first step is to make sure your credentials to the hub are stored somewhere. This can be done in two ways. If you
have access to a terminal, you cam just run the following command in the virtual environment where you installed 🤗
Transformers:
.. code-block:: bash
transformers-cli login
It will store your access token in the Hugging Face cache folder (by default :obj:`~/.cache/`).
If you don't have an easy access to a terminal (for instance in a Colab session), you can find a token linked to your
acount by going on `huggingface.co <https://huggingface.co/>`, click on your avatar on the top left corner, then on
`Edit profile` on the left, just beneath your profile picture. In the submenu `API Tokens`, you will find your API
token that you can just copy.
Directly push your model to the hub
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once you have an API token (either stored in the cache or copied and pasted in your notebook), you can directly push a
finetuned model you saved in :obj:`save_drectory` by calling:
.. code-block:: python
finetuned_model.push_to_hub("my-awesome-model")
If you have your API token not stored in the cache, you will need to pass it with :obj:`use_auth_token=your_token`.
This is also be the case for all the examples below, so we won't mention it again.
This will create a repository in your namespace name :obj:`my-awesome-model`, so anyone can now run:
.. code-block:: python
from transformers import AutoModel
model = AutoModel.from_pretrained("your_username/my-awesome-model")
Even better, you can combine this push to the hub with the call to :obj:`save_pretrained`:
.. code-block:: python
finetuned_model.save_pretrained(save_directory, push_to_hub=True, repo_name="my-awesome-model")
If you are a premium user and want your model to be private, just add :obj:`private=True` to this call.
If you are a member of an organization and want to push it inside the namespace of the organization instead of yours,
just add :obj:`organization=my_amazing_org`.
Add new files to your model repo
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once you have pushed your model to the hub, you might want to add the tokenizer, or a version of your model for another
framework (TensorFlow, PyTorch, Flax). This is super easy to do! Let's begin with the tokenizer. You can add it to the
repo you created before like this
.. code-block:: python
tokenizer.push_to_hub("my-awesome-model")
If you know its URL (it should be :obj:`https://huggingface.co/username/repo_name`), you can also do:
.. code-block:: python
tokenizer.push_to_hub(repo_url=my_repo_url)
And that's all there is to it! It's also a very easy way to fix a mistake if one of the files online had a bug.
To add a model for another backend, it's also super easy. Let's say you have fine-tuned a TensorFlow model and want to
add the pytorch model files to your model repo, so that anyone in the community can use it. The following allows you to
directly create a PyTorch version of your TensorFlow model:
.. code-block:: python
from transfomers import AutoModel
model = AutoModel.from_pretrained(save_directory, from_tf=True)
You can also replace :obj:`save_directory` by the identifier of your model (:obj:`username/repo_name`) if you don't
have a local save of it anymore. Then, just do the same as before:
.. code-block:: python
model.push_to_hub("my-awesome-model")
or
.. code-block:: python
model.push_to_hub(repo_url=my_repo_url)
Use your terminal and git
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Basic steps Basic steps
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -447,6 +447,9 @@ def main(): ...@@ -447,6 +447,9 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -476,6 +476,9 @@ def main(): ...@@ -476,6 +476,9 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -452,6 +452,9 @@ def main(): ...@@ -452,6 +452,9 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -428,6 +428,9 @@ def main(): ...@@ -428,6 +428,9 @@ def main():
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -599,6 +599,9 @@ def main(): ...@@ -599,6 +599,9 @@ def main():
trainer.log_metrics("test", metrics) trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("test", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -638,6 +638,9 @@ def main(): ...@@ -638,6 +638,9 @@ def main():
trainer.log_metrics("test", metrics) trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics) trainer.save_metrics("test", metrics)
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -579,6 +579,9 @@ def main(): ...@@ -579,6 +579,9 @@ def main():
with open(output_test_preds_file, "w") as writer: with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(test_preds))
if training_args.push_to_hub:
trainer.push_to_hub()
return results return results
......
...@@ -517,6 +517,9 @@ def main(): ...@@ -517,6 +517,9 @@ def main():
item = label_list[item] item = label_list[item]
writer.write(f"{index}\t{item}\n") writer.write(f"{index}\t{item}\n")
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -491,6 +491,9 @@ def main(): ...@@ -491,6 +491,9 @@ def main():
for prediction in true_predictions: for prediction in true_predictions:
writer.write(" ".join(prediction) + "\n") writer.write(" ".join(prediction) + "\n")
if training_args.push_to_hub:
trainer.push_to_hub()
def _mp_fn(index): def _mp_fn(index):
# For xla_spawn (TPUs) # For xla_spawn (TPUs)
......
...@@ -571,6 +571,9 @@ def main(): ...@@ -571,6 +571,9 @@ def main():
with open(output_test_preds_file, "w") as writer: with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds)) writer.write("\n".join(test_preds))
if training_args.push_to_hub:
trainer.push_to_hub()
return results return results
......
...@@ -31,7 +31,7 @@ from transformers import ( ...@@ -31,7 +31,7 @@ from transformers import (
) )
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"] dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata", "huggingface_hub"]
@add_start_docstrings(AutoConfig.__doc__) @add_start_docstrings(AutoConfig.__doc__)
......
...@@ -22,14 +22,14 @@ import os ...@@ -22,14 +22,14 @@ import os
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from . import __version__ from . import __version__
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url from .file_utils import CONFIG_NAME, PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class PretrainedConfig(object): class PretrainedConfig(PushToHubMixin):
r""" r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations. methods for loading/downloading/saving configurations.
...@@ -310,7 +310,7 @@ class PretrainedConfig(object): ...@@ -310,7 +310,7 @@ class PretrainedConfig(object):
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PretrainedConfig.from_pretrained` class method. :func:`~transformers.PretrainedConfig.from_pretrained` class method.
...@@ -318,6 +318,11 @@ class PretrainedConfig(object): ...@@ -318,6 +318,11 @@ class PretrainedConfig(object):
Args: Args:
save_directory (:obj:`str` or :obj:`os.PathLike`): save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist). Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -328,6 +333,10 @@ class PretrainedConfig(object): ...@@ -328,6 +333,10 @@ class PretrainedConfig(object):
self.to_json_file(output_config_file, use_diff=True) self.to_json_file(output_config_file, use_diff=True)
logger.info(f"Configuration saved in {output_config_file}") logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
url = self._push_to_hub(save_files=[output_config_file], **kwargs)
logger.info(f"Configuration pushed to the hub in this commit: {url}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
r""" r"""
......
...@@ -31,6 +31,7 @@ import types ...@@ -31,6 +31,7 @@ import types
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import fields from dataclasses import fields
from distutils.dir_util import copy_tree
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
...@@ -47,10 +48,10 @@ from tqdm.auto import tqdm ...@@ -47,10 +48,10 @@ from tqdm.auto import tqdm
import requests import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfApi, HfFolder, Repository
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
from . import __version__ from . import __version__
from .hf_api import HfFolder
from .utils import logging from .utils import logging
...@@ -229,7 +230,12 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] ...@@ -229,7 +230,12 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
_default_endpoint = "https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", _default_endpoint)
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
PRESET_MIRROR_DICT = { PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
...@@ -1684,3 +1690,125 @@ def copy_func(f): ...@@ -1684,3 +1690,125 @@ def copy_func(f):
g = functools.update_wrapper(g, f) g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__ g.__kwdefaults__ = f.__kwdefaults__
return g return g
class PushToHubMixin:
"""
A Mixin containing the functionality to push a model or tokenizer to the hub.
"""
def push_to_hub(
self,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = None,
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
"""
Upload model checkpoint or tokenizer files to the 🤗 model hub.
Parameters:
repo_name (:obj:`str`, `optional`):
Repository name for your model or tokenizer in the hub. If not specified, the repository name will be
the stem of :obj:`save_directory`.
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
:obj:`repo_name`.
commit_message (:obj:`str`, `optional`):
Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or
:obj:`"add model"` depending on the type of the class.
organization (:obj:`str`, `optional`):
Organization in which you want to push your model or tokenizer (you must be a member of this
organization).
private (:obj:`bool`, `optional`):
Whether or not the repository created should be private (requires a paying subscription).
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
:obj:`True` if :obj:`repo_url` is not specified.
Returns:
The url of the commit of your model in the given repository.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
self.save_pretrained(tmp_dir)
self._push_to_hub(
save_directory=tmp_dir,
repo_name=repo_name,
repo_url=repo_url,
commit_message=commit_message,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
@classmethod
def _push_to_hub(
cls,
save_directory: Optional[str] = None,
save_files: Optional[List[str]] = None,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = None,
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> str:
# Private version of push_to_hub, that either accepts a folder to push or a list of files.
if save_directory is None and save_files is None:
raise ValueError("_push_to_hub requires either a `save_directory` or a list of `save_files`.")
if repo_name is None and repo_url is None and save_directory is None:
raise ValueError("Need either a `repo_name` or `repo_url` to know where to push!")
if repo_name is None and repo_url is None and save_files is None:
repo_name = Path(save_directory).name
if use_auth_token is None and repo_url is None:
use_auth_token = True
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
else:
token = None
if repo_url is None:
# Special provision for the test endpoint (CI)
repo_url = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
token,
repo_name,
organization=organization,
private=private,
repo_type=None,
exist_ok=True,
)
if commit_message is None:
if "Tokenizer" in cls.__name__:
commit_message = "add tokenizer"
if "Config" in cls.__name__:
commit_message = "add config"
else:
commit_message = "add model"
with tempfile.TemporaryDirectory() as tmp_dir:
# First create the repo (and clone its content if it's nonempty), then add the files (otherwise there is
# no diff so nothing is pushed).
repo = Repository(tmp_dir, clone_from=repo_url, use_auth_token=use_auth_token)
if save_directory is None:
for filename in save_files:
shutil.copy(filename, Path(tmp_dir) / Path(filename).name)
else:
copy_tree(save_directory, tmp_dir)
return repo.push_to_hub(commit_message=commit_message)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
from abc import ABC
from functools import partial from functools import partial
from pickle import UnpicklingError from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union from typing import Dict, Set, Tuple, Union
...@@ -29,8 +28,10 @@ from jax.random import PRNGKey ...@@ -29,8 +28,10 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
cached_path, cached_path,
copy_func, copy_func,
...@@ -54,7 +55,7 @@ ACT2FN = { ...@@ -54,7 +55,7 @@ ACT2FN = {
} }
class FlaxPreTrainedModel(ABC): class FlaxPreTrainedModel(PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
...@@ -385,7 +386,7 @@ class FlaxPreTrainedModel(ABC): ...@@ -385,7 +386,7 @@ class FlaxPreTrainedModel(ABC):
return model return model
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub=False, **kwargs):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
...@@ -393,6 +394,11 @@ class FlaxPreTrainedModel(ABC): ...@@ -393,6 +394,11 @@ class FlaxPreTrainedModel(ABC):
Arguments: Arguments:
save_directory (:obj:`str` or :obj:`os.PathLike`): save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -406,10 +412,18 @@ class FlaxPreTrainedModel(ABC): ...@@ -406,10 +412,18 @@ class FlaxPreTrainedModel(ABC):
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
# save model # save model
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(self.params) model_bytes = to_bytes(self.params)
f.write(model_bytes) f.write(model_bytes)
logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
url = self._push_to_hub(save_files=saved_files, **kwargs)
logger.info(f"Model pushed to the hub in this commit: {url}")
def overwrite_call_docstring(model_class, docstring): def overwrite_call_docstring(model_class, docstring):
# copy __call__ function to be sure docstring is changed only for this function # copy __call__ function to be sure docstring is changed only for this function
......
...@@ -30,10 +30,12 @@ from tensorflow.python.keras.saving import hdf5_format ...@@ -30,10 +30,12 @@ from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
CONFIG_NAME,
DUMMY_INPUTS, DUMMY_INPUTS,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ModelOutput, ModelOutput,
PushToHubMixin,
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
...@@ -591,7 +593,7 @@ def init_copy_embeddings(old_embeddings, new_num_tokens): ...@@ -591,7 +593,7 @@ def init_copy_embeddings(old_embeddings, new_num_tokens):
return mask, current_weights return mask, current_weights
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
r""" r"""
Base class for all TF models. Base class for all TF models.
...@@ -1011,7 +1013,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1011,7 +1013,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
""" """
raise NotImplementedError raise NotImplementedError
def save_pretrained(self, save_directory, saved_model=False, version=1): def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
:func:`~transformers.TFPreTrainedModel.from_pretrained` class method. :func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
...@@ -1025,6 +1027,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1025,6 +1027,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
The version of the saved model. A saved model needs to be versioned in order to be properly loaded by The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
TensorFlow Serving as detailed in the official documentation TensorFlow Serving as detailed in the official documentation
https://www.tensorflow.org/tfx/serving/serving_basic https://www.tensorflow.org/tfx/serving/serving_basic
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -1045,6 +1052,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1045,6 +1052,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
self.save_weights(output_model_file) self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}") logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [os.path.join(save_directory, CONFIG_NAME), output_model_file]
url = self._push_to_hub(save_files=saved_files, **kwargs)
logger.info(f"Model pushed to the hub in this commit: {url}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" r"""
......
...@@ -29,12 +29,14 @@ from torch.nn import functional as F ...@@ -29,12 +29,14 @@ from torch.nn import functional as F
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
CONFIG_NAME,
DUMMY_INPUTS, DUMMY_INPUTS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME, TF_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ModelOutput, ModelOutput,
PushToHubMixin,
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
...@@ -385,7 +387,7 @@ class ModuleUtilsMixin: ...@@ -385,7 +387,7 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
...@@ -799,6 +801,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -799,6 +801,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
save_config: bool = True, save_config: bool = True,
state_dict: Optional[dict] = None, state_dict: Optional[dict] = None,
save_function: Callable = torch.save, save_function: Callable = torch.save,
push_to_hub: bool = False,
**kwargs,
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
...@@ -818,6 +822,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -818,6 +822,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
save_function (:obj:`Callable`): save_function (:obj:`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace :obj:`torch.save` by another method. need to replace :obj:`torch.save` by another method.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
...@@ -848,6 +857,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -848,6 +857,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
logger.info(f"Model weights saved in {output_model_file}") logger.info(f"Model weights saved in {output_model_file}")
if push_to_hub:
saved_files = [output_model_file]
if save_config:
saved_files.append(os.path.join(save_directory, CONFIG_NAME))
url = self._push_to_hub(save_files=saved_files, **kwargs)
logger.info(f"Model pushed to the hub in this commit: {url}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment