Unverified Commit 2156662d authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[TF] Fix creating a PR while pushing in TF framework (#21968)

* add create pr arg

* style

* add test

* ficup

* update test

* last nit fix typo

* add `is_pt_tf_cross_test` marker for the tsts
parent d128f2ff
......@@ -2905,9 +2905,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_temp_dir: Optional[bool] = None,
commit_message: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Optional[Union[int, str]] = "10GB",
**model_card_kwargs,
use_auth_token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
**base_model_card_args,
) -> str:
"""
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
......@@ -2931,8 +2932,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
will then be each of size lower than this size. If expressed as a string, needs to be digits followed
by a unit (like `"5MB"`).
model_card_kwargs:
Additional keyword arguments passed along to the [`~TFPreTrainedModel.create_model_card`] method.
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
Examples:
......@@ -2948,15 +2949,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
model.push_to_hub("huggingface/my-finetuned-bert")
```
"""
if "repo_path_or_name" in model_card_kwargs:
if "repo_path_or_name" in base_model_card_args:
warnings.warn(
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
"`repo_id` instead."
)
repo_id = model_card_kwargs.pop("repo_path_or_name")
repo_id = base_model_card_args.pop("repo_path_or_name")
# Deprecation warning will be sent after for repo_url and organization
repo_url = model_card_kwargs.pop("repo_url", None)
organization = model_card_kwargs.pop("organization", None)
repo_url = base_model_card_args.pop("repo_url", None)
organization = base_model_card_args.pop("organization", None)
if os.path.isdir(repo_id):
working_dir = repo_id
......@@ -2982,11 +2983,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"output_dir": work_dir,
"model_name": Path(repo_id).name,
}
base_model_card_args.update(model_card_kwargs)
base_model_card_args.update(base_model_card_args)
self.create_model_card(**base_model_card_args)
self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
)
@classmethod
......
......@@ -85,6 +85,7 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFAutoModel,
......@@ -92,6 +93,7 @@ if is_tf_available():
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
TFPreTrainedModel,
TFRagModel,
TFSharedEmbeddings,
)
......@@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
@is_pt_tf_cross_test
def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
......@@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment