Commit dbb95132 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

improve mapping

parent d0140745
...@@ -1005,13 +1005,13 @@ def main(): ...@@ -1005,13 +1005,13 @@ def main():
eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id eos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_eos_token_id
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
def postprocess_dataset(sample, idx): def postprocess_dataset(input_ids, prompt_input_ids, idx):
# (1, codebooks, seq_len) # (1, codebooks, seq_len)
labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0) labels = all_generated_labels[idx].transpose(0,1).unsqueeze(0)
len_ = int(all_ratios[idx] * all_lens[idx]) len_ = int(all_ratios[idx] * all_lens[idx])
labels = labels[:, :, :len_] labels = labels[:, :, :len_]
labels = labels[:, :, :(len_)%10+20] # TODO: change # labels = labels[:, :, :(len_)%10+20] # TODO: change
# add bos # add bos
labels = torch.cat([bos_labels, labels], dim=-1) labels = torch.cat([bos_labels, labels], dim=-1)
...@@ -1034,14 +1034,17 @@ def main(): ...@@ -1034,14 +1034,17 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it # the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD) # we also remove the last timestampts (full of PAD)
sample["labels"] = labels[:, 1:].cpu() output = {"labels": labels[:, 1:].cpu()}
return sample output["input_ids"] = input_ids
output["prompt_input_ids"] = prompt_input_ids
return output
# TODO: done multiple times, how to deal with it. # TODO: done multiple times, how to deal with it.
with accelerator.main_process_first(): with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map( vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset, postprocess_dataset,
num_proc=num_workers, num_proc=num_workers,
input_columns=["input_ids", "prompt_input_ids"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
with_indices=True, with_indices=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment