Unverified Commit a04d6858 authored by Wauplin's avatar Wauplin
Browse files

Don't use deprecated Repository anymore

parent 10016fb0
...@@ -19,28 +19,27 @@ ...@@ -19,28 +19,27 @@
import logging import logging
import os import os
import re import re
import sys
import shutil import shutil
import sys
import time import time
from multiprocess import set_start_method from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
import evaluate
from tqdm import tqdm
from pathlib import Path from pathlib import Path
from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Union, Set
import datasets import datasets
import evaluate
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from huggingface_hub import Repository, create_repo
import transformers import transformers
from accelerate import Accelerator
from accelerate.utils import AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin, set_seed
from accelerate.utils.memory import release_memory
from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from huggingface_hub import HfApi
from multiprocess import set_start_method
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModel, AutoModel,
...@@ -48,26 +47,19 @@ from transformers import ( ...@@ -48,26 +47,19 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
pipeline,
) )
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import pipeline
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.utils import send_example_telemetry from transformers.utils import send_example_telemetry
from transformers import AutoModel from wandb import Audio
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig, ParlerTTSConfig,
ParlerTTSForConditionalGeneration,
build_delay_pattern_mask, build_delay_pattern_mask,
) )
from wandb import Audio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1415,14 +1407,13 @@ def main(): ...@@ -1415,14 +1407,13 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if training_args.push_to_hub: if training_args.push_to_hub:
# Retrieve of infer repo_name api = HfApi(token=training_args.hub_token)
# Create repo (repo_name from args or inferred)
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
if repo_name is None: if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore: if "wandb" not in gitignore:
...@@ -1624,9 +1615,11 @@ def main(): ...@@ -1624,9 +1615,11 @@ def main():
unwrapped_model.save_pretrained(training_args.output_dir) unwrapped_model.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub( api.upload_folder(
repo_id=repo_id,
folder_path=training_args.output_dir,
commit_message=f"Saving train state of step {cur_step}", commit_message=f"Saving train state of step {cur_step}",
blocking=False, run_as_future=True,
) )
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment