Unverified Commit 13489248 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Examples] Generalise run audio classification for log-mel models (#21756)

* [Examples] Generalise run audio classification for log-mel models

* batch feature extractor

* make style
parent f7ca656f
...@@ -289,24 +289,27 @@ def main(): ...@@ -289,24 +289,27 @@ def main():
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
model_input_name = feature_extractor.model_input_names[0]
def train_transforms(batch): def train_transforms(batch):
"""Apply train_transforms across a batch.""" """Apply train_transforms across a batch."""
output_batch = {"input_values": []} subsampled_wavs = []
for audio in batch[data_args.audio_column_name]: for audio in batch[data_args.audio_column_name]:
wav = random_subsample( wav = random_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
) )
output_batch["input_values"].append(wav) subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = list(batch[data_args.label_column_name]) output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
def val_transforms(batch): def val_transforms(batch):
"""Apply val_transforms across a batch.""" """Apply val_transforms across a batch."""
output_batch = {"input_values": []} wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
for audio in batch[data_args.audio_column_name]: inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
wav = audio["array"] output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["input_values"].append(wav)
output_batch["labels"] = list(batch[data_args.label_column_name]) output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment