Commit fc66e60b authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix some bugs

parent b6341055
......@@ -30,7 +30,6 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import datasets
import evaluate
import numpy as np
import torch
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
......@@ -315,7 +314,7 @@ class DataSeq2SeqTrainingArguments:
@dataclass
class DataCollatorMusicGenWithPadding:
class DataCollatorStableSpeechWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
......@@ -360,16 +359,14 @@ class DataCollatorMusicGenWithPadding:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
batch= {"labels":labels, **input_ids}
if self.feature_extractor_input_name in features[0]:
# TODO: verify that it works
......@@ -485,29 +482,30 @@ def load_multiple_datasets(
**kwargs,
)
if id_column_name is not None and id_column_name not in dataset:
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.columns))}."
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset:
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.columns))}."
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_keep = set(metadata_dataset.columns).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(set(metadata_dataset.columns)-metadata_columns_to_keep)
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict["name"]}")
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
dataset_features = dataset.features.keys()
if columns_to_keep is not None:
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset)
......@@ -586,9 +584,14 @@ def main():
# 1. First, let's load the dataset
raw_datasets = DatasetDict()
num_workers = data_args.preprocessing_num_workers
if training_args.do_train:
columns_to_keep = [data_args.target_audio_column_name, data_args.prompt_column_name]
if data_args.description_column_name is not None:
columns_to_keep.append(data_args.description_column_name)
if data_args.conditional_audio_column_name is not None:
columns_to_keep.append(data_args.conditional_audio_column_name)
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
data_args.train_dataset_name,
data_args.train_dataset_config_name,
......@@ -597,10 +600,9 @@ def main():
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
......@@ -645,10 +647,9 @@ def main():
data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
......@@ -752,11 +753,11 @@ def main():
if description_column_name is not None:
text = batch[description_column_name]
batch["input_ids"] = description_tokenizer(text)["input_ids"]
batch["input_ids"] = description_tokenizer(text.strip())["input_ids"]
if prompt_column_name is not None:
text = batch[prompt_column_name]
batch["prompt_input_ids"] = prompt_tokenizer(text)["input_ids"]
batch["prompt_input_ids"] = prompt_tokenizer(text.strip())["input_ids"]
# load audio
target_sample = batch[target_audio_column_name]
......@@ -878,8 +879,8 @@ def main():
config.save_pretrained(training_args.output_dir)
# Instantiate custom data collator
data_collator = DataCollatorMusicGenWithPadding(
feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer
data_collator = DataCollatorStableSpeechWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer
)
# Freeze Encoders
......@@ -956,8 +957,9 @@ def main():
# use last checkpoint if exist
if last_checkpoint is not None:
checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path):
checkpoint = model_args.model_name_or_path
# TODO: it's loading trainer from model_name_or_path doesn't work if saving config
# elif os.path.isdir(model_args.model_name_or_path):
# checkpoint = model_args.model_name_or_path
else:
checkpoint = None
......
......@@ -137,8 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer.
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt # TODO.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
......@@ -189,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type = "stable_speech"
is_composition = True
def __init__(self, prompt_embed_dim=1024, **kwargs):
def __init__(self, vocab_size=1024, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
......@@ -202,7 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder")
self.prompt_embed_dim = prompt_embed_dim
self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config)
......
......@@ -730,7 +730,6 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
......@@ -740,9 +739,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)])
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
......@@ -1538,7 +1537,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings
self.embed_prompts = nn.Embedding(config.prompt_embed_dim, self.decoder.config.hidden_size)
self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None:
......@@ -1557,7 +1556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self.post_init()
def _init_weights(self, module):
std = self.config.initializer_factor
std = self.decoder.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
......@@ -1787,7 +1786,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
# TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = StableSpeechDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
......@@ -1923,7 +1923,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id
)
elif decoder_input_ids is None and decoder_inputs_embeds is None:
......@@ -2190,7 +2190,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment