Commit 11fcc066 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

working training + generation

parent 997bf5e6
......@@ -36,14 +36,14 @@
"preprocessing_num_workers": 1,
"pad_token_id": 2048,
"pad_token_id": 2049,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 20,
"num_train_epochs": 1,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"per_device_train_batch_size": 16,
"per_device_train_batch_size": 8,
"learning_rate": 1e-6,
"adam_beta1": 0.9,
"adam_beta2": 0.95,
......@@ -56,7 +56,7 @@
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "epoch",
"per_device_eval_batch_size": 16,
"per_device_eval_batch_size": 8,
"generation_max_length": 400,
"fp16": true,
......
......@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2048
model.generation_config.pad_token_id = 2049
model.generation_config.eos_token_id = 2049
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
......
......@@ -77,16 +77,10 @@ def list_field(default=None, metadata=None):
class StableSpeechTrainer(Seq2SeqTrainer):
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length, tensor.shape[2]), dtype=tensor.dtype, device=tensor.device
......@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
# TODO: check it's been padded on the left
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
......@@ -676,6 +671,7 @@ def main():
)
# update pad token id and decoder_start_token_id
# TODO: verify if this makes sense, maybe should do it for model.decoder
config.update({
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
......@@ -700,6 +696,7 @@ def main():
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
# load description tokenizer
......@@ -740,6 +737,10 @@ def main():
description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0]
audio_encoder_eos_token_id = config.decoder.pad_token_id
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
# resample target audio
raw_datasets = raw_datasets.cast_column(
......@@ -794,7 +795,6 @@ def main():
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name)
......@@ -832,18 +832,28 @@ def main():
all_ratios.extend(generate_labels["ratio"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu())
# (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
def postprocess_dataset(sample, idx):
# (1, seq_len, codebooks, bsz)
# (1, codebooks, seq_len)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_]
# add eos token column
labels = torch.cat([labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
model.generation_config.decoder_start_token_id,
model.generation_config.max_length + model.decoder.config.num_codebooks)
audio_encoder_bos_token_id,
audio_encoder_eos_token_id,
max_length + num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
len_ = int(all_ratios[idx] * all_lens[idx])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:len_]
sample["labels"] = labels[:, 1:]
return sample
# TODO: done multiple times, how to deal with it.
......@@ -956,7 +966,7 @@ def main():
"""Custom WandbCallback to log model predictions during training.
"""
def __init__(self, trainer, val_dataset,
def __init__(self, trainer, val_dataset, description_tokenizer, # TODO: add
num_samples=8):
"""Initializes the WandbPredictionProgressCallback instance.
......@@ -969,6 +979,7 @@ def main():
"""
super().__init__()
self.trainer = trainer
self.description_tokenizer = description_tokenizer
self.sample_dataset = val_dataset.select(range(num_samples))
def on_train_end(self, args, state, control, **kwargs):
......@@ -992,6 +1003,7 @@ def main():
progress_callback = WandbPredictionProgressCallback(
trainer=trainer,
val_dataset=vectorized_datasets["eval"],
description_tokenizer=description_tokenizer,
num_samples=8, # TODO: add to args
)
......
......@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2048):
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
......@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__(
self,
vocab_size=2048,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
pad_token_id=2048,
pad_token_id=2049,
bos_token_id=2048,
eos_token_id=None,
eos_token_id=2049,
tie_word_embeddings=False,
**kwargs,
):
......
......@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# TODO: verify if prompt attention mask is required
# TODO: verify if prompt attention mask is required and has to be
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
......@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
)
# embed positions
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions = self.embed_positions(inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
......@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=self.generation_config.pad_token_id,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
max_length=self.generation_config.max_length,
)
......@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int = None):
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [P, -1, -1, -1, -1, P, P, P]
- [P, P, -1, -1, -1, -1, P, P]
- [P, P, P, -1, -1, -1, -1, P]
- [P, P, P, P, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
- [B, -1, -1, -1, -1, E, E, E]
- [B, B, -1, -1, -1, -1, E, E]
- [B, B, B, -1, -1, -1, -1, E]
- [B, B, B, B, -1, -1, -1, -1]
where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [P, a, b, -1, -1, P, P, P]
- [P, P, c, d, -1, -1, P, P]
- [P, P, P, e, f, -1, -1, P]
- [P, P, P, P, g, h, -1, -1]
- [B, a, b, -1, -1, E, E, E]
- [B, B, c, d, -1, -1, E, E]
- [B, B, B, e, f, -1, -1, E]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
......@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu(
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
......@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
max_length=generation_config.max_length,
)
......@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
batch_size, self.num_codebooks, -1
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1
)
if generation_config.return_dict_in_generate:
......@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
# TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
......@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids,
self.generation_config.pad_token_id,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
max_length=self.generation_config.max_length,
)
......@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
pad_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
......@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape(
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1
)
......@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if audio_scales is None:
audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
decode_in_batch = ((output_ids == generation_config.bos_token_id).sum() + (output_ids == generation_config.eos_token_id).sum()) == 0
if decode_in_batch.item():
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0)
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1)
if generation_config.return_dict_in_generate:
outputs.sequences = output_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment