Commit 5e2041eb authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

make smarter audio encoding in terms of RAM usage

parent 84e0def5
......@@ -444,6 +444,12 @@ class DataTrainingArguments:
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
)
pad_to_multiple_of: Optional[int] = field(
default=2,
metadata={
......@@ -490,6 +496,7 @@ class DataCollatorEncodecWithPadding:
feature_extractor: AutoFeatureExtractor
audio_column_name: str
feature_extractor_input_name: Optional[str] = "input_values"
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
......@@ -497,6 +504,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[:min(len(audio), self.max_length + 10)] for audio in audios]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
......@@ -1030,14 +1039,7 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None:
......@@ -1049,53 +1051,28 @@ def main():
input_columns=[description_column_name],
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the texts.
def pass_through_processors(batch):
# load audio
if description_column_name is not None:
text = batch[description_column_name]
batch["input_ids"] = description_tokenizer(text.strip())["input_ids"]
if prompt_column_name is not None:
text = batch[prompt_column_name]
batch["prompt_input_ids"] = prompt_tokenizer(text.strip())["input_ids"]
# Preprocessing the dataset.
# We need to tokenize the texts.
def pass_through_processors(description, prompt):
batch = {}
batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
# TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
# take length of raw audio waveform
batch["target_length"] = len(batch[target_audio_column_name]["array"].squeeze())
return batch
with accelerator.main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages
tmp_datasets = raw_datasets.map(
vectorized_datasets = raw_datasets.map(
pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names,
input_columns=[description_column_name, prompt_column_name],
num_proc=num_workers,
desc="preprocess datasets",
# cache_file_names={"train": "/scratch/train.arrow", "eval":"/scratch/eval.arrow"} , # TODO: remove - specific to cluster
)
# only keep audio column from the raw datasets
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove = [col for col in next(iter(raw_datasets.values())).column_names if col != target_audio_column_name]
for split in raw_datasets:
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(cols_to_remove), tmp_datasets[split]], axis=1)
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets = raw_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["target_length"],
)
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
......@@ -1118,13 +1095,15 @@ def main():
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
with torch.no_grad():
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
if accelerator.is_main_process:
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
......@@ -1140,6 +1119,7 @@ def main():
with_indices=True,
writer_batch_size=100,
)
accelerator.wait_for_everyone()
accelerator.free_memory()
del data_loader, all_encoder_outputs, all_encoder_lengths
......@@ -1153,7 +1133,7 @@ def main():
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name)
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length)
def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio")
......@@ -1169,7 +1149,7 @@ def main():
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
raw_datasets[split],
batch_size=training_args.audio_encode_per_device_eval_batch_size,
collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers,
......@@ -1185,22 +1165,31 @@ def main():
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels)
all_generated_labels.extend(generate_labels["labels"].cpu())
all_ratios.extend(generate_labels["ratio"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu())
if accelerator.is_main_process:
all_generated_labels.extend(generate_labels["labels"].cpu())
all_ratios.extend(generate_labels["ratio"].cpu().squeeze())
all_lens.extend(generate_labels["len_audio"].cpu().squeeze())
# (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(input_ids, prompt_input_ids, idx):
if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "ratios": all_ratios, "target_length": all_lens})
tmp_labels.save_to_disk(data_args.temporary_save_to_disk, num_proc=data_args.preprocessing_num_workers)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(data_args.temporary_save_to_disk)
with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels, target_length, ratio):
# (1, codebooks, seq_len)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
len_ = int(all_ratios[idx] * all_lens[idx])
labels = torch.tensor(labels).transpose(0,1).unsqueeze(0)
len_ = int(ratio * target_length)
labels = labels[:, :, :len_]
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
......@@ -1210,7 +1199,6 @@ def main():
max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
......@@ -1223,29 +1211,38 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:].cpu()}
output["input_ids"] = input_ids
output["prompt_input_ids"] = prompt_input_ids
return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor.
input_columns=["input_ids", "prompt_input_ids"],
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels", "target_length", "ratios"],
remove_columns=["ratios"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
)
accelerator.free_memory()
del generate_labels, all_generated_labels, all_lens, all_ratios
del generate_labels, all_lens, all_ratios
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["target_length"],
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
vectorized_datasets.save_to_disk(data_args.save_to_disk)
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=data_args.preprocessing_num_workers)
logger.info(f"Dataset saved at {data_args.save_to_disk}")
# for large datasets it is advised to run the preprocessing on a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment