Unverified Commit bdb03638 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #48 from ylacombe/pr/Wauplin/18

Pr/wauplin/18
parents b2b749d1 3f5fd26c
...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config, inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
......
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -67,15 +68,18 @@ class ModelArguments: ...@@ -67,15 +68,18 @@ class ModelArguments:
) )
asr_model_name_or_path: str = field( asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2", default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
clap_model_name_or_path: str = field( clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech", default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
......
...@@ -11,6 +11,7 @@ from tqdm import tqdm ...@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
......
...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu") clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
......
...@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader ...@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
import datasets import datasets
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
from huggingface_hub import Repository, create_repo from huggingface_hub import HfApi
import transformers import transformers
from transformers import ( from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry from transformers.utils import send_example_telemetry
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig, ParlerTTSConfig,
ParlerTTSForConditionalGeneration,
build_delay_pattern_mask, build_delay_pattern_mask,
) )
...@@ -301,9 +299,7 @@ def main(): ...@@ -301,9 +299,7 @@ def main():
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
if model_args.pad_token_id is not None
else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id "decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id, else config.decoder_start_token_id,
...@@ -578,12 +574,14 @@ def main(): ...@@ -578,12 +574,14 @@ def main():
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path, word_error, transcriptions = wer(
model_args.asr_model_name_or_path,
prompts, prompts,
audios, audios,
device, device,
training_args.per_device_eval_batch_size, training_args.per_device_eval_batch_size,
sampling_rate) sampling_rate,
)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
...@@ -673,14 +671,13 @@ def main(): ...@@ -673,14 +671,13 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if training_args.push_to_hub: if training_args.push_to_hub:
# Retrieve of infer repo_name api = HfApi(token=training_args.hub_token)
# Create repo (repo_name from args or inferred)
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
if repo_name is None: if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore: if "wandb" not in gitignore:
...@@ -874,7 +871,9 @@ def main(): ...@@ -874,7 +871,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger) rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps: if cur_step == total_train_steps:
# un-wrap student model for save # un-wrap student model for save
...@@ -882,9 +881,11 @@ def main(): ...@@ -882,9 +881,11 @@ def main():
unwrapped_model.save_pretrained(training_args.output_dir) unwrapped_model.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub( api.upload_folder(
repo_id=repo_id,
folder_path=training_args.output_dir,
commit_message=f"Saving train state of step {cur_step}", commit_message=f"Saving train state of step {cur_step}",
blocking=False, run_as_future=True,
) )
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
......
...@@ -8,6 +8,7 @@ from typing import Dict, List ...@@ -8,6 +8,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment