Commit aa4cbf27 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 9271958b
...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config, inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
......
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -67,15 +68,18 @@ class ModelArguments: ...@@ -67,15 +68,18 @@ class ModelArguments:
) )
asr_model_name_or_path: str = field( asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2", default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
clap_model_name_or_path: str = field( clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech", default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
......
...@@ -11,6 +11,7 @@ from tqdm import tqdm ...@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
......
...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu") clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
......
...@@ -21,7 +21,6 @@ import os ...@@ -21,7 +21,6 @@ import os
import re import re
import sys import sys
import time import time
from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from tqdm import tqdm from tqdm import tqdm
...@@ -38,11 +37,7 @@ from huggingface_hub import HfApi ...@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
from multiprocess import set_start_method from multiprocess import set_start_method
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
...@@ -306,9 +301,7 @@ def main(): ...@@ -306,9 +301,7 @@ def main():
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
if model_args.pad_token_id is not None
else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id "decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id, else config.decoder_start_token_id,
...@@ -583,12 +576,14 @@ def main(): ...@@ -583,12 +576,14 @@ def main():
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path, word_error, transcriptions = wer(
model_args.asr_model_name_or_path,
prompts, prompts,
audios, audios,
device, device,
training_args.per_device_eval_batch_size, training_args.per_device_eval_batch_size,
sampling_rate) sampling_rate,
)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
...@@ -878,7 +873,9 @@ def main(): ...@@ -878,7 +873,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger) rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps: if cur_step == total_train_steps:
# un-wrap student model for save # un-wrap student model for save
......
...@@ -8,6 +8,7 @@ from typing import Dict, List ...@@ -8,6 +8,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment