Commit 75ae54a8 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

improve pre-processing logics

parent 5e2041eb
...@@ -1039,7 +1039,14 @@ def main(): ...@@ -1039,7 +1039,14 @@ def main():
# Freeze Encoders # Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder) model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed: if not dataset_was_precomputed:
# Filter on text length # Filter on text length
if description_column_name is not None: if description_column_name is not None:
...@@ -1158,7 +1165,6 @@ def main(): ...@@ -1158,7 +1165,6 @@ def main():
data_loader = accelerator.prepare(data_loader) data_loader = accelerator.prepare(data_loader)
all_generated_labels = [] all_generated_labels = []
all_ratios = []
all_lens = [] all_lens = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process): for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generate_labels = apply_audio_decoder(batch) generate_labels = apply_audio_decoder(batch)
...@@ -1166,30 +1172,31 @@ def main(): ...@@ -1166,30 +1172,31 @@ def main():
generate_labels = accelerator.gather_for_metrics(generate_labels) generate_labels = accelerator.gather_for_metrics(generate_labels)
if accelerator.is_main_process: if accelerator.is_main_process:
all_generated_labels.extend(generate_labels["labels"].cpu()) lab = generate_labels["labels"].cpu().transpose(1,2).to(torch.int16)
all_ratios.extend(generate_labels["ratio"].cpu().squeeze()) rat = generate_labels["ratio"].cpu().squeeze()
all_lens.extend(generate_labels["len_audio"].cpu().squeeze()) lens = generate_labels["len_audio"].cpu().squeeze()
lab = [l[:, :int(ratio*length)] for (l, ratio, length) in zip(lab, rat, lens)]
all_generated_labels.extend(lab)
all_lens.extend(lens)
# (1, codebooks, seq_len) where seq_len=1 # (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
if accelerator.is_main_process: if accelerator.is_main_process:
tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "ratios": all_ratios, "target_length": all_lens}) tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
tmp_labels.save_to_disk(data_args.temporary_save_to_disk, num_proc=data_args.preprocessing_num_workers) tmp_labels.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split), num_proc=1 if split == "eval" else data_args.preprocessing_num_workers)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
del all_generated_labels del all_generated_labels
tmp_labels = datasets.load_from_disk(data_args.temporary_save_to_disk) tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1) vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
def postprocess_dataset(labels, target_length, ratio): def postprocess_dataset(labels):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
labels = torch.tensor(labels).transpose(0,1).unsqueeze(0) labels = torch.tensor(labels).unsqueeze(0)
len_ = int(ratio * target_length)
labels = labels[:, :, :len_]
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
...@@ -1210,7 +1217,7 @@ def main(): ...@@ -1210,7 +1217,7 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD) # we also remove the last timestampts (full of PAD)
output = {"labels": labels[:, 1:].cpu()} output = {"labels": labels[:, 1:]}
return output return output
...@@ -1219,14 +1226,13 @@ def main(): ...@@ -1219,14 +1226,13 @@ def main():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor. num_proc=data_args.preprocessing_num_workers, # this one is resource consuming if many processor.
input_columns=["labels", "target_length", "ratios"], input_columns=["labels"],
remove_columns=["ratios"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
) )
accelerator.free_memory() accelerator.free_memory()
del generate_labels, all_lens, all_ratios del generate_labels, all_lens
with accelerator.main_process_first(): with accelerator.main_process_first():
...@@ -1242,7 +1248,7 @@ def main(): ...@@ -1242,7 +1248,7 @@ def main():
if data_args.save_to_disk is not None and not dataset_was_precomputed: if data_args.save_to_disk is not None and not dataset_was_precomputed:
if accelerator.is_main_process: if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=data_args.preprocessing_num_workers) vectorized_datasets.save_to_disk(data_args.save_to_disk, num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"])-1))
logger.info(f"Dataset saved at {data_args.save_to_disk}") logger.info(f"Dataset saved at {data_args.save_to_disk}")
# for large datasets it is advised to run the preprocessing on a # for large datasets it is advised to run the preprocessing on a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment