Commit 0f6d59d4 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

latest changes

parent 11fcc066
......@@ -24,11 +24,11 @@
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 1000,
"max_eval_samples": 200,
"max_train_samples": 20,
"max_eval_samples": 10,
"max_duration_in_seconds": 20,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
......@@ -36,30 +36,36 @@
"preprocessing_num_workers": 1,
"pad_token_id": 2049,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 1,
"num_train_epochs": 120,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"per_device_train_batch_size": 8,
"learning_rate": 1e-6,
"gradient_checkpointing": false,
"per_device_train_batch_size": 2,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.95,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"logging_steps": 25,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"logging_steps": 1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "epoch",
"evaluation_strategy": "steps",
"eval_steps": 600,
"per_device_eval_batch_size": 8,
"generation_max_length": 400,
"fp16": true,
"fp16": false,
"seed": 456,
"dataloader_num_workers":8
......
{
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz",
"description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base",
"push_to_hub": false,
"hub_model_id": "stable-speech-mini",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "/home/yoach/dataspeech/artefacts/training/",
"train_dataset_name": "blabble-io/libritts_r",
"train_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"train_dataset_config_name": "clean",
"train_split_name": "train.clean.360",
"eval_dataset_name": "blabble-io/libritts_r",
"eval_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"eval_dataset_config_name": "clean",
"eval_split_name": "train.clean.360",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 12,
"max_eval_samples": 12,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 20,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10,
"per_device_eval_batch_size": 3,
"generation_max_length": 400,
"do_sample": true,
"logging_steps": 15,
"dtype": "float32",
"seed": 456,
"dataloader_num_workers":8
}
......@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config = StableSpeechDecoderConfig(
max_position_embeddings=2048,
num_hidden_layers=2,
ffn_dim=256,
num_attention_heads=4,
num_hidden_layers=4,
ffn_dim=512,
num_attention_heads=8,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=256,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
hidden_size=512,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
)
# TODO: ?? how to make it stop ?
......@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2049
model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True
model.generation_config.guidance_scale = 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/")
\ No newline at end of file
......@@ -22,10 +22,14 @@ import logging
import os
import re
import sys
import shutil
import warnings
import math
import time
import evaluate
from tqdm import tqdm
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
......@@ -36,17 +40,21 @@ from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from huggingface_hub import Repository, create_repo
import transformers
from transformers import (
AutoFeatureExtractor,
AutoModel,
AutoModelWithLMHead,
AutoProcessor,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.trainer_utils import is_main_process
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
......@@ -57,6 +65,8 @@ from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig
if is_wandb_available():
from wandb import Audio
......@@ -72,6 +82,108 @@ logger = logging.getLogger(__name__)
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path
for path in content
if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
for path in glob_checkpoints:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
"""Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
if len(checkpoints_sorted) <= save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
def log_metric(
accelerator,
metrics: Dict,
train_time: float,
step: int,
epoch: int,
learning_rate: float = None,
prefix: str = "train",
):
"""Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
log_metrics = {}
for k, v in metrics.items():
log_metrics[f"{prefix}/{k}"] = v
log_metrics[f"{prefix}/time"] = train_time
log_metrics[f"{prefix}/epoch"] = epoch
if learning_rate is not None:
log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step)
def log_pred(
accelerator,
pred_descriptions: List[str],
pred_prompts: List[str],
transcriptions: List[str],
audios: List[torch.Tensor],
sampling_rate: int,
step: int,
prefix: str = "eval",
num_lines: int = 200000,
):
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
if accelerator.is_main_process:
wandb_tracker = accelerator.get_tracker("wandb")
# pretty name for current step: step 50000 -> step 50k
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-")
# convert str data to a wandb compatible format
str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step
wandb_tracker.log({
"Speech samples": [
Audio(
audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate,
)
for (i, audio) in enumerate(audios[:min(len(audios), 100)])
]},
step=step)
#### ARGUMENTS
......@@ -134,10 +246,19 @@ class ModelArguments:
default=False,
metadata={"help": "Whether to freeze the text encoder."},
)
do_sample: bool = field(
default=False,
metadata={"help": "Whether to do sampling or greedy decoding."},
)
max_length: int = field(
default=400, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."},
)
@dataclass
class DataSeq2SeqTrainingArguments:
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
......@@ -305,6 +426,22 @@ class DataSeq2SeqTrainingArguments:
"help": "id column name."
}
)
wandb_project: str = field(
default="stable-speech",
metadata={"help": "The name of the wandb project."},
)
@dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"The data type (dtype) in which to run training. One of `float32` (full-precision), "
"`float16` or `bfloat16` (both half-precision)."
)
},
)
@dataclass
class DataCollatorEncodecWithPadding:
......@@ -320,7 +457,6 @@ class DataCollatorEncodecWithPadding:
# different padding methods
audios = [torch.tensor(feature["labels"]).squeeze() for feature in features]
len_audio = [len(audio) for audio in audios]
max_audio = max(len_audio)
input_features = {self.feature_extractor_input_name: audios}
batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True)
......@@ -372,11 +508,15 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
delay_pattern_mask = [torch.tensor(feature["label_delay_pattern_mask"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask = torch.nn.utils.rnn.pad_sequence(delay_pattern_mask,batch_first=True,padding_value=-100)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
batch= {"labels":labels, "label_delay_pattern_mask":delay_pattern_mask, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
......@@ -554,7 +694,7 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataSeq2SeqTrainingArguments, Seq2SeqTrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, StableSpeechTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
......@@ -566,9 +706,24 @@ def main():
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args)
accelerator = Accelerator()
if training_args.dtype == "float16":
mixed_precision = "fp16"
elif training_args.dtype == "bfloat16":
mixed_precision = "bf16"
else:
mixed_precision = "no"
# Detecting last checkpoint.
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
)
accelerator.init_trackers(project_name=data_args.wandb_project)
# Detecting last checkpoint and eventually continue from last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
......@@ -577,11 +732,12 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
......@@ -591,14 +747,20 @@ def main():
)
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
# Log on each process the small summary:
# Log a small summary on each proces
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if accelerator.is_main_process:
# Set the verbosity to info of the Transformers logger (on main process only)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
......@@ -737,7 +899,8 @@ def main():
description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0]
audio_encoder_eos_token_id = config.decoder.pad_token_id
audio_encoder_pad_token_id = config.decoder.pad_token_id
audio_encoder_eos_token_id = config.decoder.eos_token_id
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
......@@ -834,6 +997,7 @@ def main():
# (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(sample, idx):
# (1, codebooks, seq_len)
......@@ -841,19 +1005,23 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_]
# add eos token column
labels = torch.cat([labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
# TODO: remove, only for test
labels = labels[:, :, :(len_)%10+20]
# add bos and eos token column
labels = torch.cat([bos_labels,labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
audio_encoder_bos_token_id,
audio_encoder_eos_token_id,
max_length + num_codebooks)
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_pad_token_id,
max_length=labels.shape[-1] + num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:]
sample["label_delay_pattern_mask"] = delay_pattern_mask[:, 1:]
return sample
# TODO: done multiple times, how to deal with it.
......@@ -869,7 +1037,6 @@ def main():
accelerator.free_memory()
del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available():
......@@ -888,20 +1055,41 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# 6. Next, we can prepare the training.
# Let's use word CLAP similary as our evaluation metric,
# instantiate a data collator and the trainer
# enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# Define evaluation metrics during training, *i.e.* CLAP similarity
# TODO: allow using another CLAP
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech")
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
# TODO add wer with lightweight asr model
metric = evaluate.load("wer")
def clap_similarity(texts, audios):
clap_inputs = clap_processor(text=texts, audios=audios.squeeze(1), padding=True, return_tensors="pt")
def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None))
audio_features = clap.get_audio_features(clap_inputs["input_features"])
......@@ -909,32 +1097,73 @@ def main():
return cosine_sim.mean()
eval_metrics = {"clap": clap_similarity}
def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios])
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts])
return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(pred):
input_ids = pred.inputs
def compute_metrics(audios, descriptions, prompts, device="cpu"):
input_ids = descriptions
input_ids[input_ids==-100] = description_tokenizer.pad_token_id
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
audios = pred.predictions
results = {key: metric(texts, audios) for (key, metric) in eval_metrics.items()}
prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
results = {
"clap": eval_methods["clap"](texts, audios, device)
}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
results["wer"] = word_error
return results
return results, texts, prompts, audios, transcriptions
# Define Training Schedule
# Store some constants
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
train_batch_size = per_device_train_batch_size * accelerator.num_processes
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
total_train_steps = steps_per_epoch * num_epochs
elif training_args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
total_train_steps = int(training_args.max_steps)
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_epochs = sys.maxsize
steps_per_epoch = total_train_steps
if training_args.eval_steps is None:
logger.info(
f"eval_steps is not set, evaluating at the end of each epoch"
)
eval_steps = steps_per_epoch
else:
eval_steps = training_args.eval_steps
# Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW(
params=model.parameters(),
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
)
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
lr_scheduler = get_scheduler(
name=training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
num_training_steps=total_train_steps * accelerator.num_processes,
)
# Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding(
......@@ -943,7 +1172,293 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {total_train_steps}")
# ======================== Training ================================
train_time = 0
train_start = time.time()
steps_trained_progress_bar = tqdm(
range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
)
continue_training = True
epochs_trained = 0
cur_step = 0
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
if accelerator.is_main_process:
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore:
gitignore.write("wandb\n")
elif training_args.output_dir is not None:
os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
if checkpoint is not None:
accelerator.load_state(checkpoint)
# Find num steps and epoch from saved state string pattern
pattern = r"checkpoint-(\d+)-epoch-(\d+)"
match = re.search(pattern, checkpoint)
cur_step = int(match.group(1))
epochs_trained = int(match.group(2))
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {cur_step}")
steps_trained_progress_bar.update(cur_step)
for epoch in range(0, epochs_trained):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
else:
# Currently we don't know how many steps we've taken in the current epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
# This is "good enough" for our purposes but not fully correct
resume_step = None
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
else:
resume_step = None
gen_kwargs = {
"do_sample": model_args.do_sample,
"max_length": model_args.max_length,
}
# TODO: add max_length
# Define gradient update step fn
def train_step(
batch,
):
model.train()
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
# TODO: add CE per codebook
metrics = {"loss": ce_loss}
return ce_loss, metrics
# Define eval fn
def eval_step(batch):
model.eval()
with torch.no_grad():
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
metrics = {"loss": ce_loss}
return metrics
def generate_step(batch):
model.eval()
output_audios = accelerator.unwrap_model(model).generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
train_dataloader = DataLoader(
vectorized_datasets["train"],
collate_fn=data_collator,
batch_size=per_device_train_batch_size,
num_workers=training_args.dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory,
)
train_dataloader = accelerator.prepare(train_dataloader)
if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.set_epoch(epoch)
if resume_step is not None:
# Skip the first N batches in the dataloader when resuming from a checkpoint
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
resume_step = None
for batch in train_dataloader:
with accelerator.accumulate(model):
loss, train_metric = train_step(batch)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Check if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
steps_trained_progress_bar.update(1)
cur_step += 1
if cur_step % training_args.logging_steps == 0:
steps_trained_progress_bar.write(
f"Step... ({cur_step} / {total_train_steps} | Loss:"
f" {train_metric['loss']}, Learning Rate:"
f" {lr_scheduler.get_last_lr()[0]})"
)
log_metric(
accelerator,
metrics=train_metric,
learning_rate=lr_scheduler.get_last_lr()[0],
train_time=train_time + time.time() - train_start,
step=cur_step,
epoch=epoch,
prefix="train",
)
# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
accelerator.save_state(output_dir=intermediate_dir)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
if cur_step == total_train_steps:
# un-wrap student model for save
model = accelerator.unwrap_model(model)
model.save_pretrained(training_args.output_dir)
# re-wrap student model for final eval
model = accelerator.prepare(model)
if training_args.push_to_hub:
repo.push_to_hub(
commit_message=f"Saving train state of step {cur_step}",
blocking=False,
)
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start
model.eval()
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_descriptions = []
eval_prompts = []
eval_start = time.time()
validation_dataloader = DataLoader(
vectorized_datasets["eval"],
collate_fn=data_collator,
batch_size=per_device_eval_batch_size,
drop_last=False,
num_workers=training_args.dataloader_pin_memory,
pin_memory=training_args.dataloader_pin_memory,
)
validation_dataloader = accelerator.prepare(validation_dataloader)
for batch in tqdm(
validation_dataloader,
desc=f"Evaluating...",
position=2,
disable=not accelerator.is_local_main_process,
):
# Model forward
eval_metric = eval_step(batch)
eval_metric = accelerator.gather_for_metrics(eval_metric)
eval_metrics.append(eval_metric)
# generation
if training_args.predict_with_generate:
generated_audios = generate_step(batch)
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"])
)
eval_preds.extend(generated_audios)
eval_descriptions.extend(input_ids)
eval_prompts.extend(prompts)
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = {
key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
}
# compute metrics
metrics_desc = ""
if training_args.predict_with_generate:
metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics(
eval_preds, eval_descriptions, eval_prompts, accelerator.device
)
eval_metrics.update(metric_values)
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
log_pred(
accelerator,
pred_descriptions,
pred_prompts,
transcriptions,
audios,
sampling_rate=sampling_rate,
step=cur_step,
prefix="eval",
)
# Print metrics and update progress bar
steps_trained_progress_bar.write(
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
f" {metrics_desc})"
)
log_metric(
accelerator,
metrics=eval_metrics,
train_time=eval_time,
step=cur_step,
epoch=epoch,
prefix="eval",
)
# flush the train metrics
train_start = time.time()
# break condition
if cur_step == total_train_steps:
continue_training = False
break
if not continue_training:
break
accelerator.end_training()
###########################################################################
# Initialize StableSpeechTrainer
trainer = StableSpeechTrainer(
......@@ -956,10 +1471,16 @@ def main():
tokenizer=feature_extractor,
)
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to and training_args.do_eval:
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
)
def decode_predictions(predictions):
audios = predictions.predictions
return {"audio": np.array(audios.squeeze(1))}
return {"audio": np.array(audios)}
class WandbPredictionProgressCallback(WandbCallback):
......@@ -982,8 +1503,8 @@ def main():
self.description_tokenizer = description_tokenizer
self.sample_dataset = val_dataset.select(range(num_samples))
def on_train_end(self, args, state, control, **kwargs):
super().on_train_end(args, state, control, **kwargs)
def on_evaluate(self, args, state, control, **kwargs):
super().on_evaluate(args, state, control, **kwargs)
predictions = self.trainer.predict(self.sample_dataset)
......@@ -1004,7 +1525,7 @@ def main():
trainer=trainer,
val_dataset=vectorized_datasets["eval"],
description_tokenizer=description_tokenizer,
num_samples=8, # TODO: add to args
num_samples=max_eval_samples,
)
# Add the callback to the trainer
......
......@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
vocab_size (`int`, *optional*, defaults to 2050):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
......@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__(
self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token)
vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
pad_token_id=2049,
pad_token_id=2050,
bos_token_id=2048,
eos_token_id=2049,
tie_word_embeddings=False,
......
......@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1
# TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
......@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
# TODO: delay_pattern_mask
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None
if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask = (labels != -100)
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
......@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None):
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, E, E, E]
- [B, B, -1, -1, -1, -1, E, E]
- [B, B, B, -1, -1, -1, -1, E]
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, E, E, E]
- [B, B, c, d, -1, -1, E, E]
- [B, B, B, e, f, -1, -1, E]
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
......@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
......@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
......@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right(
labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id
)
labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder(
......@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder,
)
......@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
......@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
......@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
# TODO: probably won't work...
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1
)
......@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
).audio_values.squeeze(1)
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0)
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)|(sample == generation_config.pad_token_id)).sum(dim=(0,1)) == 0)
if sample_mask.sum()>0:
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1)
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1)
if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment