Commit d0140745 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update code: fix accelerate, fix delay pattern mask, improve generation

parent ee12a812
......@@ -36,8 +36,8 @@
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"pad_token_id": 2048,
"decoder_start_token_id": 2049,
"do_train": true,
"num_train_epochs": 120,
......
......@@ -24,8 +24,8 @@
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 12,
"max_eval_samples": 12,
"max_train_samples": 4,
"max_eval_samples": 4,
"max_duration_in_seconds": 30,
......@@ -36,14 +36,14 @@
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"pad_token_id": 2048,
"decoder_start_token_id": 2049,
"do_train": true,
"num_train_epochs": 20,
"num_train_epochs": 120,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"per_device_train_batch_size": 2,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
......@@ -60,10 +60,10 @@
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10,
"per_device_eval_batch_size": 3,
"eval_steps": 30,
"per_device_eval_batch_size": 2,
"generation_max_length": 400,
"do_sample": true,
"do_sample": false,
"logging_steps": 15,
......
......@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049
model.generation_config.decoder_start_token_id = 2049
model.generation_config.pad_token_id = 2048
model.generation_config.eos_token_id = 2048
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
......
......@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig(
decoder = StableSpeechForCausalLM(decoder_config)
decoder.save_pretrained("/raid/yoach/tmp/decoder/")
decoder.save_pretrained("/home/yoach/dataspeech/artefacts/decoder/")
t5 = AutoConfig.from_pretrained("t5-base")
......@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base")
model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path="t5-base",
audio_encoder_pretrained_model_name_or_path="facebook/encodec_32khz",
decoder_pretrained_model_name_or_path="/raid/yoach/tmp/decoder/",
decoder_pretrained_model_name_or_path="/home/yoach/dataspeech/artefacts/decoder/",
vocab_size = t5.vocab_size
)
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049
model.generation_config.decoder_start_token_id = 2049
model.generation_config.pad_token_id = 2048
model.generation_config.eos_token_id = 2048
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/raid/yoach/tmp/small-stable-speech-untrained/")
\ No newline at end of file
model.save_pretrained("/home/yoach/dataspeech/artefacts/small-stable-speech-untrained/")
\ No newline at end of file
......@@ -26,6 +26,8 @@ import shutil
import warnings
import math
import time
from multiprocess import set_start_method
import evaluate
from tqdm import tqdm
......@@ -63,7 +65,7 @@ from accelerate import Accelerator
from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available():
from wandb import Audio
......@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks)
labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
delay_pattern_mask = [torch.tensor(feature["label_delay_pattern_mask"]).transpose(0,1) for feature in features]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask = torch.nn.utils.rnn.pad_sequence(delay_pattern_mask,batch_first=True,padding_value=-100)
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, "label_delay_pattern_mask":delay_pattern_mask, **input_ids}
batch= {"labels":labels, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
......@@ -1014,23 +1011,30 @@ def main():
len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_]
# TODO: remove, only for test
labels = labels[:, :, :(len_)%10+20]
# add bos and eos token column
labels = torch.cat([bos_labels,labels, eos_labels.to(labels.device).to(labels.dtype)], dim=-1)
labels = labels[:, :, :(len_)%10+20] # TODO: change
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
labels, delay_pattern_mask = build_delay_pattern_mask(labels,
bos_token_id=audio_encoder_bos_token_id,
pad_token_id=audio_encoder_pad_token_id,
max_length=labels.shape[-1] + num_codebooks)
pad_token_id=audio_encoder_eos_token_id,
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels = torch.where(delay_pattern_mask==-1, audio_encoder_eos_token_id, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:]
sample["label_delay_pattern_mask"] = delay_pattern_mask[:, 1:]
# we also remove the last timestampts (full of PAD)
sample["labels"] = labels[:, 1:].cpu()
return sample
# TODO: done multiple times, how to deal with it.
......@@ -1047,12 +1051,6 @@ def main():
del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available():
from transformers.integrations import WandbCallback
else:
raise ValueError("`args.add_audio_samples_to_wandb=True` but wandb is not installed. See https://docs.wandb.ai/quickstart to install.")
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
......@@ -1467,4 +1465,5 @@ def main():
if __name__ == "__main__":
set_start_method("spawn")
main()
\ No newline at end of file
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration
\ No newline at end of file
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration, apply_delay_pattern_mask, build_delay_pattern_mask
\ No newline at end of file
......@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2050):
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
......@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__(
self,
vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos)
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
pad_token_id=2050,
bos_token_id=2048,
eos_token_id=2049,
pad_token_id=2048,
bos_token_id=2049,
eos_token_id=2048,
tie_word_embeddings=False,
**kwargs,
):
......
......@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
]
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
def build_delay_pattern_mask(input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int, num_codebooks: int):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
@dataclass
class StableSpeechUnconditionalInput(ModelOutput):
......@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask = (labels != -100)
# we use every codebooks token AND one single EOS at the end of each codebooks
mask = (input_ids.transpose(1,2) != self.config.eos_token_id) & ((labels != -100))
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
......@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
bsz, num_codebooks, seq_len = input_ids.shape
max_length = max_length if max_length is not None else self.generation_config.max_length
input_ids_shifted = (
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern = torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool))
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids = input_ids[:, 0, :]
start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
if len(start_ids) > 0:
first_start_id = min(start_ids)
else:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id = seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
return input_ids, pattern_mask
return build_delay_pattern_mask(input_ids, bos_token_id, pad_token_id, max_length, self.num_codebooks)
@staticmethod
def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len = input_ids.shape[-1]
decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
return input_ids
return apply_delay_pattern_mask(input_ids, decoder_pad_token_mask)
@torch.no_grad()
def generate(
......@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
**kwargs,
):
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
......@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder,
)
......@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"use_cache": use_cache,
}
......@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# TODO: probably won't work...
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape(
_, mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id)&(mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(
batch_size, self.decoder.num_codebooks, -1
)
......@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if audio_scales is None:
audio_scales = [None] * batch_size
decode_in_batch = ((output_ids == generation_config.bos_token_id).sum() + (output_ids == generation_config.eos_token_id).sum()) == 0
if decode_in_batch.item():
decode_sequentially = generation_config.bos_token_id in output_ids or generation_config.pad_token_id in output_ids or generation_config.eos_token_id in output_ids
if not decode_sequentially:
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment