Commit d9e50b17 authored by Soumith Chintala's avatar Soumith Chintala
Browse files

push get_file_path patches to test

parent c74b79c8
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
import os import os
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
from torch._utils_internal import get_file_path_2
def mock_transform(return_value, arg_list): def mock_transform(return_value, arg_list):
...@@ -13,10 +14,10 @@ def mock_transform(return_value, arg_list): ...@@ -13,10 +14,10 @@ def mock_transform(return_value, arg_list):
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
root = 'test/assets/dataset/' root = get_file_path_2('test/assets/dataset/')
classes = ['a', 'b'] classes = ['a', 'b']
class_a_images = [os.path.join('test/assets/dataset/a/', path) for path in ['a1.png', 'a2.png', 'a3.png']] class_a_images = [get_file_path_2(os.path.join('test/assets/dataset/a/', path)) for path in ['a1.png', 'a2.png', 'a3.png']]
class_b_images = [os.path.join('test/assets/dataset/b/', path) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']] class_b_images = [get_file_path_2(os.path.join('test/assets/dataset/b/', path)) for path in ['b1.png', 'b2.png', 'b3.png', 'b4.png']]
def test_image_folder(self): def test_image_folder(self):
dataset = ImageFolder(Tester.root, loader=lambda x: x) dataset = ImageFolder(Tester.root, loader=lambda x: x)
...@@ -34,7 +35,7 @@ class Tester(unittest.TestCase): ...@@ -34,7 +35,7 @@ class Tester(unittest.TestCase):
self.assertEqual(imgs, outputs) self.assertEqual(imgs, outputs)
def test_transform(self): def test_transform(self):
return_value = 'test/assets/dataset/a/a1.png' return_value = get_file_path_2('test/assets/dataset/a/a1.png')
args = [] args = []
transform = mock_transform(return_value, args) transform = mock_transform(return_value, args)
......
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torch._utils_internal import get_file_path_2
import unittest import unittest
import math import math
import random import random
...@@ -16,7 +17,7 @@ try: ...@@ -16,7 +17,7 @@ try:
except ImportError: except ImportError:
stats = None stats = None
GRACE_HOPPER = 'assets/grace_hopper_517x606.jpg' GRACE_HOPPER = get_file_path_2('assets/grace_hopper_517x606.jpg')
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment