Unverified Commit 0d0c392c authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

CLI: use hub's `create_commit` (#17755)

* use create_commit

* better commit message and description

* touch setup.py to trigger cache update

* add hub version gating
parent c366ce10
...@@ -27,7 +27,7 @@ jobs: ...@@ -27,7 +27,7 @@ jobs:
id: cache id: cache
with: with:
path: ~/venv/ path: ~/venv/
key: v3-tests_model_like-${{ hashFiles('setup.py') }} key: v4-tests_model_like-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss - name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
......
...@@ -21,7 +21,7 @@ jobs: ...@@ -21,7 +21,7 @@ jobs:
id: cache id: cache
with: with:
path: ~/venv/ path: ~/venv/
key: v3-tests_templates-${{ hashFiles('setup.py') }} key: v4-tests_templates-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss - name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
......
...@@ -21,7 +21,7 @@ jobs: ...@@ -21,7 +21,7 @@ jobs:
id: cache id: cache
with: with:
path: ~/venv/ path: ~/venv/
key: v2-metadata-${{ hashFiles('setup.py') }} key: v3-metadata-${{ hashFiles('setup.py') }}
- name: Create virtual environment on cache miss - name: Create virtual environment on cache miss
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
......
...@@ -18,8 +18,9 @@ from importlib import import_module ...@@ -18,8 +18,9 @@ from importlib import import_module
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
from packaging import version
from huggingface_hub import Repository, upload_file import huggingface_hub
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging from ..utils import logging
...@@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace): ...@@ -45,7 +46,9 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand Returns: ServeCommand
""" """
return PTtoTFCommand(args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push) return PTtoTFCommand(
args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description
)
class PTtoTFCommand(BaseTransformersCLICommand): class PTtoTFCommand(BaseTransformersCLICommand):
...@@ -89,6 +92,12 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -89,6 +92,12 @@ class PTtoTFCommand(BaseTransformersCLICommand):
action="store_true", action="store_true",
help="Optional flag to push the weights directly to `main` (requires permissions)", help="Optional flag to push the weights directly to `main` (requires permissions)",
) )
train_parser.add_argument(
"--extra-commit-description",
type=str,
default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
)
train_parser.set_defaults(func=convert_command_factory) train_parser.set_defaults(func=convert_command_factory)
@staticmethod @staticmethod
...@@ -134,13 +143,23 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -134,13 +143,23 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return _find_pt_tf_differences(pt_outputs, tf_outputs, {}) return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
def __init__(self, model_name: str, local_dir: str, new_weights: bool, no_pr: bool, push: bool, *args): def __init__(
self,
model_name: str,
local_dir: str,
new_weights: bool,
no_pr: bool,
push: bool,
extra_commit_description: str,
*args
):
self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._new_weights = new_weights self._new_weights = new_weights
self._no_pr = no_pr self._no_pr = no_pr
self._push = push self._push = push
self._extra_commit_description = extra_commit_description
def get_text_inputs(self): def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir) tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
...@@ -170,10 +189,17 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -170,10 +189,17 @@ class PTtoTFCommand(BaseTransformersCLICommand):
return pt_input, tf_input return pt_input, tf_input
def run(self): def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"):
raise ImportError(
"The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub"
" installation."
)
else:
from huggingface_hub import Repository, create_commit
from huggingface_hub._commit_api import CommitOperationAdd
# Fetch remote data # Fetch remote data
# TODO: implement a solution to pull a specific PR/commit, so we can use this CLI to validate pushes.
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir) config = AutoConfig.from_pretrained(self._local_dir)
...@@ -240,32 +266,29 @@ class PTtoTFCommand(BaseTransformersCLICommand): ...@@ -240,32 +266,29 @@ class PTtoTFCommand(BaseTransformersCLICommand):
) )
) )
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
if self._push: if self._push:
repo.git_add(auto_lfs_track=True) repo.git_add(auto_lfs_track=True)
repo.git_commit("Add TF weights") repo.git_commit(commit_message)
repo.git_push(blocking=True) # this prints a progress bar with the upload repo.git_push(blocking=True) # this prints a progress bar with the upload
self._logger.warn(f"TF weights pushed into {self._model_name}") self._logger.warn(f"TF weights pushed into {self._model_name}")
elif not self._no_pr: elif not self._no_pr:
# TODO: remove try/except when the upload to PR feature is released
# (https://github.com/huggingface/huggingface_hub/pull/884)
try:
self._logger.warn("Uploading the weights into a new PR...") self._logger.warn("Uploading the weights into a new PR...")
hub_pr_url = upload_file( commit_descrition = (
path_or_fileobj=tf_weights_path, "Model converted by the [`transformers`' `pt_to_tf`"
path_in_repo=TF_WEIGHTS_NAME, " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)."
"\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart."
f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
)
if self._extra_commit_description:
commit_descrition += "\n\n" + self._extra_commit_description
hub_pr_url = create_commit(
repo_id=self._model_name, repo_id=self._model_name,
operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)],
commit_message=commit_message,
commit_description=commit_descrition,
repo_type="model",
create_pr=True, create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f" difference={max_crossload_diff:.3e}; Maximum converted output"
f" difference={max_conversion_diff:.3e}."
),
) )
self._logger.warn(f"PR open in {hub_pr_url}") self._logger.warn(f"PR open in {hub_pr_url}")
except TypeError:
self._logger.warn(
f"You can now open a PR in https://huggingface.co/{self._model_name}/discussions, manually"
f" uploading the file in {tf_weights_path}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment