Unverified Commit bdb03638 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #48 from ylacombe/pr/Wauplin/18

Pr/wauplin/18
parents b2b749d1 3f5fd26c
...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`]. represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24): num_hidden_layers (`int`, *optional*, defaults to 24):
......
...@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences output_ids = outputs.sequences
else: else:
output_ids = outputs output_ids = outputs
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs` # encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation( model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config, inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs: if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
...@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values outputs.sequences = output_values
return outputs return outputs
else: else:
return output_values return output_values
\ No newline at end of file
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -67,15 +68,18 @@ class ModelArguments: ...@@ -67,15 +68,18 @@ class ModelArguments:
) )
asr_model_name_or_path: str = field( asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2", default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
clap_model_name_or_path: str = field( clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech", default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"} metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
) )
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
......
...@@ -11,6 +11,7 @@ from tqdm import tqdm ...@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
""" """
...@@ -301,4 +302,4 @@ def load_multiple_datasets( ...@@ -301,4 +302,4 @@ def load_multiple_datasets(
with accelerator.main_process_first(): with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets) interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset return interleaved_dataset
\ No newline at end of file
import torch import torch
import evaluate import evaluate
from transformers import AutoModel, AutoProcessor, pipeline from transformers import AutoModel, AutoProcessor, pipeline
...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device): ...@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu") clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu") return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer") metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
...@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s ...@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts] predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
) )
return word_error, [t["text"] for t in transcriptions] return word_error, [t["text"] for t in transcriptions]
\ No newline at end of file
...@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader ...@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
import datasets import datasets
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
from huggingface_hub import Repository, create_repo from huggingface_hub import HfApi
import transformers import transformers
from transformers import ( from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry from transformers.utils import send_example_telemetry
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory from accelerate.utils.memory import release_memory
from parler_tts import ( from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig, ParlerTTSConfig,
ParlerTTSForConditionalGeneration,
build_delay_pattern_mask, build_delay_pattern_mask,
) )
...@@ -301,9 +299,7 @@ def main(): ...@@ -301,9 +299,7 @@ def main():
# update pad token id and decoder_start_token_id # update pad token id and decoder_start_token_id
config.update( config.update(
{ {
"pad_token_id": model_args.pad_token_id "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
if model_args.pad_token_id is not None
else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id "decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id, else config.decoder_start_token_id,
...@@ -574,16 +570,18 @@ def main(): ...@@ -574,16 +570,18 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True) prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios] audios = [a.cpu().numpy() for a in audios]
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device) clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path, word_error, transcriptions = wer(
prompts, model_args.asr_model_name_or_path,
audios, prompts,
device, audios,
training_args.per_device_eval_batch_size, device,
sampling_rate) training_args.per_device_eval_batch_size,
sampling_rate,
)
results["wer"] = word_error results["wer"] = word_error
return results, texts, prompts, audios, transcriptions return results, texts, prompts, audios, transcriptions
...@@ -673,14 +671,13 @@ def main(): ...@@ -673,14 +671,13 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
if training_args.push_to_hub: if training_args.push_to_hub:
# Retrieve of infer repo_name api = HfApi(token=training_args.hub_token)
# Create repo (repo_name from args or inferred)
repo_name = training_args.hub_model_id repo_name = training_args.hub_model_id
if repo_name is None: if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore: if "wandb" not in gitignore:
...@@ -874,7 +871,9 @@ def main(): ...@@ -874,7 +871,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False) accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger) rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps: if cur_step == total_train_steps:
# un-wrap student model for save # un-wrap student model for save
...@@ -882,9 +881,11 @@ def main(): ...@@ -882,9 +881,11 @@ def main():
unwrapped_model.save_pretrained(training_args.output_dir) unwrapped_model.save_pretrained(training_args.output_dir)
if training_args.push_to_hub: if training_args.push_to_hub:
repo.push_to_hub( api.upload_folder(
repo_id=repo_id,
folder_path=training_args.output_dir,
commit_message=f"Saving train state of step {cur_step}", commit_message=f"Saving train state of step {cur_step}",
blocking=False, run_as_future=True,
) )
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
...@@ -1014,4 +1015,4 @@ def main(): ...@@ -1014,4 +1015,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
set_start_method("spawn") set_start_method("spawn")
main() main()
\ No newline at end of file
...@@ -8,6 +8,7 @@ from typing import Dict, List ...@@ -8,6 +8,7 @@ from typing import Dict, List
import torch import torch
from wandb import Audio from wandb import Audio
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
...@@ -121,4 +122,4 @@ def log_pred( ...@@ -121,4 +122,4 @@ def log_pred(
] ]
}, },
step=step, step=step,
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment