Commit 997bf5e6 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make audio encoding multi-gpus compatible

parent e25b8ba0
......@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union
import datasets
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
import transformers
......@@ -43,13 +45,15 @@ from transformers import (
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available
from accelerate import PartialState
from accelerate import Accelerator
from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig
......@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments:
default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
conditional_audio_column_name: str = field( # TODO
default=None,
metadata={"help": "The name of the dataset column containing the conditional audio data. Defaults to 'audio'"},
)
description_column_name: str = field( #TODO
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
......@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments:
"help": "id column name."
}
)
@dataclass
class DataCollatorEncodecWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor: AutoFeatureExtractor
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios = [torch.tensor(feature["labels"]).squeeze() for feature in features]
len_audio = [len(audio) for audio in audios]
max_audio = max(len_audio)
input_features = {self.feature_extractor_input_name: audios}
batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True)
batch[self.feature_extractor_input_name] = batch[self.feature_extractor_input_name].unsqueeze(1) # add mono-channel
batch["padding_mask"] = batch.pop("attention_mask")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch
@dataclass
......@@ -437,6 +460,7 @@ def convert_dataset_str_to_list(
def load_multiple_datasets(
accelerator: Accelerator,
dataset_names: Union[List, str],
dataset_config_names: Union[List, str],
metadata_dataset_names: Optional[str]=None,
......@@ -463,51 +487,52 @@ def load_multiple_datasets(
all_datasets = []
# iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
dataset = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
dataset_features = dataset.features.keys()
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
metadata_dataset = load_dataset(
metadata_dataset_name,
with accelerator.main_process_first():
dataset = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
dataset_features = dataset.features.keys()
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
metadata_dataset = load_dataset(
metadata_dataset_name,
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
dataset_features = dataset.features.keys()
if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
dataset_features = dataset.features.keys()
if columns_to_keep is not None:
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
if columns_to_keep is not None:
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset)
if len(all_datasets) == 1:
......@@ -522,7 +547,8 @@ def load_multiple_datasets(
seed=seed,
)
else:
interleaved_dataset = concatenate_datasets(all_datasets)
with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
......@@ -544,6 +570,8 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args)
accelerator = Accelerator()
# Detecting last checkpoint.
last_checkpoint = None
......@@ -566,7 +594,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
# Log on each process the small summary:
logger.warning(
......@@ -574,8 +602,9 @@ def main():
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
if accelerator.is_main_process:
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
......@@ -585,71 +614,51 @@ def main():
raw_datasets = DatasetDict()
num_workers = data_args.preprocessing_num_workers
columns_to_keep = [data_args.target_audio_column_name, data_args.prompt_column_name]
columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name
}
if data_args.description_column_name is not None:
columns_to_keep.append(data_args.description_column_name)
if data_args.conditional_audio_column_name is not None:
columns_to_keep.append(data_args.conditional_audio_column_name)
columns_to_keep["description_column_nam"] = data_args.description_column_name
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
accelerator,
data_args.train_dataset_name,
data_args.train_dataset_config_name,
data_args.train_metadata_dataset_name,
metadata_dataset_names=data_args.train_metadata_dataset_name,
splits=data_args.train_split_name,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
columns_to_keep=columns_to_keep.values(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if data_args.target_audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--target_audio_column_name '{data_args.target_audio_column_name}' not found in dataset '{data_args.train_dataset_name}'."
" Make sure to set `--target_audio_column_name` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.description_column_name is not None and data_args.description_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--description_column_name {data_args.description_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--description_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
if data_args.prompt_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--description_column_name {data_args.prompt_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--prompt_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
if data_args.conditional_audio_column_name is not None and data_args.conditional_audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--conditional_audio_column_name {data_args.conditional_audio_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--conditional_audio_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
for key in columns_to_keep:
if columns_to_keep[key] not in raw_datasets["train"].column_names:
raise ValueError(
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
f" Make sure to set `--{key}` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if training_args.do_eval:
raw_datasets["eval"] = load_multiple_datasets(
accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
data_args.eval_metadata_dataset_name,
data_args.eval_dataset_config_name if data_args.eval_dataset_config_name else data_args.train_dataset_config_name,
metadata_dataset_names=data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep,
columns_to_keep=columns_to_keep.values(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
......@@ -657,9 +666,8 @@ def main():
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
# TODO: is is the right way to do ?
# 3. Next, let's load the config as we might need it to create
# load config
# 2. Next, let's load the config as we might need it to create
# load config TODO: add the option to create the config from scratch
config = StableSpeechConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
......@@ -673,8 +681,7 @@ def main():
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
})
# 4. Now we can instantiate the feature extractor, tokenizers and model
# 3. Now we can instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
......@@ -692,16 +699,24 @@ def main():
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
# load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235")
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# create model + TODO: not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
......@@ -711,46 +726,31 @@ def main():
trust_remote_code=data_args.trust_remote_code,
)
# take audio_encoder_feature_extractor
audio_encoder_feature_extractor = AutoFeatureExtractor.from_pretrained(
model.config.audio_encoder._name_or_path,
)
# 5. Now we preprocess the datasets including loading the audio, resampling and normalization
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# resample target audio
raw_datasets = raw_datasets.cast_column(
data_args.target_audio_column_name, datasets.features.Audio(sampling_rate=audio_encoder_feature_extractor.sampling_rate)
)
if data_args.conditional_audio_column_name is not None:
raw_datasets = raw_datasets.cast_column(
data_args.conditional_audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# derive max & min input length for sample rate & max duration
max_target_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
min_target_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
sampling_rate = feature_extractor.sampling_rate
max_target_length = data_args.max_duration_in_seconds * sampling_rate
min_target_length = data_args.min_duration_in_seconds * sampling_rate
target_audio_column_name = data_args.target_audio_column_name
conditional_audio_column_name = data_args.conditional_audio_column_name
description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0]
# resample target audio
raw_datasets = raw_datasets.cast_column(
target_audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
# We need to read the audio files as arrays and tokenize the texts.
def pass_through_processors(batch):
# load audio
if conditional_audio_column_name is not None:
sample = batch[target_audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
if description_column_name is not None:
text = batch[description_column_name]
batch["input_ids"] = description_tokenizer(text.strip())["input_ids"]
......@@ -761,14 +761,14 @@ def main():
# load audio
target_sample = batch[target_audio_column_name]
labels = audio_encoder_feature_extractor(target_sample["array"], sampling_rate=target_sample["sampling_rate"])
labels = feature_extractor(target_sample["array"], sampling_rate=target_sample["sampling_rate"])
batch["labels"] = labels["input_values"]
# take length of raw audio waveform
batch["target_length"] = len(target_sample["array"].squeeze())
return batch
with training_args.main_process_first(desc="dataset map preprocessing"):
with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map(
pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names,
......@@ -785,34 +785,81 @@ def main():
num_proc=num_workers,
input_columns=["target_length"],
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name)
def apply_audio_decoder(batch):
labels = audio_decoder.encode(torch.tensor(batch["labels"]).to(audio_decoder.device))["audio_codes"]
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
model.generation_config.decoder_start_token_id,
model.generation_config.max_length + 1)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
# the first timestamp is associated to a row full of BOS, let's get rid of it
batch["labels"] = labels[:, 1:]
return batch
with training_args.main_process_first(desc="audio target preprocessing"):
# for now on CPU
# TODO: enrich for GPU
vectorized_datasets = vectorized_datasets.map(
apply_audio_decoder,
num_proc=num_workers,
desc="preprocess datasets",
len_audio = batch.pop("len_audio")
audio_decoder.to(batch["input_values"].device).eval()
labels = audio_decoder.encode(**batch)["audio_codes"]
output = {}
output["len_audio"] = len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1,2)
output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
return output
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=training_args.per_device_eval_batch_size,
collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
# TODO: will it work on GPU ? unmerged for now https://github.com/huggingface/accelerate/pull/2433
# for split in vectorized_datasets:
# with distributed_state.split_between_processes(vectorized_datasets[split]["labels"]) as input_labels:
# result = audio_decoder(input_labels)
data_loader = accelerator.prepare(data_loader)
all_generated_labels = []
all_ratios = []
all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels)
all_generated_labels.extend(generate_labels["labels"].cpu())
all_ratios.extend(generate_labels["ratio"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu())
def postprocess_dataset(sample, idx):
# (1, seq_len, codebooks, bsz)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
model.generation_config.decoder_start_token_id,
model.generation_config.max_length + model.decoder.config.num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
len_ = int(all_ratios[idx] * all_lens[idx])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:len_]
return sample
# TODO: done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=num_workers,
desc="Postprocessing labeling",
with_indices=True,
)
accelerator.free_memory()
del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available():
......@@ -827,6 +874,7 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
# TODO: save to disk in this step instead of something else ??
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return
......@@ -865,9 +913,9 @@ def main():
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with training_args.main_process_first():
with accelerator.main_process_first():
# only the main process saves them
if is_main_process(training_args.local_rank):
if accelerator.is_main_process:
# save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir)
......@@ -936,7 +984,7 @@ def main():
audios = predictions["audio"]
# log the table to wandb
self._wandb.log({"sample_songs": [self._wandb.Audio(audio, caption=text, sample_rate=audio_encoder_feature_extractor.sampling_rate) for (audio, text) in zip(audios, texts)]})
self._wandb.log({"sample_songs": [self._wandb.Audio(audio, caption=text, sample_rate=sampling_rate) for (audio, text) in zip(audios, texts)]})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment