Unverified Commit 6232c380 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Fix `.push_to_hub` and cleanup `get_full_repo_name` usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
parent 400e76ef
...@@ -43,7 +43,7 @@ from transformers import ( ...@@ -43,7 +43,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -240,12 +240,14 @@ def main(): ...@@ -240,12 +240,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -51,7 +51,7 @@ from transformers import ( ...@@ -51,7 +51,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -295,12 +295,14 @@ def main(): ...@@ -295,12 +295,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -52,7 +52,7 @@ from transformers import ( ...@@ -52,7 +52,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -340,12 +340,14 @@ def main(): ...@@ -340,12 +340,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -29,7 +29,7 @@ import datasets ...@@ -29,7 +29,7 @@ import datasets
import torch import torch
from accelerate import Accelerator, DistributedDataParallelKwargs from accelerate import Accelerator, DistributedDataParallelKwargs
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset, load_metric
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -45,7 +45,6 @@ from transformers import ( ...@@ -45,7 +45,6 @@ from transformers import (
get_scheduler, get_scheduler,
set_seed, set_seed,
) )
from transformers.file_utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -258,11 +257,14 @@ def main(): ...@@ -258,11 +257,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
repo = Repository(args.output_dir, clone_from=repo_name) # Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -17,6 +17,8 @@ File utilities: utilities related to download and cache models ...@@ -17,6 +17,8 @@ File utilities: utilities related to download and cache models
This module should not be update anymore and is only left for backward compatibility. This module should not be update anymore and is only left for backward compatibility.
""" """
from huggingface_hub import get_full_repo_name # for backward compatibility
from . import __version__ from . import __version__
# Backward compatibility imports, to make sure all those objects can be found in file_utils # Backward compatibility imports, to make sure all those objects can be found in file_utils
...@@ -71,7 +73,6 @@ from .utils import ( ...@@ -71,7 +73,6 @@ from .utils import (
define_sagemaker_information, define_sagemaker_information,
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_full_repo_name,
get_torch_version, get_torch_version,
has_file, has_file,
http_user_agent, http_user_agent,
......
...@@ -12,7 +12,6 @@ from tensorflow.keras.callbacks import Callback ...@@ -12,7 +12,6 @@ from tensorflow.keras.callbacks import Callback
from . import IntervalStrategy, PreTrainedTokenizerBase from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .utils import get_full_repo_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -334,14 +333,13 @@ class PushToHubCallback(Callback): ...@@ -334,14 +333,13 @@ class PushToHubCallback(Callback):
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!") raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
self.save_steps = save_steps self.save_steps = save_steps
output_dir = Path(output_dir) output_dir = Path(output_dir)
# Create repo and retrieve repo_id
if hub_model_id is None: if hub_model_id is None:
hub_model_id = output_dir.absolute().name hub_model_id = output_dir.absolute().name
if "/" not in hub_model_id: self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
hub_model_id = get_full_repo_name(hub_model_id, token=hub_token)
self.output_dir = output_dir self.output_dir = output_dir
self.hub_model_id = hub_model_id
create_repo(self.hub_model_id, exist_ok=True)
self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token) self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
self.tokenizer = tokenizer self.tokenizer = tokenizer
......
...@@ -1357,21 +1357,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1357,21 +1357,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"Checkpoint loading failed as no optimizer is attached to the model. " "Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled." "This is most likely caused by the model not being compiled."
) )
if not os.path.isdir(repo_path_or_name): if os.path.isdir(repo_path_or_name):
local_dir = repo_path_or_name
else:
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it # If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files = list_repo_files(repo_path_or_name) repo_files = list_repo_files(repo_path_or_name)
for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"): for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
if file not in repo_files: if file not in repo_files:
raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!") raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
if "/" not in repo_path_or_name: repo = Repository(repo_path_or_name.split("/")[-1], clone_from=repo_path_or_name)
model_id = repo_path_or_name
repo_path_or_name = self.get_full_repo_name(repo_path_or_name)
else:
model_id = repo_path_or_name.split("/")[-1]
repo = Repository(model_id, clone_from=f"https://huggingface.co/{repo_path_or_name}")
local_dir = repo.local_dir local_dir = repo.local_dir
else:
local_dir = repo_path_or_name
# Now make sure the repo actually has a checkpoint in it. # Now make sure the repo actually has a checkpoint in it.
checkpoint_dir = os.path.join(local_dir, "checkpoint") checkpoint_dir = os.path.join(local_dir, "checkpoint")
......
...@@ -129,7 +129,6 @@ from .utils import ( ...@@ -129,7 +129,6 @@ from .utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
can_return_loss, can_return_loss,
find_labels, find_labels,
get_full_repo_name,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
...@@ -3396,22 +3395,22 @@ class Trainer: ...@@ -3396,22 +3395,22 @@ class Trainer:
""" """
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
if self.args.hub_model_id is None:
# Make sure the repo exists + retrieve "real" repo_id
repo_name = self.args.hub_model_id
if repo_name is None:
repo_name = Path(self.args.output_dir).absolute().name repo_name = Path(self.args.output_dir).absolute().name
else: repo_id = create_repo(
repo_name = self.args.hub_model_id repo_id=repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True
if "/" not in repo_name: ).repo_id
repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
# Make sure the repo exists.
create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
try: try:
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token)
except EnvironmentError: except EnvironmentError:
if self.args.overwrite_output_dir and at_init: if self.args.overwrite_output_dir and at_init:
# Try again after wiping output_dir # Try again after wiping output_dir
shutil.rmtree(self.args.output_dir) shutil.rmtree(self.args.output_dir)
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token)
else: else:
raise raise
......
...@@ -24,6 +24,7 @@ from enum import Enum ...@@ -24,6 +24,7 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from huggingface_hub import get_full_repo_name
from packaging import version from packaging import version
from .debug_utils import DebugOption from .debug_utils import DebugOption
...@@ -38,7 +39,6 @@ from .trainer_utils import ( ...@@ -38,7 +39,6 @@ from .trainer_utils import (
from .utils import ( from .utils import (
ExplicitEnum, ExplicitEnum,
cached_property, cached_property,
get_full_repo_name,
is_accelerate_available, is_accelerate_available,
is_safetensors_available, is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from huggingface_hub import get_full_repo_name # for backward compatibility
from packaging import version from packaging import version
from .. import __version__ from .. import __version__
...@@ -79,7 +80,6 @@ from .hub import ( ...@@ -79,7 +80,6 @@ from .hub import (
extract_commit_hash, extract_commit_hash,
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_full_repo_name,
has_file, has_file,
http_user_agent, http_user_agent,
is_offline_mode, is_offline_mode,
......
...@@ -36,7 +36,6 @@ from huggingface_hub import ( ...@@ -36,7 +36,6 @@ from huggingface_hub import (
get_hf_file_metadata, get_hf_file_metadata,
hf_hub_download, hf_hub_download,
hf_hub_url, hf_hub_url,
whoami,
) )
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import ( from huggingface_hub.utils import (
...@@ -690,6 +689,10 @@ class PushToHubMixin: ...@@ -690,6 +689,10 @@ class PushToHubMixin:
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` " "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
"instead." "instead."
) )
if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
)
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "") repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
if organization is not None: if organization is not None:
warnings.warn( warnings.warn(
...@@ -702,11 +705,7 @@ class PushToHubMixin: ...@@ -702,11 +705,7 @@ class PushToHubMixin:
repo_id = f"{organization}/{repo_id}" repo_id = f"{organization}/{repo_id}"
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True) url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
return url.repo_id
# If the namespace is not there, add it or `upload_file` will complain
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
repo_id = get_full_repo_name(repo_id, token=token)
return repo_id
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]): def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
""" """
...@@ -786,8 +785,7 @@ class PushToHubMixin: ...@@ -786,8 +785,7 @@ class PushToHubMixin:
**deprecated_kwargs, **deprecated_kwargs,
) -> str: ) -> str:
""" """
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in Upload the {object_files} to the 🤗 Model Hub.
`repo_path_or_name`.
Parameters: Parameters:
repo_id (`str`): repo_id (`str`):
...@@ -838,22 +836,35 @@ class PushToHubMixin: ...@@ -838,22 +836,35 @@ class PushToHubMixin:
) )
token = use_auth_token token = use_auth_token
if "repo_path_or_name" in deprecated_kwargs: repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
if repo_path_or_name is not None:
# Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
# repo_id from the folder path, if it exists.
warnings.warn( warnings.warn(
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use " "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
"`repo_id` instead." "`repo_id` instead.",
FutureWarning,
) )
repo_id = deprecated_kwargs.pop("repo_path_or_name") if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`."
)
if os.path.isdir(repo_path_or_name):
# repo_path: infer repo_id from the path
repo_id = repo_id.split(os.path.sep)[-1]
working_dir = repo_id
else:
# repo_name: use it as repo_id
repo_id = repo_path_or_name
working_dir = repo_id.split("/")[-1]
else:
# Repo_id is passed correctly: infer working_dir from it
working_dir = repo_id.split("/")[-1]
# Deprecation warning will be sent after for repo_url and organization # Deprecation warning will be sent after for repo_url and organization
repo_url = deprecated_kwargs.pop("repo_url", None) repo_url = deprecated_kwargs.pop("repo_url", None)
organization = deprecated_kwargs.pop("organization", None) organization = deprecated_kwargs.pop("organization", None)
if os.path.isdir(repo_id):
working_dir = repo_id
repo_id = repo_id.split(os.path.sep)[-1]
else:
working_dir = repo_id.split("/")[-1]
repo_id = self._create_repo( repo_id = self._create_repo(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization repo_id, private=private, token=token, repo_url=repo_url, organization=organization
) )
...@@ -877,14 +888,6 @@ class PushToHubMixin: ...@@ -877,14 +888,6 @@ class PushToHubMixin:
) )
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def send_example_telemetry(example_name, *example_args, framework="pytorch"): def send_example_telemetry(example_name, *example_args, framework="pytorch"):
""" """
Sends telemetry that helps tracking the examples use. Sends telemetry that helps tracking the examples use.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment