"docs/vscode:/vscode.git/clone" did not exist on "2e1d2d7e66c33fdd2b58aaf03a9893dbe593a3a3"
Commit 997bf5e6 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make audio encoding multi-gpus compatible

parent e25b8ba0
...@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union ...@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union
import datasets import datasets
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets from datasets import DatasetDict, load_dataset, Dataset, IterableDataset, interleave_datasets, concatenate_datasets
import transformers import transformers
...@@ -43,13 +45,15 @@ from transformers import ( ...@@ -43,13 +45,15 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer, Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.integrations import is_wandb_available from transformers.integrations import is_wandb_available
from accelerate import PartialState
from accelerate import Accelerator
from accelerate.utils import set_seed
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig
...@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments: ...@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments:
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the target audio data. Defaults to 'audio'"},
) )
conditional_audio_column_name: str = field( # TODO
default=None,
metadata={"help": "The name of the dataset column containing the conditional audio data. Defaults to 'audio'"},
)
description_column_name: str = field( #TODO description_column_name: str = field( #TODO
default=None, default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."}, metadata={"help": "The name of the dataset column containing the text data. Defaults to 'None'."},
...@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments: ...@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments:
"help": "id column name." "help": "id column name."
} }
) )
@dataclass
class DataCollatorEncodecWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor: AutoFeatureExtractor
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios = [torch.tensor(feature["labels"]).squeeze() for feature in features]
len_audio = [len(audio) for audio in audios]
max_audio = max(len_audio)
input_features = {self.feature_extractor_input_name: audios}
batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True)
batch[self.feature_extractor_input_name] = batch[self.feature_extractor_input_name].unsqueeze(1) # add mono-channel
batch["padding_mask"] = batch.pop("attention_mask")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch
@dataclass @dataclass
...@@ -437,6 +460,7 @@ def convert_dataset_str_to_list( ...@@ -437,6 +460,7 @@ def convert_dataset_str_to_list(
def load_multiple_datasets( def load_multiple_datasets(
accelerator: Accelerator,
dataset_names: Union[List, str], dataset_names: Union[List, str],
dataset_config_names: Union[List, str], dataset_config_names: Union[List, str],
metadata_dataset_names: Optional[str]=None, metadata_dataset_names: Optional[str]=None,
...@@ -463,51 +487,52 @@ def load_multiple_datasets( ...@@ -463,51 +487,52 @@ def load_multiple_datasets(
all_datasets = [] all_datasets = []
# iterate over the datasets we want to interleave # iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."): for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
dataset = load_dataset( with accelerator.main_process_first():
dataset_dict["name"], dataset = load_dataset(
dataset_dict["config"], dataset_dict["name"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
dataset_features = dataset.features.keys()
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
metadata_dataset = load_dataset(
metadata_dataset_name,
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
streaming=streaming, streaming=streaming,
**kwargs, **kwargs,
) )
dataset_features = dataset.features.keys()
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
metadata_dataset = load_dataset(
metadata_dataset_name,
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names)) metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove) dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
dataset_features = dataset.features.keys() if id_column_name is not None:
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
dataset_features = dataset.features.keys()
if columns_to_keep is not None: if columns_to_keep is not None:
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset) all_datasets.append(dataset)
if len(all_datasets) == 1: if len(all_datasets) == 1:
...@@ -522,7 +547,8 @@ def load_multiple_datasets( ...@@ -522,7 +547,8 @@ def load_multiple_datasets(
seed=seed, seed=seed,
) )
else: else:
interleaved_dataset = concatenate_datasets(all_datasets) with accelerator.main_process_first():
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset return interleaved_dataset
...@@ -544,6 +570,8 @@ def main(): ...@@ -544,6 +570,8 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_stable_speech", model_args, data_args) send_example_telemetry("run_stable_speech", model_args, data_args)
accelerator = Accelerator()
# Detecting last checkpoint. # Detecting last checkpoint.
last_checkpoint = None last_checkpoint = None
...@@ -566,7 +594,7 @@ def main(): ...@@ -566,7 +594,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -574,8 +602,9 @@ def main(): ...@@ -574,8 +602,9 @@ def main():
f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if accelerator.is_main_process:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args) logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model. # Set seed before initializing model.
...@@ -585,71 +614,51 @@ def main(): ...@@ -585,71 +614,51 @@ def main():
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
columns_to_keep = [data_args.target_audio_column_name, data_args.prompt_column_name] columns_to_keep = {
"target_audio_column_name": data_args.target_audio_column_name,
"prompt_column_name": data_args.prompt_column_name
}
if data_args.description_column_name is not None: if data_args.description_column_name is not None:
columns_to_keep.append(data_args.description_column_name) columns_to_keep["description_column_nam"] = data_args.description_column_name
if data_args.conditional_audio_column_name is not None:
columns_to_keep.append(data_args.conditional_audio_column_name)
if training_args.do_train: if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets( raw_datasets["train"] = load_multiple_datasets(
accelerator,
data_args.train_dataset_name, data_args.train_dataset_name,
data_args.train_dataset_config_name, data_args.train_dataset_config_name,
data_args.train_metadata_dataset_name, metadata_dataset_names=data_args.train_metadata_dataset_name,
splits=data_args.train_split_name, splits=data_args.train_split_name,
dataset_samples=data_args.train_dataset_samples, dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed, seed=training_args.seed,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep, columns_to_keep=columns_to_keep.values(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
if data_args.target_audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--target_audio_column_name '{data_args.target_audio_column_name}' not found in dataset '{data_args.train_dataset_name}'."
" Make sure to set `--target_audio_column_name` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.description_column_name is not None and data_args.description_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--description_column_name {data_args.description_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--description_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
if data_args.prompt_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--description_column_name {data_args.prompt_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--prompt_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
if data_args.conditional_audio_column_name is not None and data_args.conditional_audio_column_name not in raw_datasets["train"].column_names:
raise ValueError(
f"--conditional_audio_column_name {data_args.conditional_audio_column_name} not found in dataset '{data_args.train_dataset_name}'. "
"Make sure to set `--conditional_audio_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
for key in columns_to_keep:
if columns_to_keep[key] not in raw_datasets["train"].column_names:
raise ValueError(
f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
f" Make sure to set `--{key}` to the correct audio column - one of"
f" {', '.join(raw_datasets['train'].column_names)}."
)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if training_args.do_eval: if training_args.do_eval:
raw_datasets["eval"] = load_multiple_datasets( raw_datasets["eval"] = load_multiple_datasets(
accelerator,
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name, data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name data_args.eval_dataset_config_name if data_args.eval_dataset_config_name else data_args.train_dataset_config_name,
if data_args.eval_dataset_config_name metadata_dataset_names=data_args.eval_metadata_dataset_name,
else data_args.train_dataset_config_name,
data_args.eval_metadata_dataset_name,
splits=data_args.eval_split_name, splits=data_args.eval_split_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name, id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep, columns_to_keep=columns_to_keep.values(),
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
...@@ -657,9 +666,8 @@ def main(): ...@@ -657,9 +666,8 @@ def main():
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
# TODO: is is the right way to do ? # 2. Next, let's load the config as we might need it to create
# 3. Next, let's load the config as we might need it to create # load config TODO: add the option to create the config from scratch
# load config
config = StableSpeechConfig.from_pretrained( config = StableSpeechConfig.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -673,8 +681,7 @@ def main(): ...@@ -673,8 +681,7 @@ def main():
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id, "decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
}) })
# 3. Now we can instantiate the feature extractor, tokenizers and model
# 4. Now we can instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only # Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab. # one local process can concurrently download model & vocab.
...@@ -692,16 +699,24 @@ def main(): ...@@ -692,16 +699,24 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
) )
# load description tokenizer # load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained( description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path, model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=data_args.token, token=data_args.token,
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
) )
if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235")
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# create model + TODO: not from_pretrained probably # create model + TODO: not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained( model = StableSpeechForConditionalGeneration.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
...@@ -711,46 +726,31 @@ def main(): ...@@ -711,46 +726,31 @@ def main():
trust_remote_code=data_args.trust_remote_code, trust_remote_code=data_args.trust_remote_code,
) )
# take audio_encoder_feature_extractor
audio_encoder_feature_extractor = AutoFeatureExtractor.from_pretrained(
model.config.audio_encoder._name_or_path,
)
# 5. Now we preprocess the datasets including loading the audio, resampling and normalization # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio, # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input # so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor` # via the `feature_extractor`
# resample target audio
raw_datasets = raw_datasets.cast_column(
data_args.target_audio_column_name, datasets.features.Audio(sampling_rate=audio_encoder_feature_extractor.sampling_rate)
)
if data_args.conditional_audio_column_name is not None:
raw_datasets = raw_datasets.cast_column(
data_args.conditional_audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# derive max & min input length for sample rate & max duration # derive max & min input length for sample rate & max duration
max_target_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
min_target_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate max_target_length = data_args.max_duration_in_seconds * sampling_rate
min_target_length = data_args.min_duration_in_seconds * sampling_rate
target_audio_column_name = data_args.target_audio_column_name target_audio_column_name = data_args.target_audio_column_name
conditional_audio_column_name = data_args.conditional_audio_column_name
description_column_name = data_args.description_column_name description_column_name = data_args.description_column_name
prompt_column_name = data_args.prompt_column_name prompt_column_name = data_args.prompt_column_name
feature_extractor_input_name = feature_extractor.model_input_names[0] feature_extractor_input_name = feature_extractor.model_input_names[0]
# resample target audio
raw_datasets = raw_datasets.cast_column(
target_audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
# Preprocessing the datasets. # Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets. # We need to read the audio files as arrays and tokenize the texts.
def pass_through_processors(batch): def pass_through_processors(batch):
# load audio # load audio
if conditional_audio_column_name is not None:
sample = batch[target_audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
if description_column_name is not None: if description_column_name is not None:
text = batch[description_column_name] text = batch[description_column_name]
batch["input_ids"] = description_tokenizer(text.strip())["input_ids"] batch["input_ids"] = description_tokenizer(text.strip())["input_ids"]
...@@ -761,14 +761,14 @@ def main(): ...@@ -761,14 +761,14 @@ def main():
# load audio # load audio
target_sample = batch[target_audio_column_name] target_sample = batch[target_audio_column_name]
labels = audio_encoder_feature_extractor(target_sample["array"], sampling_rate=target_sample["sampling_rate"]) labels = feature_extractor(target_sample["array"], sampling_rate=target_sample["sampling_rate"])
batch["labels"] = labels["input_values"] batch["labels"] = labels["input_values"]
# take length of raw audio waveform # take length of raw audio waveform
batch["target_length"] = len(target_sample["array"].squeeze()) batch["target_length"] = len(target_sample["array"].squeeze())
return batch return batch
with training_args.main_process_first(desc="dataset map preprocessing"): with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
pass_through_processors, pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names, remove_columns=next(iter(raw_datasets.values())).column_names,
...@@ -785,34 +785,81 @@ def main(): ...@@ -785,34 +785,81 @@ def main():
num_proc=num_workers, num_proc=num_workers,
input_columns=["target_length"], input_columns=["target_length"],
) )
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
logger.info("*** Encode target audio with encodec ***")
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder = model.audio_encoder audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name)
def apply_audio_decoder(batch): def apply_audio_decoder(batch):
labels = audio_decoder.encode(torch.tensor(batch["labels"]).to(audio_decoder.device))["audio_codes"] len_audio = batch.pop("len_audio")
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels, audio_decoder.to(batch["input_values"].device).eval()
model.generation_config.decoder_start_token_id, labels = audio_decoder.encode(**batch)["audio_codes"]
model.generation_config.max_length + 1) output = {}
output["len_audio"] = len_audio
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask) # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output["labels"] = labels.squeeze(0).transpose(1,2)
# the first timestamp is associated to a row full of BOS, let's get rid of it output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
batch["labels"] = labels[:, 1:] return output
return batch
for split in vectorized_datasets:
with training_args.main_process_first(desc="audio target preprocessing"): data_loader = DataLoader(
# for now on CPU vectorized_datasets[split],
# TODO: enrich for GPU batch_size=training_args.per_device_eval_batch_size,
vectorized_datasets = vectorized_datasets.map( collate_fn=encoder_data_collator,
apply_audio_decoder, num_workers=training_args.dataloader_num_workers,
num_proc=num_workers, pin_memory=True,
desc="preprocess datasets",
) )
data_loader = accelerator.prepare(data_loader)
# TODO: will it work on GPU ? unmerged for now https://github.com/huggingface/accelerate/pull/2433
# for split in vectorized_datasets: all_generated_labels = []
# with distributed_state.split_between_processes(vectorized_datasets[split]["labels"]) as input_labels: all_ratios = []
# result = audio_decoder(input_labels) all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch)
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels)
all_generated_labels.extend(generate_labels["labels"].cpu())
all_ratios.extend(generate_labels["ratio"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu())
def postprocess_dataset(sample, idx):
# (1, seq_len, codebooks, bsz)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
labels, delay_pattern_mask = model.decoder.build_delay_pattern_mask(labels,
model.generation_config.decoder_start_token_id,
model.generation_config.max_length + model.decoder.config.num_codebooks)
labels = model.decoder.apply_delay_pattern_mask(labels, delay_pattern_mask)
len_ = int(all_ratios[idx] * all_lens[idx])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample["labels"] = labels[:, 1:len_]
return sample
# TODO: done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=num_workers,
desc="Postprocessing labeling",
with_indices=True,
)
accelerator.free_memory()
del generate_labels
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to: if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to:
if is_wandb_available(): if is_wandb_available():
...@@ -827,6 +874,7 @@ def main(): ...@@ -827,6 +874,7 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset # cached dataset
if data_args.preprocessing_only: if data_args.preprocessing_only:
# TODO: save to disk in this step instead of something else ??
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
...@@ -865,9 +913,9 @@ def main(): ...@@ -865,9 +913,9 @@ def main():
# Now save everything to be able to create a single processor later # Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved # make sure all processes wait until data is saved
with training_args.main_process_first(): with accelerator.main_process_first():
# only the main process saves them # only the main process saves them
if is_main_process(training_args.local_rank): if accelerator.is_main_process:
# save feature extractor, tokenizer and config # save feature extractor, tokenizer and config
if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name): if model_args.prompt_tokenizer_name is None and model_args.description_tokenizer_name or (model_args.prompt_tokenizer_name==model_args.description_tokenizer_name):
prompt_tokenizer.save_pretrained(training_args.output_dir) prompt_tokenizer.save_pretrained(training_args.output_dir)
...@@ -936,7 +984,7 @@ def main(): ...@@ -936,7 +984,7 @@ def main():
audios = predictions["audio"] audios = predictions["audio"]
# log the table to wandb # log the table to wandb
self._wandb.log({"sample_songs": [self._wandb.Audio(audio, caption=text, sample_rate=audio_encoder_feature_extractor.sampling_rate) for (audio, text) in zip(audios, texts)]}) self._wandb.log({"sample_songs": [self._wandb.Audio(audio, caption=text, sample_rate=sampling_rate) for (audio, text) in zip(audios, texts)]})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment