Commit 0f6d59d4 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

latest changes

parent 11fcc066
......@@ -24,11 +24,11 @@
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 1000,
"max_eval_samples": 200,
"max_train_samples": 20,
"max_eval_samples": 10,
"max_duration_in_seconds": 20,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
......@@ -36,30 +36,36 @@
"preprocessing_num_workers": 1,
"pad_token_id": 2049,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 1,
"num_train_epochs": 120,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": true,
"per_device_train_batch_size": 8,
"learning_rate": 1e-6,
"gradient_checkpointing": false,
"per_device_train_batch_size": 2,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.95,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"logging_steps": 25,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"logging_steps": 1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "epoch",
"evaluation_strategy": "steps",
"eval_steps": 600,
"per_device_eval_batch_size": 8,
"generation_max_length": 400,
"fp16": true,
"fp16": false,
"seed": 456,
"dataloader_num_workers":8
......
{
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz",
"description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base",
"push_to_hub": false,
"hub_model_id": "stable-speech-mini",
"report_to": ["wandb"],
"overwrite_output_dir": true,
"output_dir": "/home/yoach/dataspeech/artefacts/training/",
"train_dataset_name": "blabble-io/libritts_r",
"train_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"train_dataset_config_name": "clean",
"train_split_name": "train.clean.360",
"eval_dataset_name": "blabble-io/libritts_r",
"eval_metadata_dataset_name": "stable-speech/libritts-r-tags-and-text-generated",
"eval_dataset_config_name": "clean",
"eval_split_name": "train.clean.360",
"target_audio_column_name": "audio",
"description_column_name": "text_description",
"prompt_column_name": "text",
"max_train_samples": 12,
"max_eval_samples": 12,
"max_duration_in_seconds": 30,
"min_duration_in_seconds": 1.0,
"add_audio_samples_to_wandb": true,
"id_column_name": "id",
"preprocessing_num_workers": 1,
"pad_token_id": 2050,
"decoder_start_token_id": 2048,
"do_train": true,
"num_train_epochs": 20,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": false,
"per_device_train_batch_size": 3,
"learning_rate": 1e-3,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"weight_decay": 0.1,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.1,
"freeze_text_encoder": true,
"do_eval": true,
"predict_with_generate": true,
"include_inputs_for_metrics": true,
"evaluation_strategy": "steps",
"eval_steps": 10,
"per_device_eval_batch_size": 3,
"generation_max_length": 400,
"do_sample": true,
"logging_steps": 15,
"dtype": "float32",
"seed": 456,
"dataloader_num_workers":8
}
......@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config = StableSpeechDecoderConfig(
max_position_embeddings=2048,
num_hidden_layers=2,
ffn_dim=256,
num_attention_heads=4,
num_hidden_layers=4,
ffn_dim=512,
num_attention_heads=8,
layerdrop=0.0,
use_cache=True,
activation_function="gelu",
hidden_size=256,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.1,
hidden_size=512,
dropout=0.0,
attention_dropout=0.0,
activation_dropout=0.0,
)
# TODO: ?? how to make it stop ?
......@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
model.generation_config.decoder_start_token_id = 2048
model.generation_config.pad_token_id = 2049
model.generation_config.pad_token_id = 2050
model.generation_config.eos_token_id = 2049
# set other default generation config params
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
model.generation_config.do_sample = True
model.generation_config.guidance_scale = 3.0
model.generation_config.do_sample = False # True
model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained("/home/yoach/dataspeech/artefacts/tiny-model/")
\ No newline at end of file
This diff is collapsed.
......@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
vocab_size (`int`, *optional*, defaults to 2050):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
......@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def __init__(
self,
vocab_size=2049, # vocab size = 2048 (encodec vocab size) + 1 (eos token)
vocab_size=2050, # vocab size = 2048 (encodec vocab size) + 2 (bos, eos)
max_position_embeddings=2048,
num_hidden_layers=24,
ffn_dim=4096,
......@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor=0.02,
scale_embedding=False,
num_codebooks=4,
pad_token_id=2049,
pad_token_id=2050,
bos_token_id=2048,
eos_token_id=2049,
tie_word_embeddings=False,
......
......@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_dim = config.vocab_size + 1
# TODO: not right dim
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
)
......@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
# TODO: delay_pattern_mask
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None
if labels is not None:
loss = torch.zeros([], device=self.device)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.bos_token_id, -100)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask = (labels != -100)
# per codebook cross-entropy
for codebook in range(self.config.num_codebooks):
codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
codebook_mask = mask[..., codebook].contiguous().view(-1)
codebook_labels = labels[..., codebook].contiguous().view(-1)
codebook_loss = loss_fct(codebook_logits[codebook_mask], codebook_labels[codebook_mask])
loss += codebook_loss
loss = loss / self.config.num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
......@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if delay_pattern_mask is None:
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, eos_token_id: int, max_length: int = None):
def build_delay_pattern_mask(self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, E, E, E]
- [B, B, -1, -1, -1, -1, E, E]
- [B, B, B, -1, -1, -1, -1, E]
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where B is the BOS token id, E is the EOS token id and -1 indicates that the token is valid for prediction. If we include
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, E, E, E]
- [B, B, c, d, -1, -1, E, E]
- [B, B, B, e, f, -1, -1, E]
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
......@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask = ~(bos_delay_pattern).to(input_ids.device)
eos_mask = ~(eos_delay_pattern).to(input_ids.device)
mask = ~(bos_delay_pattern + eos_delay_pattern).to(input_ids.device)
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * eos_token_id
input_ids = mask * input_ids_shifted + ~bos_mask * bos_token_id + ~eos_mask * pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
......@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
......@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None,
label_delay_pattern_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right(
labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id
)
labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1,2)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
audio_encoder_outputs = self.audio_encoder(
......@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
label_delay_pattern_mask=label_delay_pattern_mask,
**kwargs_decoder,
)
......@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_delay_pattern_mask is None:
decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
decoder_input_ids,
bos_token_id=self.generation_config.decoder_start_token_id,
eos_token_id=self.generation_config.eos_token_id,
bos_token_id=self.generation_config.bos_token_id,
pad_token_id=self.generation_config.pad_token_id,
max_length=self.generation_config.max_length,
)
......@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1,2)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
......@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.decoder_start_token_id,
eos_token_id=generation_config.eos_token_id,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=generation_config.max_length,
)
# stash the delay mask so that we don't have to recompute in each forward pass
......@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.eos_token_id)].reshape(
# TODO: probably won't work...
output_ids = output_ids[(model_kwargs["decoder_delay_pattern_mask"] != generation_config.bos_token_id)&(model_kwargs["decoder_delay_pattern_mask"] != generation_config.pad_token_id)].reshape(
batch_size, self.decoder.num_codebooks, -1
)
......@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values = self.audio_encoder.decode(
output_ids,
audio_scales=audio_scales,
).audio_values
).audio_values.squeeze(1)
else:
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)).sum(dim=(0,1)) == 0)
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
sample_mask = (((sample == generation_config.bos_token_id)|(sample == generation_config.eos_token_id)|(sample == generation_config.pad_token_id)).sum(dim=(0,1)) == 0)
if sample_mask.sum()>0:
sample = sample[:, :, sample_mask]
sample = self.audio_encoder.decode(sample[None, ...], [audio_scales[sample_id]]).audio_values
output_values.append(sample.transpose(0,2))
else:
output_values.append(torch.zeros((1,1,1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1)
output_values = torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0).transpose(1,2).squeeze(-1).squeeze(1)
if generation_config.return_dict_in_generate:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment