Unverified Commit 7a0fccc6 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`Trainer` / tags]: Fix trainer + tags when users do not pass `"tags"` to...

FIX [`Trainer` / tags]: Fix trainer + tags when users do not pass `"tags"` to `trainer.push_to_hub()` (#29009)

* fix trainer tags

* add test
parent 5f06053d
...@@ -3842,7 +3842,10 @@ class Trainer: ...@@ -3842,7 +3842,10 @@ class Trainer:
# Add additional tags in the case the model has already some tags and users pass # Add additional tags in the case the model has already some tags and users pass
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
# from all models since Trainer does not call `model.push_to_hub`. # from all models since Trainer does not call `model.push_to_hub`.
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None: if getattr(self.model, "model_tags", None) is not None:
if "tags" not in kwargs:
kwargs["tags"] = []
# If it is a string, convert it to a list # If it is a string, convert it to a list
if isinstance(kwargs["tags"], str): if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]] kwargs["tags"] = [kwargs["tags"]]
......
...@@ -30,7 +30,7 @@ from typing import Dict, List ...@@ -30,7 +30,7 @@ from typing import Dict, List
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import numpy as np import numpy as np
from huggingface_hub import HfFolder, delete_repo, list_repo_commits, list_repo_files from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -2564,7 +2564,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -2564,7 +2564,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step", "test-trainer-tensorboard"]: for model in [
"test-trainer",
"test-trainer-epoch",
"test-trainer-step",
"test-trainer-tensorboard",
"test-trainer-tags",
]:
try: try:
delete_repo(token=cls._token, repo_id=model) delete_repo(token=cls._token, repo_id=model)
except HTTPError: except HTTPError:
...@@ -2695,6 +2701,31 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -2695,6 +2701,31 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
assert found_log is True, "No tensorboard log found in repo" assert found_log is True, "No tensorboard log found in repo"
def test_push_to_hub_tags(self):
# Checks if `trainer.push_to_hub()` works correctly by adding the desired
# tag without having to pass `tags` in `push_to_hub`
# see:
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-tags"),
push_to_hub=True,
hub_token=self._token,
)
trainer.model.add_model_tags(["test-trainer-tags"])
url = trainer.push_to_hub()
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
self.assertTrue(re_search is not None)
repo_name = re_search.groups()[0]
self.assertEqual(repo_name, f"{USER}/test-trainer-tags")
model_card = ModelCard.load(repo_name)
self.assertTrue("test-trainer-tags" in model_card.data.tags)
@require_torch @require_torch
@require_optuna @require_optuna
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment