"vscode:/vscode.git/clone" did not exist on "7cf5d8f77857e1cc64e585f46e2f656ea4eef8ec"
Commit aa4cbf27 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent 9271958b
...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`]. represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24): num_hidden_layers (`int`, *optional*, defaults to 24):
......
...@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences output_ids = outputs.sequences
else: else:
output_ids = outputs output_ids = outputs
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config, inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
...@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values outputs.sequences = output_values
return outputs return outputs
else: else:
return output_values return output_values
\ No newline at end of file
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -67,15 +68,18 @@ class ModelArguments: ...@@ -67,15 +68,18 @@ class ModelArguments:
) )
asr_model_name_or_path: str = field( asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2", default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
clap_model_name_or_path: str = field( clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech", default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
......
...@@ -11,6 +11,7 @@ from tqdm import tqdm ...@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
...@@ -301,4 +302,4 @@ def load_multiple_datasets( ...@@ -301,4 +302,4 @@ def load_multiple_datasets(
with accelerator.main_process_first(): with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets) interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset return interleaved_dataset
\ No newline at end of file
import torch import torch
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline from transformers import AutoModel, AutoProcessor, pipeline
...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu") clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
...@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts] predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
) )
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
\ No newline at end of file
...@@ -21,7 +21,6 @@ import os ...@@ -21,7 +21,6 @@ import os
import re import re
import sys import sys
import time import time
from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from tqdm import tqdm from tqdm import tqdm
...@@ -38,11 +37,7 @@ from huggingface_hub import HfApi ...@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
from multiprocess import set_start_method from multiprocess import set_start_method
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
...@@ -306,9 +301,7 @@ def main(): ...@@ -306,9 +301,7 @@ def main():
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
if model_args.pad_token_id is not None
else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id "decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id, else config.decoder_start_token_id,
...@@ -579,16 +572,18 @@ def main(): ...@@ -579,16 +572,18 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path, word_error, transcriptions = wer(
prompts, model_args.asr_model_name_or_path,
audios, prompts,
device, audios,
training_args.per_device_eval_batch_size, device,
sampling_rate) training_args.per_device_eval_batch_size,
sampling_rate,
)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
...@@ -878,7 +873,9 @@ def main(): ...@@ -878,7 +873,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger) rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps: if cur_step == total_train_steps:
# un-wrap student model for save # un-wrap student model for save
...@@ -1020,4 +1017,4 @@ def main(): ...@@ -1020,4 +1017,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn") set_start_method("spawn")
main() main()
\ No newline at end of file
...@@ -8,6 +8,7 @@ from typing import Dict, List ...@@ -8,6 +8,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
...@@ -121,4 +122,4 @@ def log_pred( ...@@ -121,4 +122,4 @@ def log_pred(
] ]
}, },
step=step, step=step,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment