Unverified Commit bdb03638 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge pull request #48 from ylacombe/pr/Wauplin/18

Pr/wauplin/18
parents b2b749d1 3f5fd26c
......@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
......
......@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
......@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config,
inputs_tensor,
model_kwargs,
model_input_name,
generation_config,
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
......@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values
return outputs
else:
return output_values
\ No newline at end of file
return output_values
......@@ -3,6 +3,7 @@ from typing import Optional
from transformers import Seq2SeqTrainingArguments
@dataclass
class ModelArguments:
"""
......@@ -67,15 +68,18 @@ class ModelArguments:
)
asr_model_name_or_path: str = field(
default="distil-whisper/distil-large-v2",
metadata={"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
clap_model_name_or_path: str = field(
default="laion/larger_clap_music_and_speech",
metadata={"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
metadata={
"help": "Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
@dataclass
class DataTrainingArguments:
"""
......
......@@ -11,6 +11,7 @@ from tqdm import tqdm
from accelerate import Accelerator
@dataclass
class DataCollatorEncodecWithPadding:
"""
......@@ -301,4 +302,4 @@ def load_multiple_datasets(
with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
\ No newline at end of file
return interleaved_dataset
import torch
import torch
import evaluate
from transformers import AutoModel, AutoProcessor, pipeline
......@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs.to("cpu")
return cosine_sim.mean().to("cpu")
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate):
metric = evaluate.load("wer")
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device)
......@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]
)
return word_error, [t["text"] for t in transcriptions]
\ No newline at end of file
return word_error, [t["text"] for t in transcriptions]
......@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
import datasets
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets
from huggingface_hub import Repository, create_repo
from huggingface_hub import HfApi
import transformers
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
HfArgumentParser
)
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler
from transformers.utils import send_example_telemetry
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from parler_tts import (
ParlerTTSForConditionalGeneration,
ParlerTTSConfig,
ParlerTTSForConditionalGeneration,
build_delay_pattern_mask,
)
......@@ -301,9 +299,7 @@ def main():
# update pad token id and decoder_start_token_id
config.update(
{
"pad_token_id": model_args.pad_token_id
if model_args.pad_token_id is not None
else config.pad_token_id,
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id
if model_args.decoder_start_token_id is not None
else config.decoder_start_token_id,
......@@ -574,16 +570,18 @@ def main():
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
results["clap"] = clap_score
word_error, transcriptions = wer(model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate)
word_error, transcriptions = wer(
model_args.asr_model_name_or_path,
prompts,
audios,
device,
training_args.per_device_eval_batch_size,
sampling_rate,
)
results["wer"] = word_error
return results, texts, prompts, audios, transcriptions
......@@ -673,14 +671,13 @@ def main():
if accelerator.is_main_process:
if training_args.push_to_hub:
# Retrieve of infer repo_name
api = HfApi(token=training_args.hub_token)
# Create repo (repo_name from args or inferred)
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore:
......@@ -874,7 +871,9 @@ def main():
accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger)
rotate_checkpoints(
training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
)
if cur_step == total_train_steps:
# un-wrap student model for save
......@@ -882,9 +881,11 @@ def main():
unwrapped_model.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(
api.upload_folder(
repo_id=repo_id,
folder_path=training_args.output_dir,
commit_message=f"Saving train state of step {cur_step}",
blocking=False,
run_as_future=True,
)
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
......@@ -1014,4 +1015,4 @@ def main():
if __name__ == "__main__":
set_start_method("spawn")
main()
\ No newline at end of file
main()
......@@ -8,6 +8,7 @@ from typing import Dict, List
import torch
from wandb import Audio
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
......@@ -121,4 +122,4 @@ def log_pred(
]
},
step=step,
)
\ No newline at end of file
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment