"vscode:/vscode.git/clone" did not exist on "58631803e5ab484a0a083ba43d5f5507b0d70c4f"
Commit 84e0def5 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix encodec collator

parent b09eba24
...@@ -495,13 +495,10 @@ class DataCollatorEncodecWithPadding: ...@@ -495,13 +495,10 @@ class DataCollatorEncodecWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
audios = [torch.tensor(feature[self.audio_column_name]).squeeze() for feature in features] audios = [feature[self.audio_column_name]["array"] for feature in features]
len_audio = [len(audio) for audio in audios] len_audio = [len(audio) for audio in audios]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest", return_attention_mask=True) batch = self.feature_extractor(audios, return_tensors="pt", padding="longest")
batch[self.feature_extractor_input_name] = batch[self.feature_extractor_input_name].unsqueeze(1) # add mono-channel
batch["padding_mask"] = batch.pop("attention_mask")
batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1) batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
return batch return batch
...@@ -1083,17 +1080,15 @@ def main(): ...@@ -1083,17 +1080,15 @@ def main():
# this is a trick to avoid to rewrite the entire audio column which takes ages # this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove = [col for col in next(iter(raw_datasets.values())).column_names if col != target_audio_column_name] cols_to_remove = [col for col in next(iter(raw_datasets.values())).column_names if col != target_audio_column_name]
for split in raw_datasets: for split in raw_datasets:
vectorized_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(cols_to_remove), tmp_datasets[split]], axis=1) raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(cols_to_remove), tmp_datasets[split]], axis=1)
# TODO: remove
logger.info(f"Vectorized datasets {vectorized_datasets}")
with accelerator.main_process_first(): with accelerator.main_process_first():
def is_audio_in_length_range(length): def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length return length > min_target_length and length < max_target_length
# filter data that is shorter than min_target_length # filter data that is shorter than min_target_length
vectorized_datasets = vectorized_datasets.filter( vectorized_datasets = raw_datasets.filter(
is_audio_in_length_range, is_audio_in_length_range,
num_proc=num_workers, num_proc=num_workers,
input_columns=["target_length"], input_columns=["target_length"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment