Commit 75ae54a8 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

improve pre-processing logics

parent 5e2041eb
......@@ -1039,7 +1039,14 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None:
......@@ -1158,7 +1165,6 @@ def main():
data_loader = accelerator.prepare(data_loader)
all_generated_labels = []
all_ratios = []
all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch)
......@@ -1166,30 +1172,31 @@ def main():
generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process:
all_generated_labels.extend(generate_labels["labels"].cpu())
all_ratios.extend(generate_labels["ratio"].cpu().squeeze())
all_lens.extend(generate_labels["len_audio"].cpu().squeeze())
lab = generate_labels["labels"].cpu().transpose(1,2).to(torch.int16)
rat = generate_labels["ratio"].cpu().squeeze()
lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, :int(ratio*length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab)
all_lens.extend(lens)
# (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "ratios": all_ratios, "target_length": all_lens})
tmp_labels.save_to_disk(data_args.temporary_save_to_disk, num_proc=data_args.preprocessing_num_workers)
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers)
accelerator.wait_for_everyone()
del all_generated_labels
tmp_labels = datasets.load_from_disk(data_args.temporary_save_to_disk)
tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels, target_length, ratio):
def postprocess_dataset(labels):
# (1, codebooks, seq_len)
labels = torch.tensor(labels).transpose(0,1).unsqueeze(0)
len_ = int(ratio * target_length)
labels = labels[:, :, :len_]
labels = torch.tensor(labels).unsqueeze(0)
# add bos
labels = torch.cat([bos_labels, labels], dim=-1)
......@@ -1210,7 +1217,7 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:].cpu()}
output = {"labels": labels[:, 1:]}
return output
......@@ -1219,14 +1226,13 @@ def main():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels", "target_length", "ratios"],
remove_columns=["ratios"],
input_columns=["labels"],
desc="Postprocessing labeling",
)
accelerator.free_memory()
del generate_labels, all_lens, all_ratios
del generate_labels, all_lens
with accelerator.main_process_first():
......@@ -1242,7 +1248,7 @@ def main():
if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=data_args.preprocessing_num_workers)
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1))
logger.info(f"Dataset saved at {data_args.save_to_disk}")
# for large datasets it is advised to run the preprocessing on a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment