Unverified Commit 6232c380 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Fix `.push_to_hub` and cleanup `get_full_repo_name` usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
parent 400e76ef
......@@ -43,7 +43,7 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
......@@ -240,12 +240,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
......
......@@ -51,7 +51,7 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
......@@ -295,12 +295,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
......
......@@ -52,7 +52,7 @@ from transformers import (
default_data_collator,
get_scheduler,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
......@@ -340,12 +340,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
......
......@@ -29,7 +29,7 @@ import datasets
import torch
from accelerate import Accelerator, DistributedDataParallelKwargs
from datasets import ClassLabel, load_dataset, load_metric
from huggingface_hub import Repository
from huggingface_hub import Repository, create_repo
from luke_utils import DataCollatorForLukeTokenClassification, is_punctuation, padding_tensor
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
......@@ -45,7 +45,6 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils.versions import require_version
......@@ -258,11 +257,14 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
# Retrieve of infer repo_name
repo_name = args.hub_model_id
if repo_name is None:
repo_name = Path(args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
......
......@@ -17,6 +17,8 @@ File utilities: utilities related to download and cache models
This module should not be update anymore and is only left for backward compatibility.
"""
from huggingface_hub import get_full_repo_name # for backward compatibility
from . import __version__
# Backward compatibility imports, to make sure all those objects can be found in file_utils
......@@ -71,7 +73,6 @@ from .utils import (
define_sagemaker_information,
get_cached_models,
get_file_from_repo,
get_full_repo_name,
get_torch_version,
has_file,
http_user_agent,
......
......@@ -12,7 +12,6 @@ from tensorflow.keras.callbacks import Callback
from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary
from .utils import get_full_repo_name
logger = logging.getLogger(__name__)
......@@ -334,14 +333,13 @@ class PushToHubCallback(Callback):
raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
self.save_steps = save_steps
output_dir = Path(output_dir)
# Create repo and retrieve repo_id
if hub_model_id is None:
hub_model_id = output_dir.absolute().name
if "/" not in hub_model_id:
hub_model_id = get_full_repo_name(hub_model_id, token=hub_token)
self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
self.output_dir = output_dir
self.hub_model_id = hub_model_id
create_repo(self.hub_model_id, exist_ok=True)
self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
self.tokenizer = tokenizer
......
......@@ -1357,21 +1357,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"Checkpoint loading failed as no optimizer is attached to the model. "
"This is most likely caused by the model not being compiled."
)
if not os.path.isdir(repo_path_or_name):
if os.path.isdir(repo_path_or_name):
local_dir = repo_path_or_name
else:
# If this isn't a local path, check that the remote repo exists and has a checkpoint in it
repo_files = list_repo_files(repo_path_or_name)
for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
if file not in repo_files:
raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
if "/" not in repo_path_or_name:
model_id = repo_path_or_name
repo_path_or_name = self.get_full_repo_name(repo_path_or_name)
else:
model_id = repo_path_or_name.split("/")[-1]
repo = Repository(model_id, clone_from=f"https://huggingface.co/{repo_path_or_name}")
repo = Repository(repo_path_or_name.split("/")[-1], clone_from=repo_path_or_name)
local_dir = repo.local_dir
else:
local_dir = repo_path_or_name
# Now make sure the repo actually has a checkpoint in it.
checkpoint_dir = os.path.join(local_dir, "checkpoint")
......
......@@ -129,7 +129,6 @@ from .utils import (
WEIGHTS_NAME,
can_return_loss,
find_labels,
get_full_repo_name,
is_accelerate_available,
is_apex_available,
is_datasets_available,
......@@ -3396,22 +3395,22 @@ class Trainer:
"""
if not self.is_world_process_zero():
return
if self.args.hub_model_id is None:
# Make sure the repo exists + retrieve "real" repo_id
repo_name = self.args.hub_model_id
if repo_name is None:
repo_name = Path(self.args.output_dir).absolute().name
else:
repo_name = self.args.hub_model_id
if "/" not in repo_name:
repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
repo_id = create_repo(
repo_id=repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True
).repo_id
# Make sure the repo exists.
create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
try:
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token)
except EnvironmentError:
if self.args.overwrite_output_dir and at_init:
# Try again after wiping output_dir
shutil.rmtree(self.args.output_dir)
self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token)
else:
raise
......
......@@ -24,6 +24,7 @@ from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import get_full_repo_name
from packaging import version
from .debug_utils import DebugOption
......@@ -38,7 +39,6 @@ from .trainer_utils import (
from .utils import (
ExplicitEnum,
cached_property,
get_full_repo_name,
is_accelerate_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
......
......@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import get_full_repo_name # for backward compatibility
from packaging import version
from .. import __version__
......@@ -79,7 +80,6 @@ from .hub import (
extract_commit_hash,
get_cached_models,
get_file_from_repo,
get_full_repo_name,
has_file,
http_user_agent,
is_offline_mode,
......
......@@ -36,7 +36,6 @@ from huggingface_hub import (
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
whoami,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
......@@ -690,6 +689,10 @@ class PushToHubMixin:
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
"instead."
)
if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
)
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
if organization is not None:
warnings.warn(
......@@ -702,11 +705,7 @@ class PushToHubMixin:
repo_id = f"{organization}/{repo_id}"
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
# If the namespace is not there, add it or `upload_file` will complain
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
repo_id = get_full_repo_name(repo_id, token=token)
return repo_id
return url.repo_id
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
"""
......@@ -786,8 +785,7 @@ class PushToHubMixin:
**deprecated_kwargs,
) -> str:
"""
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
`repo_path_or_name`.
Upload the {object_files} to the 🤗 Model Hub.
Parameters:
repo_id (`str`):
......@@ -838,22 +836,35 @@ class PushToHubMixin:
)
token = use_auth_token
if "repo_path_or_name" in deprecated_kwargs:
repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
if repo_path_or_name is not None:
# Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
# repo_id from the folder path, if it exists.
warnings.warn(
"The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
"`repo_id` instead."
"`repo_id` instead.",
FutureWarning,
)
repo_id = deprecated_kwargs.pop("repo_path_or_name")
if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`."
)
if os.path.isdir(repo_path_or_name):
# repo_path: infer repo_id from the path
repo_id = repo_id.split(os.path.sep)[-1]
working_dir = repo_id
else:
# repo_name: use it as repo_id
repo_id = repo_path_or_name
working_dir = repo_id.split("/")[-1]
else:
# Repo_id is passed correctly: infer working_dir from it
working_dir = repo_id.split("/")[-1]
# Deprecation warning will be sent after for repo_url and organization
repo_url = deprecated_kwargs.pop("repo_url", None)
organization = deprecated_kwargs.pop("organization", None)
if os.path.isdir(repo_id):
working_dir = repo_id
repo_id = repo_id.split(os.path.sep)[-1]
else:
working_dir = repo_id.split("/")[-1]
repo_id = self._create_repo(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)
......@@ -877,14 +888,6 @@ class PushToHubMixin:
)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def send_example_telemetry(example_name, *example_args, framework="pytorch"):
"""
Sends telemetry that helps tracking the examples use.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment