Commit aa4cbf27 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 9271958b
......@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
......
......@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
......@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config,
inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
......@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values
return outputs
else:
return output_values
\ No newline at end of file
return output_values
......@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments
@dataclass
class ModelArguments:
"""
......@@ -67,15 +68,18 @@ class ModelArguments:
)
asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
@dataclass
class DataTrainingArguments:
"""
......
......@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator
@dataclass
class DataCollatorEncodecWithPadding:
"""
......@@ -301,4 +302,4 @@ def load_multiple_datasets(
with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
\ No newline at end of file
return interleaved_dataset
import torch
import torch
import evaluate
from transformers import AutoModel, AutoProcessor, pipeline
......@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
......@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
\ No newline at end of file
return word_error, [t["text"] for t in transcriptions]
......@@ -21,7 +21,6 @@ import os
import re
import sys
import time
from dataclasses import dataclass, field
from datetime import timedelta
from tqdm import tqdm
......@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
from multiprocess import set_start_method
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import LengthGroupedSampler
......@@ -306,9 +301,7 @@ def main():
# update pad token id and decoder_start_token_id
config.update(
{
"pad_token_id": model_args.pad_token_id
if model_args.pad_token_id is not None
else config.pad_token_id,
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id,
......@@ -579,16 +572,18 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate)
word_error, transcriptions = wer(
model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate,
)
results["wer"] = word_error
return results, texts, prompts, audios, transcriptions
......@@ -878,7 +873,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger)
rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps:
# un-wrap student model for save
......@@ -1020,4 +1017,4 @@ def main():
if __name__ == "__main__":
set_start_method("spawn")
main()
\ No newline at end of file
main()
......@@ -8,6 +8,7 @@ from typing import Dict, List
import torch
from wandb import Audio
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
......@@ -121,4 +122,4 @@ def log_pred(
]
},
step=step,
)
\ No newline at end of file
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment