Unverified Commit 6232c380 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Fix `.push_to_hub` and cleanup `get_full_repo_name` usage (#25120)

* Fix .push_to_hub and cleanup get_full_repo_name usage

* Do not rely on Python bool conversion magic

* request changes
parent 400e76ef
...@@ -53,7 +53,7 @@ from transformers import ( ...@@ -53,7 +53,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry from transformers.utils import is_offline_mode, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -424,14 +424,14 @@ def main(): ...@@ -424,14 +424,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -59,7 +59,7 @@ from transformers import ( ...@@ -59,7 +59,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.models.bart.modeling_flax_bart import shift_tokens_right from transformers.models.bart.modeling_flax_bart import shift_tokens_right
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
...@@ -496,14 +496,14 @@ def main(): ...@@ -496,14 +496,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -58,7 +58,7 @@ from transformers import ( ...@@ -58,7 +58,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -372,14 +372,14 @@ def main(): ...@@ -372,14 +372,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -59,7 +59,7 @@ from transformers import ( ...@@ -59,7 +59,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
...@@ -410,14 +410,14 @@ def main(): ...@@ -410,14 +410,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -59,7 +59,7 @@ from transformers import ( ...@@ -59,7 +59,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
...@@ -537,14 +537,14 @@ def main(): ...@@ -537,14 +537,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -55,7 +55,7 @@ from transformers import ( ...@@ -55,7 +55,7 @@ from transformers import (
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -462,14 +462,14 @@ def main(): ...@@ -462,14 +462,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# region Load Data # region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
......
...@@ -56,7 +56,7 @@ from transformers import ( ...@@ -56,7 +56,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry from transformers.utils import is_offline_mode, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -452,14 +452,14 @@ def main(): ...@@ -452,14 +452,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -49,7 +49,7 @@ from transformers import ( ...@@ -49,7 +49,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -342,14 +342,14 @@ def main(): ...@@ -342,14 +342,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
......
...@@ -49,7 +49,7 @@ from transformers import ( ...@@ -49,7 +49,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
is_tensorboard_available, is_tensorboard_available,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -398,14 +398,14 @@ def main(): ...@@ -398,14 +398,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
......
...@@ -54,7 +54,7 @@ from transformers import ( ...@@ -54,7 +54,7 @@ from transformers import (
is_tensorboard_available, is_tensorboard_available,
set_seed, set_seed,
) )
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -293,14 +293,14 @@ def main(): ...@@ -293,14 +293,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if training_args.push_to_hub: if training_args.push_to_hub:
if training_args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name( repo_name = training_args.hub_model_id
Path(training_args.output_dir).absolute().name, token=training_args.hub_token if repo_name is None:
) repo_name = Path(training_args.output_dir).absolute().name
else: # Create repo and retrieve repo_id
repo_name = training_args.hub_model_id repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
create_repo(repo_name, exist_ok=True, token=training_args.hub_token) # Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token) repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
# Initialize datasets and pre-processing transforms # Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing # We use torchvision here for faster pre-processing
......
...@@ -42,7 +42,7 @@ from tqdm.auto import tqdm ...@@ -42,7 +42,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -236,12 +236,14 @@ def main(): ...@@ -236,12 +236,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -25,7 +25,7 @@ import torch ...@@ -25,7 +25,7 @@ import torch
from accelerate import Accelerator, DistributedType from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed from accelerate.utils import set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -41,7 +41,7 @@ from transformers import ( ...@@ -41,7 +41,7 @@ from transformers import (
SchedulerType, SchedulerType,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -406,11 +406,14 @@ def main(): ...@@ -406,11 +406,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
repo = Repository(args.output_dir, clone_from=repo_name) # Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -52,7 +52,7 @@ from transformers import ( ...@@ -52,7 +52,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -286,12 +286,14 @@ def main(): ...@@ -286,12 +286,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -52,7 +52,7 @@ from transformers import ( ...@@ -52,7 +52,7 @@ from transformers import (
SchedulerType, SchedulerType,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -295,12 +295,14 @@ def main(): ...@@ -295,12 +295,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -52,7 +52,7 @@ from transformers import ( ...@@ -52,7 +52,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import PaddingStrategy, check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import PaddingStrategy, check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
...@@ -313,12 +313,14 @@ def main(): ...@@ -313,12 +313,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -51,7 +51,7 @@ from transformers import ( ...@@ -51,7 +51,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -328,12 +328,14 @@ def main(): ...@@ -328,12 +328,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -52,7 +52,7 @@ from transformers import ( ...@@ -52,7 +52,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -366,12 +366,14 @@ def main(): ...@@ -366,12 +366,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -45,7 +45,7 @@ from transformers import ( ...@@ -45,7 +45,7 @@ from transformers import (
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -350,12 +350,14 @@ def main(): ...@@ -350,12 +350,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -43,7 +43,7 @@ from transformers import ( ...@@ -43,7 +43,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from transformers.utils import get_full_repo_name, send_example_telemetry from transformers.utils import send_example_telemetry
logger = get_logger(__name__) logger = get_logger(__name__)
...@@ -418,12 +418,14 @@ def main(): ...@@ -418,12 +418,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub and not args.preprocessing_only: if args.push_to_hub and not args.preprocessing_only:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -51,7 +51,7 @@ from transformers import ( ...@@ -51,7 +51,7 @@ from transformers import (
SchedulerType, SchedulerType,
get_scheduler, get_scheduler,
) )
from transformers.utils import check_min_version, get_full_repo_name, is_offline_mode, send_example_telemetry from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -360,12 +360,14 @@ def main(): ...@@ -360,12 +360,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.push_to_hub:
if args.hub_model_id is None: # Retrieve of infer repo_name
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = args.hub_model_id
else: if repo_name is None:
repo_name = args.hub_model_id repo_name = Path(args.output_dir).absolute().name
create_repo(repo_name, exist_ok=True, token=args.hub_token) # Create repo and retrieve repo_id
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) repo_id = create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
# Clone repo locally
repo = Repository(args.output_dir, clone_from=repo_id, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment