Unverified Commit a04d6858 authored by Wauplin's avatar Wauplin
Browse files

Don't use deprecated Repository anymore

parent 10016fb0
......@@ -19,28 +19,27 @@
import logging
import os
import re
import sys
import shutil
import sys
import time
from multiprocess import set_start_method
from dataclasses import dataclass, field
from datetime import timedelta
import evaluate
from tqdm import tqdm
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Set
from typing import Dict, List, Optional, Set, Union
import datasets
import evaluate
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from huggingface_hub import Repository, create_repo
import transformers
from accelerate import Accelerator
from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, set_seed
from accelerate.utils.memory import release_memory
from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from huggingface_hub import HfApi
from multiprocess import set_start_method
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoFeatureExtractor,
AutoModel,
......@@ -48,26 +47,19 @@ from transformers import (
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments,
pipeline,
)
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.utils import send_example_telemetry
from transformers import AutoModel
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from wandb import Audio
from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
ParlerTTSForConditionalGeneration,
build_delay_pattern_mask,
)
from wandb import Audio
logger = logging.getLogger(__name__)
......@@ -1415,14 +1407,13 @@ def main():
if accelerator.is_main_process:
if training_args.push_to_hub:
# Retrieve of infer repo_name
api = HfApi(token=training_args.hub_token)
# Create repo (repo_name from args or inferred)
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore:
......@@ -1624,9 +1615,11 @@ def main():
unwrapped_model.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(
api.upload_folder(
repo_id=repo_id,
folder_path=training_args.output_dir,
commit_message=f"Saving train state of step {cur_step}",
blocking=False,
run_as_future=True,
)
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment