Commit fc66e60b authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix some bugs

parent b6341055
...@@ -30,7 +30,6 @@ from dataclasses import dataclass, field ...@@ -30,7 +30,6 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import datasets import datasets
import evaluate
import numpy as np import numpy as np
import torch import torch
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
...@@ -315,7 +314,7 @@ class DataSeq2SeqTrainingArguments: ...@@ -315,7 +314,7 @@ class DataSeq2SeqTrainingArguments:
@dataclass @dataclass
class DataCollatorMusicGenWithPadding: class DataCollatorStableSpeechWithPadding:
""" """
Data collator that will dynamically pad the inputs received. Data collator that will dynamically pad the inputs received.
Args: Args:
...@@ -360,16 +359,14 @@ class DataCollatorMusicGenWithPadding: ...@@ -360,16 +359,14 @@ class DataCollatorMusicGenWithPadding:
input_ids = [{"input_ids": feature["input_ids"]} for feature in features] input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) input_ids = self.description_tokenizer.pad(input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch= {"labels":labels, **input_ids}
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features] prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of) prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
batch["prompt_input_ids"] = prompt_input_ids["input_ids"] batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids: if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"] batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
batch= {"labels":labels, **input_ids}
if self.feature_extractor_input_name in features[0]: if self.feature_extractor_input_name in features[0]:
# TODO: verify that it works # TODO: verify that it works
...@@ -485,29 +482,30 @@ def load_multiple_datasets( ...@@ -485,29 +482,30 @@ def load_multiple_datasets(
**kwargs, **kwargs,
) )
if id_column_name is not None and id_column_name not in dataset: if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns" f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.columns))}." f"- one of {', '.join(list(dataset.column_names))}."
) )
if id_column_name is not None and id_column_name not in metadata_dataset: if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError( raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns" f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.columns))}." f"- one of {', '.join(list(metadata_dataset.column_names))}."
) )
elif id_column_name is not None: elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}") metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_keep = set(metadata_dataset.columns).intersection(set(dataset.column_names)) metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(set(metadata_dataset.columns)-metadata_columns_to_keep) metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None: if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0: if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict["name"]}") raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
dataset_features = dataset.features.keys()
if columns_to_keep is not None: if columns_to_keep is not None:
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset) all_datasets.append(dataset)
...@@ -586,9 +584,14 @@ def main(): ...@@ -586,9 +584,14 @@ def main():
# 1. First, let's load the dataset # 1. First, let's load the dataset
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
if training_args.do_train: columns_to_keep = [data_args.target_audio_column_name, data_args.prompt_column_name]
if data_args.description_column_name is not None:
columns_to_keep.append(data_args.description_column_name)
if data_args.conditional_audio_column_name is not None:
columns_to_keep.append(data_args.conditional_audio_column_name)
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets( raw_datasets["train"] = load_multiple_datasets(
data_args.train_dataset_name, data_args.train_dataset_name,
data_args.train_dataset_config_name, data_args.train_dataset_config_name,
...@@ -597,10 +600,9 @@ def main(): ...@@ -597,10 +600,9 @@ def main():
dataset_samples=data_args.train_dataset_samples, dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed, seed=training_args.seed,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
...@@ -645,10 +647,9 @@ def main(): ...@@ -645,10 +647,9 @@ def main():
data_args.eval_metadata_dataset_name, data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name, splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
...@@ -752,11 +753,11 @@ def main(): ...@@ -752,11 +753,11 @@ def main():
if description_column_name is not None: if description_column_name is not None:
text = batch[description_column_name] text = batch[description_column_name]
batch["input_ids"] = description_tokenizer(text)["input_ids"] batch["input_ids"] = description_tokenizer(text.strip())["input_ids"]
if prompt_column_name is not None: if prompt_column_name is not None:
text = batch[prompt_column_name] text = batch[prompt_column_name]
batch["prompt_input_ids"] = prompt_tokenizer(text)["input_ids"] batch["prompt_input_ids"] = prompt_tokenizer(text.strip())["input_ids"]
# load audio # load audio
target_sample = batch[target_audio_column_name] target_sample = batch[target_audio_column_name]
...@@ -878,8 +879,8 @@ def main(): ...@@ -878,8 +879,8 @@ def main():
config.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir)
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorMusicGenWithPadding( data_collator = DataCollatorStableSpeechWithPadding(
feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer
) )
# Freeze Encoders # Freeze Encoders
...@@ -956,8 +957,9 @@ def main(): ...@@ -956,8 +957,9 @@ def main():
# use last checkpoint if exist # use last checkpoint if exist
if last_checkpoint is not None: if last_checkpoint is not None:
checkpoint = last_checkpoint checkpoint = last_checkpoint
elif os.path.isdir(model_args.model_name_or_path): # TODO: it's loading trainer from model_name_or_path doesn't work if saving config
checkpoint = model_args.model_name_or_path # elif os.path.isdir(model_args.model_name_or_path):
# checkpoint = model_args.model_name_or_path
else: else:
checkpoint = None checkpoint = None
......
...@@ -137,8 +137,8 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -137,8 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024): vocab_size (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer. Vocabulary size of the prompt # TODO.
kwargs (*optional*): kwargs (*optional*):
Dictionary of keyword arguments. Notably: Dictionary of keyword arguments. Notably:
...@@ -189,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -189,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type = "stable_speech" model_type = "stable_speech"
is_composition = True is_composition = True
def __init__(self, prompt_embed_dim=1024, **kwargs): def __init__(self, vocab_size=1024, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
...@@ -202,7 +202,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -202,7 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder") decoder_config = kwargs.pop("decoder")
self.prompt_embed_dim = prompt_embed_dim self.vocab_size = vocab_size
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config) self.decoder = StableSpeechDecoderConfig(**decoder_config)
......
...@@ -730,7 +730,6 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -730,7 +730,6 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
# if prompt_hidden_states, fuse to inputs_embeds and update input shape # if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None: if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1) inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# TODO: verify if prompt attention mask is required # TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings # As it is, the masked ids from the prompt will still count in the positions embeddings
...@@ -740,9 +739,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -740,9 +739,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
logger.warning_once( logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour." "`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
) )
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)]) attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
...@@ -1538,7 +1537,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1538,7 +1537,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings # prompt embeddings
self.embed_prompts = nn.Embedding(config.prompt_embed_dim, self.decoder.config.hidden_size) self.embed_prompts = nn.Embedding(config.vocab_size, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
...@@ -1557,7 +1556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1557,7 +1556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self.post_init() self.post_init()
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_factor std = self.decoder.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)): if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
...@@ -1787,7 +1786,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1787,7 +1786,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
) )
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
decoder_config, kwargs_decoder = AutoConfig.from_pretrained( # TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = StableSpeechDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
) )
...@@ -1923,7 +1923,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1923,7 +1923,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id
) )
elif decoder_input_ids is None and decoder_inputs_embeds is None: elif decoder_input_ids is None and decoder_inputs_embeds is None:
...@@ -2190,7 +2190,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2190,7 +2190,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels.transpose(1,2), self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings # TODO: now it's possible with prompt_embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment