Commit 5e2041eb authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

make smarter audio encoding in terms of RAM usage

parent 84e0def5
...@@ -444,6 +444,12 @@ class DataTrainingArguments: ...@@ -444,6 +444,12 @@ class DataTrainingArguments:
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it." "help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
} }
) )
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
)
pad_to_multiple_of: Optional[int] = field( pad_to_multiple_of: Optional[int] = field(
default=2, default=2,
metadata={ metadata={
...@@ -490,6 +496,7 @@ class DataCollatorEncodecWithPadding: ...@@ -490,6 +496,7 @@ class DataCollatorEncodecWithPadding:
feature_extractor: AutoFeatureExtractor feature_extractor: AutoFeatureExtractor
audio_column_name: str audio_column_name: str
feature_extractor_input_name: Optional[str] = "input_values" feature_extractor_input_name: Optional[str] = "input_values"
max_length: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
...@@ -497,6 +504,8 @@ class DataCollatorEncodecWithPadding: ...@@ -497,6 +504,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods # different padding methods
audios = [feature[self.audio_column_name]["array"] for feature in features] audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
if self.max_length is not None:
audios = [audio[:min(len(audio), self.max_length + 10)] for audio in audios]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest") batch = self.feature_extractor(audios, return_tensors="pt", padding="longest")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
...@@ -1030,14 +1039,7 @@ def main(): ...@@ -1030,14 +1039,7 @@ def main():
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed: if not dataset_was_precomputed:
# Filter on text length # Filter on text length
if description_column_name is not None: if description_column_name is not None:
...@@ -1049,53 +1051,28 @@ def main(): ...@@ -1049,53 +1051,28 @@ def main():
input_columns=[description_column_name], input_columns=[description_column_name],
) )
# Preprocessing the datasets. # Preprocessing the dataset.
# We need to read the audio files as arrays and tokenize the texts. # We need to tokenize the texts.
def pass_through_processors(batch): def pass_through_processors(description, prompt):
# load audio batch = {}
if description_column_name is not None:
text = batch[description_column_name] batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
batch["input_ids"] = description_tokenizer(text.strip())["input_ids"] # TODO: add possibility to train without description column
batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
if prompt_column_name is not None:
text = batch[prompt_column_name]
batch["prompt_input_ids"] = prompt_tokenizer(text.strip())["input_ids"]
# take length of raw audio waveform
batch["target_length"] = len(batch[target_audio_column_name]["array"].squeeze())
return batch return batch
with accelerator.main_process_first(): with accelerator.main_process_first():
# this is a trick to avoid to rewrite the entire audio column which takes ages # this is a trick to avoid to rewrite the entire audio column which takes ages
tmp_datasets = raw_datasets.map( vectorized_datasets = raw_datasets.map(
pass_through_processors, pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names, remove_columns=next(iter(raw_datasets.values())).column_names,
input_columns=[description_column_name, prompt_column_name],
num_proc=num_workers, num_proc=num_workers,
desc="preprocess datasets", desc="preprocess datasets",
# cache_file_names={"train": "/scratch/train.arrow", "eval":"/scratch/eval.arrow"} , # TODO: remove - specific to cluster
)
# only keep audio column from the raw datasets
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove = [col for col in next(iter(raw_datasets.values())).column_names if col != target_audio_column_name]
for split in raw_datasets:
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(cols_to_remove), tmp_datasets[split]], axis=1)
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets = raw_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["target_length"],
) )
# We use Accelerate to perform distributed inference # We use Accelerate to perform distributed inference
# T5 doesn't support fp16 # T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16")) autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
...@@ -1118,13 +1095,15 @@ def main(): ...@@ -1118,13 +1095,15 @@ def main():
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device) model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs): with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) with torch.no_grad():
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id) encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs) encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"]) lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu")) if accelerator.is_main_process:
all_encoder_lengths.extend(lengths.to("cpu")) all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx): def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])} output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
...@@ -1140,6 +1119,7 @@ def main(): ...@@ -1140,6 +1119,7 @@ def main():
with_indices=True, with_indices=True,
writer_batch_size=100, writer_batch_size=100,
) )
accelerator.wait_for_everyone()
accelerator.free_memory() accelerator.free_memory()
del data_loader, all_encoder_outputs, all_encoder_lengths del data_loader, all_encoder_outputs, all_encoder_lengths
...@@ -1153,7 +1133,7 @@ def main(): ...@@ -1153,7 +1133,7 @@ def main():
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder = model.audio_encoder audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name) encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name, max_length=max_target_length)
def apply_audio_decoder(batch): def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio") len_audio = batch.pop("len_audio")
...@@ -1169,7 +1149,7 @@ def main(): ...@@ -1169,7 +1149,7 @@ def main():
for split in vectorized_datasets: for split in vectorized_datasets:
data_loader = DataLoader( data_loader = DataLoader(
vectorized_datasets[split], raw_datasets[split],
batch_size=training_args.audio_encode_per_device_eval_batch_size, batch_size=training_args.audio_encode_per_device_eval_batch_size,
collate_fn=encoder_data_collator, collate_fn=encoder_data_collator,
num_workers=training_args.dataloader_num_workers, num_workers=training_args.dataloader_num_workers,
...@@ -1185,22 +1165,31 @@ def main(): ...@@ -1185,22 +1165,31 @@ def main():
generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0) generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
generate_labels = accelerator.gather_for_metrics(generate_labels) generate_labels = accelerator.gather_for_metrics(generate_labels)
all_generated_labels.extend(generate_labels["labels"].cpu()) if accelerator.is_main_process:
all_ratios.extend(generate_labels["ratio"].cpu()) all_generated_labels.extend(generate_labels["labels"].cpu())
all_lens.extend(generate_labels["len_audio"].cpu()) all_ratios.extend(generate_labels["ratio"].cpu().squeeze())
all_lens.extend(generate_labels["len_audio"].cpu().squeeze())
# (1, codebooks, seq_len) where seq_len=1 # (1, codebooks, seq_len) where seq_len=1
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(input_ids, prompt_input_ids, idx): if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "ratios": all_ratios, "target_length": all_lens})
tmp_labels.save_to_disk(data_args.temporary_save_to_disk, num_proc=data_args.preprocessing_num_workers)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(data_args.temporary_save_to_disk)
with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels, target_length, ratio):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0) labels = torch.tensor(labels).transpose(0,1).unsqueeze(0)
len_ = int(all_ratios[idx] * all_lens[idx]) len_ = int(ratio * target_length)
labels = labels[:, :, :len_] labels = labels[:, :, :len_]
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
...@@ -1210,7 +1199,6 @@ def main(): ...@@ -1210,7 +1199,6 @@ def main():
max_length=labels.shape[-1] + num_codebooks, max_length=labels.shape[-1] + num_codebooks,
num_codebooks=num_codebooks) num_codebooks=num_codebooks)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS # to take care of EOS
# we want labels to look like this: # we want labels to look like this:
...@@ -1223,29 +1211,38 @@ def main(): ...@@ -1223,29 +1211,38 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD) # we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:].cpu()} output = {"labels": labels[:, 1:].cpu()}
output["input_ids"] = input_ids
output["prompt_input_ids"] = prompt_input_ids
return output return output
# TODO(YL): done multiple times, how to deal with it. # TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor. num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["input_ids", "prompt_input_ids"], input_columns=["labels", "target_length", "ratios"],
remove_columns=["ratios"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
) )
accelerator.free_memory() accelerator.free_memory()
del generate_labels, all_generated_labels, all_lens, all_ratios del generate_labels, all_lens, all_ratios
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["target_length"],
)
if data_args.save_to_disk is not None and not dataset_was_precomputed: if data_args.save_to_disk is not None and not dataset_was_precomputed:
vectorized_datasets.save_to_disk(data_args.save_to_disk) if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=data_args.preprocessing_num_workers)
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment