Commit 0f6d59d4 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

latest changes

parent 11fcc066
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
"description_column_name": "text_description", "description_column_name": "text_description",
"prompt_column_name": "text", "prompt_column_name": "text",
"max_train_samples": 1000, "max_train_samples": 20,
"max_eval_samples": 200, "max_eval_samples": 10,
"max_duration_in_seconds": 20, "max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0, "min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true, "add_audio_samples_to_wandb": true,
...@@ -36,30 +36,36 @@ ...@@ -36,30 +36,36 @@
"preprocessing_num_workers": 1, "preprocessing_num_workers": 1,
"pad_token_id": 2049, "pad_token_id": 2050,
"decoder_start_token_id": 2048, "decoder_start_token_id": 2048,
"do_train": true, "do_train": true,
"num_train_epochs": 1, "num_train_epochs": 120,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"gradient_checkpointing": true, "gradient_checkpointing": false,
"per_device_train_batch_size": 8, "per_device_train_batch_size": 2,
"learning_rate": 1e-6, "learning_rate": 1e-3,
"adam_beta1": 0.9, "adam_beta1": 0.9,
"adam_beta2": 0.95, "adam_beta2": 0.999,
"weight_decay": 0.1, "weight_decay": 0.1,
"logging_steps": 25, "lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"logging_steps": 1,
"freeze_text_encoder": true,
"do_eval": true, "do_eval": true,
"predict_with_generate": true, "predict_with_generate": true,
"include_inputs_for_metrics": true, "include_inputs_for_metrics": true,
"evaluation_strategy": "epoch", "evaluation_strategy": "steps",
"eval_steps": 600,
"per_device_eval_batch_size": 8, "per_device_eval_batch_size": 8,
"generation_max_length": 400, "generation_max_length": 400,
"fp16": true, "fp16": false,
"seed": 456, "seed": 456,
"dataloader_num_workers":8 "dataloader_num_workers":8
......
{
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz",
"description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base",
"push_to_hub": false,
"hub_model_id": "stable-speech-mini",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "/home/yoach/dataspeech/artefacts/training/",
"train_dataset_name": "blabble-io/libritts_r",
"train_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"train_dataset_config_name": "clean",
"train_split_name": "train.clean.360",
"eval_dataset_name": "blabble-io/libritts_r",
"eval_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"eval_dataset_config_name": "clean",
"eval_split_name": "train.clean.360",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 12,
"max_eval_samples": 12,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 20,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10,
"per_device_eval_batch_size": 3,
"generation_max_length": 400,
"do_sample": true,
"logging_steps": 15,
"dtype": "float32",
"seed": 456,
"dataloader_num_workers":8
}
...@@ -4,16 +4,16 @@ from transformers import AutoConfig ...@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config = StableSpeechDecoderConfig( decoder_config = StableSpeechDecoderConfig(
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=2, num_hidden_layers=4,
ffn_dim=256, ffn_dim=512,
num_attention_heads=4, num_attention_heads=8,
layerdrop=0.0, layerdrop=0.0,
use_cache=True, use_cache=True,
activation_function="gelu", activation_function="gelu",
hidden_size=256, hidden_size=512,
dropout=0.1, dropout=0.0,
attention_dropout=0.1, attention_dropout=0.0,
activation_dropout=0.1, activation_dropout=0.0,
) )
# TODO: ?? how to make it stop ? # TODO: ?? how to make it stop ?
...@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained( ...@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids # set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048 model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2049 model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049 model.generation_config.eos_token_id = 2049
# set other default generation config params # set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/") model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/")
\ No newline at end of file
...@@ -22,10 +22,14 @@ import logging ...@@ -22,10 +22,14 @@ import logging
import os import os
import re import re
import sys import sys
import shutil
import warnings import warnings
import math import math
import time
import evaluate
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -36,17 +40,21 @@ from torch.utils.data import DataLoader ...@@ -36,17 +40,21 @@ from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
from huggingface_hub import Repository, create_repo
import transformers import transformers
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModel, AutoModel,
AutoModelWithLMHead,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import is_main_process
from transformers import pipeline
from transformers.optimization import get_scheduler
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available from transformers.integrations import is_wandb_available
...@@ -57,6 +65,8 @@ from accelerate.utils import set_seed ...@@ -57,6 +65,8 @@ from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig
if is_wandb_available():
from wandb import Audio
...@@ -72,6 +82,108 @@ logger = logging.getLogger(__name__) ...@@ -72,6 +82,108 @@ logger = logging.getLogger(__name__)
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
_RE_CHECKPOINT = re.compile(r"^checkpoint-(\d+)-epoch-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path
for path in content
if _RE_CHECKPOINT.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_RE_CHECKPOINT.search(x).groups()[0])))
def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[str]:
"""Helper function to sort saved checkpoints from oldest to newest."""
ordering_and_checkpoint_path = []
glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
for path in glob_checkpoints:
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
if regex_match is not None and regex_match.groups() is not None:
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
return checkpoints_sorted
def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint") -> None:
"""Helper function to delete old checkpoints."""
if save_total_limit is None or save_total_limit <= 0:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted = sorted_checkpoints(output_dir=output_dir, checkpoint_prefix=checkpoint_prefix)
if len(checkpoints_sorted) <= save_total_limit:
return
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
for checkpoint in checkpoints_to_be_deleted:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
def log_metric(
accelerator,
metrics: Dict,
train_time: float,
step: int,
epoch: int,
learning_rate: float = None,
prefix: str = "train",
):
"""Helper function to log all training/evaluation metrics with the correct prefixes and styling."""
log_metrics = {}
for k, v in metrics.items():
log_metrics[f"{prefix}/{k}"] = v
log_metrics[f"{prefix}/time"] = train_time
log_metrics[f"{prefix}/epoch"] = epoch
if learning_rate is not None:
log_metrics[f"{prefix}/learning_rate"] = learning_rate
accelerator.log(log_metrics, step=step)
def log_pred(
accelerator,
pred_descriptions: List[str],
pred_prompts: List[str],
transcriptions: List[str],
audios: List[torch.Tensor],
sampling_rate: int,
step: int,
prefix: str = "eval",
num_lines: int = 200000,
):
"""Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
if accelerator.is_main_process:
wandb_tracker = accelerator.get_tracker("wandb")
# pretty name for current step: step 50000 -> step 50k
cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
prefix_pretty = prefix.replace("/", "-")
# convert str data to a wandb compatible format
str_data = [[pred_descriptions[i], pred_prompts[i], transcriptions[i]] for i in range(len(pred_descriptions))]
# log as a table with the appropriate headers
wandb_tracker.log_table(
table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
columns=["Target descriptions", "Target prompts", "Predicted transcriptions"],
data=str_data[:num_lines],
step=step,
commit=False,
)
# wandb can only loads 100 audios per step
wandb_tracker.log({
"Speech samples": [
Audio(
audio,
caption=f"{pred_prompts[i]} --- DESCRIPTION: {pred_descriptions[i]}",
sample_rate=sampling_rate,
)
for (i, audio) in enumerate(audios[:min(len(audios), 100)])
]},
step=step)
#### ARGUMENTS #### ARGUMENTS
...@@ -134,10 +246,19 @@ class ModelArguments: ...@@ -134,10 +246,19 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether to freeze the text encoder."}, metadata={"help": "Whether to freeze the text encoder."},
) )
do_sample: bool = field(
default=False,
metadata={"help": "Whether to do sampling or greedy decoding."},
)
max_length: int = field(
default=400, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."},
)
@dataclass @dataclass
class DataSeq2SeqTrainingArguments: class DataTrainingArguments:
""" """
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
...@@ -305,6 +426,22 @@ class DataSeq2SeqTrainingArguments: ...@@ -305,6 +426,22 @@ class DataSeq2SeqTrainingArguments:
"help": "id column name." "help": "id column name."
} }
) )
wandb_project: str = field(
default="stable-speech",
metadata={"help": "The name of the wandb project."},
)
@dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"The data type (dtype) in which to run training. One of `float32` (full-precision), "
"`float16` or `bfloat16` (both half-precision)."
)
},
)
@dataclass @dataclass
class DataCollatorEncodecWithPadding: class DataCollatorEncodecWithPadding:
...@@ -320,7 +457,6 @@ class DataCollatorEncodecWithPadding: ...@@ -320,7 +457,6 @@ class DataCollatorEncodecWithPadding:
# different padding methods # different padding methods
audios = [torch.tensor(feature["labels"]).squeeze() for feature in features] audios = [torch.tensor(feature["labels"]).squeeze() for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
max_audio = max(len_audio)
input_features = {self.feature_extractor_input_name: audios} input_features = {self.feature_extractor_input_name: audios}
batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True) batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True)
...@@ -372,11 +508,15 @@ class DataCollatorStableSpeechWithPadding: ...@@ -372,11 +508,15 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks) # (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100) labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
delay_pattern_mask = [torch.tensor(feature["label_delay_pattern_mask"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask = torch.nn.utils.rnn.pad_sequence(delay_pattern_mask,batch_first=True,padding_value=-100)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids} batch= {"labels":labels, "label_delay_pattern_mask":delay_pattern_mask, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
...@@ -554,7 +694,7 @@ def main(): ...@@ -554,7 +694,7 @@ def main():
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataSeq2SeqTrainingArguments, Seq2SeqTrainingArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, StableSpeechTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file, # If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments. # let's parse it to get our arguments.
...@@ -566,9 +706,24 @@ def main(): ...@@ -566,9 +706,24 @@ def main():
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args) send_example_telemetry("run_stable_speech", model_args, data_args)
accelerator = Accelerator() if training_args.dtype == "float16":
mixed_precision = "fp16"
elif training_args.dtype == "bfloat16":
mixed_precision = "bf16"
else:
mixed_precision = "no"
# Detecting last checkpoint. accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
)
accelerator.init_trackers(project_name=data_args.wandb_project)
# Detecting last checkpoint and eventually continue from last checkpoint
last_checkpoint = None last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)
...@@ -577,11 +732,12 @@ def main(): ...@@ -577,11 +732,12 @@ def main():
f"Output directory ({training_args.output_dir}) already exists and is not empty. " f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome." "Use --overwrite_output_dir to overcome."
) )
elif last_checkpoint is not None: elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info( logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
# Setup logging # Setup logging
logging.basicConfig( logging.basicConfig(
...@@ -591,14 +747,20 @@ def main(): ...@@ -591,14 +747,20 @@ def main():
) )
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN) logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
# Log on each process the small summary: # Log a small summary on each proces
logger.warning( logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only):
if accelerator.is_main_process: # Set the verbosity to info of the Transformers logger (on main process only)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
...@@ -737,7 +899,8 @@ def main(): ...@@ -737,7 +899,8 @@ def main():
description_column_name = data_args.description_column_name description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0] feature_extractor_input_name = feature_extractor.model_input_names[0]
audio_encoder_eos_token_id = config.decoder.pad_token_id audio_encoder_pad_token_id = config.decoder.pad_token_id
audio_encoder_eos_token_id = config.decoder.eos_token_id
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks num_codebooks = model.decoder.config.num_codebooks
...@@ -834,6 +997,7 @@ def main(): ...@@ -834,6 +997,7 @@ def main():
# (1, codebooks, seq_len) where seq_len=1 # (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(sample, idx): def postprocess_dataset(sample, idx):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
...@@ -841,19 +1005,23 @@ def main(): ...@@ -841,19 +1005,23 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx]) len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_] labels = labels[:, :, :len_]
# add eos token column # TODO: remove, only for test
labels = torch.cat([labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1) labels = labels[:, :, :(len_)%10+20]
# add bos and eos token column
labels = torch.cat([bos_labels,labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels, labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
audio_encoder_bos_token_id, bos_token_id=audio_encoder_bos_token_id,
audio_encoder_eos_token_id, pad_token_id=audio_encoder_pad_token_id,
max_length + num_codebooks) max_length=labels.shape[-1] + num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask) labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:] sample["labels"] = labels[:, 1:]
sample["label_delay_pattern_mask"] = delay_pattern_mask[:, 1:]
return sample return sample
# TODO: done multiple times, how to deal with it. # TODO: done multiple times, how to deal with it.
...@@ -869,7 +1037,6 @@ def main(): ...@@ -869,7 +1037,6 @@ def main():
accelerator.free_memory() accelerator.free_memory()
del generate_labels del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to: if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available(): if is_wandb_available():
...@@ -888,20 +1055,41 @@ def main(): ...@@ -888,20 +1055,41 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with accelerator.main_process_first():
# only the main process saves them
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# 6. Next, we can prepare the training. # 6. Next, we can prepare the training.
# Let's use word CLAP similary as our evaluation metric,
# instantiate a data collator and the trainer # enable gradient checkpointing if necessary
if training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# Define evaluation metrics during training, *i.e.* CLAP similarity
# TODO: allow using another CLAP # Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech") clap = AutoModel.from_pretrained("laion/larger_clap_music_and_speech")
clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech") clap_processor = AutoProcessor.from_pretrained("laion/larger_clap_music_and_speech")
metric = evaluate.load("wer")
# TODO add wer with lightweight asr model
def clap_similarity(texts, audios): def clap_similarity(texts, audios, device):
clap_inputs = clap_processor(text=texts, audios=audios.squeeze(1), padding=True, return_tensors="pt") clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device)
clap.to(device)
text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None)) text_features = clap.get_text_features(clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None))
audio_features = clap.get_audio_features(clap_inputs["input_features"]) audio_features = clap.get_audio_features(clap_inputs["input_features"])
...@@ -909,32 +1097,73 @@ def main(): ...@@ -909,32 +1097,73 @@ def main():
return cosine_sim.mean() return cosine_sim.mean()
eval_metrics = {"clap": clap_similarity} def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios])
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts])
return word_error, [t["text"] for t in transcriptions]
eval_methods = {"clap": clap_similarity, "wer": wer}
def compute_metrics(pred): def compute_metrics(audios, descriptions, prompts, device="cpu"):
input_ids = pred.inputs input_ids = descriptions
input_ids[input_ids==-100] = description_tokenizer.pad_token_id input_ids[input_ids==-100] = description_tokenizer.pad_token_id
texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True) texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
audios = pred.predictions prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
audios = [a.cpu().numpy() for a in audios]
results = {key: metric(texts, audios) for (key, metric) in eval_metrics.items()} results = {
"clap": eval_methods["clap"](texts, audios, device)
}
word_error, transcriptions = eval_methods["wer"](prompts, audios, device)
results["wer"] = word_error
return results return results, texts, prompts, audios, transcriptions
# Define Training Schedule
# Store some constants
per_device_train_batch_size = int(training_args.per_device_train_batch_size)
train_batch_size = per_device_train_batch_size * accelerator.num_processes
gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
if training_args.max_steps < 0:
num_epochs = int(training_args.num_train_epochs)
steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
total_train_steps = steps_per_epoch * num_epochs
elif training_args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
total_train_steps = int(training_args.max_steps)
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_epochs = sys.maxsize
steps_per_epoch = total_train_steps
if training_args.eval_steps is None:
logger.info(
f"eval_steps is not set, evaluating at the end of each epoch"
)
eval_steps = steps_per_epoch
else:
eval_steps = training_args.eval_steps
# Define optimizer, LR scheduler, collator
optimizer = torch.optim.AdamW(
params=model.parameters(),
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
)
# Now save everything to be able to create a single processor later # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
# make sure all processes wait until data is saved lr_scheduler = get_scheduler(
with accelerator.main_process_first(): name=training_args.lr_scheduler_type,
# only the main process saves them optimizer=optimizer,
if accelerator.is_main_process: num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
# save feature extractor, tokenizer and config num_training_steps=total_train_steps * accelerator.num_processes,
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name): )
prompt_tokenizer.save_pretrained(training_args.output_dir)
else:
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")
prompt_tokenizer.save_pretrained(training_args.output_dir)
feature_extractor.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding( data_collator = DataCollatorStableSpeechWithPadding(
...@@ -943,7 +1172,293 @@ def main(): ...@@ -943,7 +1172,293 @@ def main():
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# Prepare everything with accelerate
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {total_train_steps}")
# ======================== Training ================================
train_time = 0
train_start = time.time()
steps_trained_progress_bar = tqdm(
range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
)
continue_training = True
epochs_trained = 0
cur_step = 0
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
if accelerator.is_main_process:
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Clone repo locally
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
if "wandb" not in gitignore:
gitignore.write("wandb\n")
elif training_args.output_dir is not None:
os.makedirs(training_args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
if checkpoint is not None:
accelerator.load_state(checkpoint)
# Find num steps and epoch from saved state string pattern
pattern = r"checkpoint-(\d+)-epoch-(\d+)"
match = re.search(pattern, checkpoint)
cur_step = int(match.group(1))
epochs_trained = int(match.group(2))
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {cur_step}")
steps_trained_progress_bar.update(cur_step)
for epoch in range(0, epochs_trained):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
if training_args.max_steps < 0:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
else:
# Currently we don't know how many steps we've taken in the current epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
# This is "good enough" for our purposes but not fully correct
resume_step = None
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
else:
resume_step = None
gen_kwargs = {
"do_sample": model_args.do_sample,
"max_length": model_args.max_length,
}
# TODO: add max_length
# Define gradient update step fn
def train_step(
batch,
):
model.train()
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
# TODO: add CE per codebook
metrics = {"loss": ce_loss}
return ce_loss, metrics
# Define eval fn
def eval_step(batch):
model.eval()
with torch.no_grad():
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
metrics = {"loss": ce_loss}
return metrics
def generate_step(batch):
model.eval()
output_audios = accelerator.unwrap_model(model).generate(**batch, **gen_kwargs)
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
train_dataloader = DataLoader(
vectorized_datasets["train"],
collate_fn=data_collator,
batch_size=per_device_train_batch_size,
num_workers=training_args.dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory,
)
train_dataloader = accelerator.prepare(train_dataloader)
if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.set_epoch(epoch)
if resume_step is not None:
# Skip the first N batches in the dataloader when resuming from a checkpoint
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
resume_step = None
for batch in train_dataloader:
with accelerator.accumulate(model):
loss, train_metric = train_step(batch)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Check if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
steps_trained_progress_bar.update(1)
cur_step += 1
if cur_step % training_args.logging_steps == 0:
steps_trained_progress_bar.write(
f"Step... ({cur_step} / {total_train_steps} | Loss:"
f" {train_metric['loss']}, Learning Rate:"
f" {lr_scheduler.get_last_lr()[0]})"
)
log_metric(
accelerator,
metrics=train_metric,
learning_rate=lr_scheduler.get_last_lr()[0],
train_time=train_time + time.time() - train_start,
step=cur_step,
epoch=epoch,
prefix="train",
)
# save checkpoint and weights after each save_steps and at the end of training
if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
accelerator.save_state(output_dir=intermediate_dir)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
if cur_step == total_train_steps:
# un-wrap student model for save
model = accelerator.unwrap_model(model)
model.save_pretrained(training_args.output_dir)
# re-wrap student model for final eval
model = accelerator.prepare(model)
if training_args.push_to_hub:
repo.push_to_hub(
commit_message=f"Saving train state of step {cur_step}",
blocking=False,
)
if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
train_time += time.time() - train_start
model.eval()
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_descriptions = []
eval_prompts = []
eval_start = time.time()
validation_dataloader = DataLoader(
vectorized_datasets["eval"],
collate_fn=data_collator,
batch_size=per_device_eval_batch_size,
drop_last=False,
num_workers=training_args.dataloader_pin_memory,
pin_memory=training_args.dataloader_pin_memory,
)
validation_dataloader = accelerator.prepare(validation_dataloader)
for batch in tqdm(
validation_dataloader,
desc=f"Evaluating...",
position=2,
disable=not accelerator.is_local_main_process,
):
# Model forward
eval_metric = eval_step(batch)
eval_metric = accelerator.gather_for_metrics(eval_metric)
eval_metrics.append(eval_metric)
# generation
if training_args.predict_with_generate:
generated_audios = generate_step(batch)
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
(generated_audios, batch["input_ids"], batch["prompt_input_ids"])
)
eval_preds.extend(generated_audios)
eval_descriptions.extend(input_ids)
eval_prompts.extend(prompts)
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = {
key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
}
# compute metrics
metrics_desc = ""
if training_args.predict_with_generate:
metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics(
eval_preds, eval_descriptions, eval_prompts, accelerator.device
)
eval_metrics.update(metric_values)
metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
log_pred(
accelerator,
pred_descriptions,
pred_prompts,
transcriptions,
audios,
sampling_rate=sampling_rate,
step=cur_step,
prefix="eval",
)
# Print metrics and update progress bar
steps_trained_progress_bar.write(
f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
f" {metrics_desc})"
)
log_metric(
accelerator,
metrics=eval_metrics,
train_time=eval_time,
step=cur_step,
epoch=epoch,
prefix="eval",
)
# flush the train metrics
train_start = time.time()
# break condition
if cur_step == total_train_steps:
continue_training = False
break
if not continue_training:
break
accelerator.end_training()
###########################################################################
# Initialize StableSpeechTrainer # Initialize StableSpeechTrainer
trainer = StableSpeechTrainer( trainer = StableSpeechTrainer(
...@@ -956,10 +1471,16 @@ def main(): ...@@ -956,10 +1471,16 @@ def main():
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to and training_args.do_eval:
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
)
def decode_predictions(predictions): def decode_predictions(predictions):
audios = predictions.predictions audios = predictions.predictions
return {"audio": np.array(audios.squeeze(1))} return {"audio": np.array(audios)}
class WandbPredictionProgressCallback(WandbCallback): class WandbPredictionProgressCallback(WandbCallback):
...@@ -982,8 +1503,8 @@ def main(): ...@@ -982,8 +1503,8 @@ def main():
self.description_tokenizer = description_tokenizer self.description_tokenizer = description_tokenizer
self.sample_dataset = val_dataset.select(range(num_samples)) self.sample_dataset = val_dataset.select(range(num_samples))
def on_train_end(self, args, state, control, **kwargs): def on_evaluate(self, args, state, control, **kwargs):
super().on_train_end(args, state, control, **kwargs) super().on_evaluate(args, state, control, **kwargs)
predictions = self.trainer.predict(self.sample_dataset) predictions = self.trainer.predict(self.sample_dataset)
...@@ -1004,7 +1525,7 @@ def main(): ...@@ -1004,7 +1525,7 @@ def main():
trainer=trainer, trainer=trainer,
val_dataset=vectorized_datasets["eval"], val_dataset=vectorized_datasets["eval"],
description_tokenizer=description_tokenizer, description_tokenizer=description_tokenizer,
num_samples=8, # TODO: add to args num_samples=max_eval_samples,
) )
# Add the callback to the trainer # Add the callback to the trainer
......
...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2050):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`]. represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token) vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos)
max_position_embeddings=2048, max_position_embeddings=2048,
num_hidden_layers=24, num_hidden_layers=24,
ffn_dim=4096, ffn_dim=4096,
...@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig): ...@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02, initializer_factor=0.02,
scale_embedding=False, scale_embedding=False,
num_codebooks=4, num_codebooks=4,
pad_token_id=2049, pad_token_id=2050,
bos_token_id=2048, bos_token_id=2048,
eos_token_id=2049, eos_token_id=2049,
tie_word_embeddings=False, tie_word_embeddings=False,
......
...@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self.num_codebooks = config.num_codebooks self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1 # TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList( self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
) )
...@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns: # TODO: delay_pattern_mask
Returns:
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:] logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device) loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks) # (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100) labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels) # loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask = (labels != -100)
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
...@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None: if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=self.generation_config.decoder_start_token_id, bos_token_id=self.generation_config.bos_token_id,
eos_token_id=self.generation_config.eos_token_id, pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
} }
# Ignore copy # Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None): def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks, are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`: seq_len)`:
- [B, -1, -1, -1, -1, E, E, E] - [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, E, E] - [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, E] - [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1] - [B, B, B, B, -1, -1, -1, -1]
where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt: mask is set to the value in the prompt:
- [B, a, b, -1, -1, E, E, E] - [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, E, E] - [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, E] - [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1] - [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1 where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction. tokens in our prediction.
...@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask = ~(bos_delay_pattern).to(input_ids.device) bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device) eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device) mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token # find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset) # and will always be in the first codebook (since it has no codebook offset)
...@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id,
eos_token_id=generation_config.eos_token_id, pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
...@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask # TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) ).transpose(1,2)
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder( audio_encoder_outputs = self.audio_encoder(
...@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values, past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
labels=labels, labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder, **kwargs_decoder,
) )
...@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None: if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids, decoder_input_ids,
bos_token_id=self.generation_config.decoder_start_token_id, bos_token_id=self.generation_config.bos_token_id,
eos_token_id=self.generation_config.eos_token_id, pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length, max_length=self.generation_config.max_length,
) )
...@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
...@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids, input_ids,
bos_token_id=generation_config.decoder_start_token_id, bos_token_id=generation_config.bos_token_id,
eos_token_id=generation_config.eos_token_id, pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
# stash the delay mask so that we don't have to recompute in each forward pass # stash the delay mask so that we don't have to recompute in each forward pass
...@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape( # TODO: probably won't work...
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1 batch_size, self.decoder.num_codebooks, -1
) )
...@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values = self.audio_encoder.decode( output_values = self.audio_encoder.decode(
output_ids, output_ids,
audio_scales=audio_scales, audio_scales=audio_scales,
).audio_values ).audio_values.squeeze(1)
else: else:
output_values = [] output_values = []
for sample_id in range(batch_size): for sample_id in range(batch_size):
sample = output_ids[:, sample_id] sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0) sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)|(sample == generation_config.pad_token_id)).sum(dim=(0,1)) == 0)
sample = sample[:, :, sample_mask] if sample_mask.sum()>0:
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values sample = sample[:, :, sample_mask]
output_values.append(sample.transpose(0,2)) sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh # TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1) output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment