"test/old-api/spectests.cpp" did not exist on "fa0af88dfeef3c6ed06296b34989d548032b13f0"
Commit 046e3fc9 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

return attn mask

parent 294c162d
...@@ -614,13 +614,15 @@ def main(): ...@@ -614,13 +614,15 @@ def main():
def train_transforms(batch): def train_transforms(batch):
"""Apply train_transforms across a batch.""" """Apply train_transforms across a batch."""
subsampled_wavs = [] audios = [audio["array"] for audio in batch["audio"]]
for audio in batch["audio"]: inputs = feature_extractor(
wav = random_subsample(audio["array"], max_length=data_args.max_length_seconds, sample_rate=sampling_rate) audios, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
subsampled_wavs.append(wav) )
inputs = feature_extractor(subsampled_wavs, sampling_rate=sampling_rate) output_batch = {
output_batch = {model_input_name: inputs.get(model_input_name)} model_input_name: inputs.get(model_input_name),
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]] "attention_mask": inputs.get("attention_mask"),
"labels": [int(label2id[label]) for label in batch["labels"]],
}
return output_batch return output_batch
if training_args.do_train: if training_args.do_train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment