Unverified Commit 05e72aa0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Adapt repository creation to latest hf_hub (#21158)

* Adapt repository creation to latest hf_hub

* Update all examples

* Fix other tests, add Flax examples

* Address review comments
parent 32525428
...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util ...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoTokenizer, AutoTokenizer,
...@@ -430,7 +430,8 @@ def main(): ...@@ -430,7 +430,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util ...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad from flax.jax_utils import pad_shard_unpad
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
...@@ -502,7 +502,8 @@ def main(): ...@@ -502,7 +502,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util ...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
...@@ -376,7 +376,8 @@ def main(): ...@@ -376,7 +376,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util ...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad from flax.jax_utils import pad_shard_unpad
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
...@@ -416,7 +416,8 @@ def main(): ...@@ -416,7 +416,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util ...@@ -45,7 +45,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad from flax.jax_utils import pad_shard_unpad
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
...@@ -542,7 +542,8 @@ def main(): ...@@ -542,7 +542,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -44,7 +44,7 @@ from flax import struct, traverse_util ...@@ -44,7 +44,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -467,7 +467,8 @@ def main(): ...@@ -467,7 +467,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# region Load Data # region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
......
...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util ...@@ -46,7 +46,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
...@@ -450,7 +450,8 @@ def main(): ...@@ -450,7 +450,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
......
...@@ -39,7 +39,7 @@ from flax import struct, traverse_util ...@@ -39,7 +39,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -350,7 +350,8 @@ def main(): ...@@ -350,7 +350,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
......
...@@ -41,7 +41,7 @@ from flax import struct, traverse_util ...@@ -41,7 +41,7 @@ from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
...@@ -406,7 +406,8 @@ def main(): ...@@ -406,7 +406,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
......
...@@ -43,7 +43,7 @@ from flax import jax_utils ...@@ -43,7 +43,7 @@ from flax import jax_utils
from flax.jax_utils import pad_shard_unpad, unreplicate from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
...@@ -298,7 +298,8 @@ def main(): ...@@ -298,7 +298,8 @@ def main():
) )
else: else:
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
repo = Repository(training_args.output_dir, clone_from=repo_name, token=training_args.hub_token)
# Initialize datasets and pre-processing transforms # Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing # We use torchvision here for faster pre-processing
......
...@@ -40,7 +40,7 @@ import transformers ...@@ -40,7 +40,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoFeatureExtractor, AutoFeatureExtractor,
...@@ -246,7 +246,8 @@ def main(): ...@@ -246,7 +246,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -41,7 +41,7 @@ import transformers ...@@ -41,7 +41,7 @@ import transformers
from accelerate import Accelerator, DistributedType from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -282,7 +282,8 @@ def main(): ...@@ -282,7 +282,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -41,7 +41,7 @@ import transformers ...@@ -41,7 +41,7 @@ import transformers
from accelerate import Accelerator, DistributedType from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -291,7 +291,8 @@ def main(): ...@@ -291,7 +291,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -40,7 +40,7 @@ import transformers ...@@ -40,7 +40,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -317,7 +317,8 @@ def main(): ...@@ -317,7 +317,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -38,7 +38,7 @@ import transformers ...@@ -38,7 +38,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AdamW, AdamW,
DataCollatorWithPadding, DataCollatorWithPadding,
...@@ -332,7 +332,8 @@ def main(): ...@@ -332,7 +332,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -38,7 +38,7 @@ import transformers ...@@ -38,7 +38,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -370,7 +370,8 @@ def main(): ...@@ -370,7 +370,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -36,7 +36,7 @@ import transformers ...@@ -36,7 +36,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository, hf_hub_download from huggingface_hub import Repository, create_repo, hf_hub_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoFeatureExtractor, AutoFeatureExtractor,
...@@ -354,7 +354,8 @@ def main(): ...@@ -354,7 +354,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -31,7 +31,7 @@ from tqdm.auto import tqdm ...@@ -31,7 +31,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AdamW, AdamW,
SchedulerType, SchedulerType,
...@@ -422,7 +422,8 @@ def main(): ...@@ -422,7 +422,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -40,7 +40,7 @@ from accelerate import Accelerator ...@@ -40,7 +40,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from filelock import FileLock from filelock import FileLock
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
...@@ -373,7 +373,8 @@ def main(): ...@@ -373,7 +373,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -32,7 +32,7 @@ import transformers ...@@ -32,7 +32,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository, create_repo
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
...@@ -244,7 +244,8 @@ def main(): ...@@ -244,7 +244,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment