Commit b09eba24 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

compute audio in collator instead of previously

parent 441af9a4
......@@ -488,16 +488,18 @@ class DataCollatorEncodecWithPadding:
"""
feature_extractor: AutoFeatureExtractor
audio_column_name: str
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios = [torch.tensor(feature["labels"]).squeeze() for feature in features]
audios = [torch.tensor(feature[self.audio_column_name]).squeeze() for feature in features]
len_audio = [len(audio) for audio in audios]
input_features = {self.feature_extractor_input_name: audios}
batch = self.feature_extractor.pad(input_features, return_tensors="pt", padding="longest", return_attention_mask=True)
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest", return_attention_mask=True)
batch[self.feature_extractor_input_name] = batch[self.feature_extractor_input_name].unsqueeze(1) # add mono-channel
batch["padding_mask"] = batch.pop("attention_mask")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
......@@ -1032,7 +1034,7 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
# TODO: remove
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
......@@ -1062,24 +1064,30 @@ def main():
text = batch[prompt_column_name]
batch["prompt_input_ids"] = prompt_tokenizer(text.strip())["input_ids"]
# load audio
target_sample = batch[target_audio_column_name]
arr = target_sample["array"]
labels = feature_extractor(arr[:min(len(arr), max_target_length+10)], sampling_rate=target_sample["sampling_rate"])
batch["labels"] = labels["input_values"]
# take length of raw audio waveform
batch["target_length"] = len(target_sample["array"].squeeze())
batch["target_length"] = len(batch[target_audio_column_name]["array"].squeeze())
return batch
with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map(
# this is a trick to avoid to rewrite the entire audio column which takes ages
tmp_datasets = raw_datasets.map(
pass_through_processors,
remove_columns=next(iter(raw_datasets.values())).column_names,
num_proc=num_workers,
desc="preprocess datasets",
# cache_file_names={"train": "/scratch/train.arrow", "eval":"/scratch/eval.arrow"} , # TODO: remove - specific to cluster
)
# only keep audio column from the raw datasets
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove = [col for col in next(iter(raw_datasets.values())).column_names if col != target_audio_column_name]
for split in raw_datasets:
vectorized_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(cols_to_remove), tmp_datasets[split]], axis=1)
# TODO: remove
logger.info(f"Vectorized datasets {vectorized_datasets}")
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
......@@ -1150,7 +1158,7 @@ def main():
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder = model.audio_encoder
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, feature_extractor_input_name)
encoder_data_collator = DataCollatorEncodecWithPadding(feature_extractor, audio_column_name=target_audio_column_name, feature_extractor_input_name=feature_extractor_input_name)
def apply_audio_decoder(batch):
len_audio = batch.pop("len_audio")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment