Commit 046e3fc9 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

return attn mask

parent 294c162d
......@@ -614,13 +614,15 @@ def main():
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch["audio"]:
wav = random_subsample(audio["array"], max_length=data_args.max_length_seconds, sample_rate=sampling_rate)
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
audios = [audio["array"] for audio in batch["audio"]]
inputs = feature_extractor(
audios, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
"attention_mask": inputs.get("attention_mask"),
"labels": [int(label2id[label]) for label in batch["labels"]],
}
return output_batch
if training_args.do_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment