Unverified Commit 1045a36c authored by Mario Šaško's avatar Mario Šaško Committed by GitHub
Browse files

Fix pytorch image classification example (#14883)

* Update example

* Remove skip in tests
parent 7df4b90c
......@@ -279,12 +279,14 @@ def main():
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
]
return example_batch
def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
return example_batch
if training_args.do_train:
......
......@@ -19,7 +19,6 @@ import json
import logging
import os
import sys
import unittest
from unittest.mock import patch
import torch
......@@ -409,7 +408,6 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
@unittest.skip("Fix me Nate!")
def test_run_image_classification(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment