Commit b62f9f1e authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

apply random cropping to eval to prevent oom

parent 32092d53
......@@ -592,21 +592,13 @@ def main():
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
def val_transforms(batch):
"""Apply val_transforms across a batch."""
wavs = [audio["array"] for audio in batch["audio"]]
inputs = feature_extractor(wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
if training_args.do_train:
# Set the training transforms
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval:
# Set the validation transforms
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
raw_datasets["eval"].set_transform(train_transforms, output_all_columns=False)
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment